Skip to content
Snippets Groups Projects
Commit ba07de11 authored by Nurul Fithria Lubis's avatar Nurul Fithria Lubis
Browse files

fix conflicts

parents 42484e8c 4b3d84a0
Branches
No related tags found
No related merge requests found
......@@ -64,7 +64,12 @@ class PipelineAgent(Agent):
===== ===== ====== === == ===
"""
<<<<<<< HEAD
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str):
=======
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str, return_semantic_acts: bool = False,
word_level_policy_nlu: NLU = None):
>>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805
"""The constructor of PipelineAgent class.
Here are some special combination cases:
......@@ -95,6 +100,7 @@ class PipelineAgent(Agent):
self.dst = dst
self.policy = policy
self.nlg = nlg
self.init_session()
self.agent_saves = []
self.history = []
......@@ -180,14 +186,27 @@ class PipelineAgent(Agent):
model_response = self.output_action
# print(model_response)
if self.policy_nlu and type(self.output_action) != list:
self.semantic_output_action = self.policy_nlu.predict(self.output_action,
context=[x[1] for x in self.history])
else:
self.semantic_output_action = None
if self.dst is not None:
self.dst.state['history'].append([self.name, model_response])
if self.name == 'sys':
self.dst.state['system_action'] = self.output_action
self.dst.state['system_action'] = self.semantic_output_action if self.semantic_output_action else self.output_action
if type(self.output_action) == list:
for intent, domain, slot, value in self.output_action:
if intent.lower() == "book":
<<<<<<< HEAD
=======
self.dst.state['booked'][domain] = [{slot: value}]
elif self.semantic_output_action:
for intent, domain, slot, value in self.semantic_output_action:
if intent.lower() == "book":
>>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805
self.dst.state['booked'][domain] = [{slot: value}]
else:
self.dst.state['user_action'] = self.output_action
......@@ -251,6 +270,9 @@ class PipelineAgent(Agent):
def get_in_da_eval(self):
return self.input_action_eval
def get_out_da_eval(self):
return self.semantic_output_action
def get_in_da(self):
return self.input_action
......
{
"dataset_name": "multiwoz21",
"data_dir": "unified_datasets/data/multiwoz21/sys/context_window_size_3",
"output_dir": "unified_datasets/output/multiwoz21/sys/context_window_size_3",
"zipped_model_path": "unified_datasets/output/multiwoz21/sys/context_window_size_3/bertnlu_unified_multiwoz21_sys_context3.zip",
"log_dir": "unified_datasets/output/multiwoz21/sys/context_window_size_3/log",
"DEVICE": "cuda:0",
"seed": 2019,
"cut_sen_len": 40,
"use_bert_tokenizer": true,
"context_window_size": 3,
"model": {
"finetune": true,
"context": true,
"context_grad": true,
"pretrained_weights": "bert-base-uncased",
"check_step": 1000,
"max_step": 10000,
"batch_size": 128,
"learning_rate": 1e-4,
"adam_epsilon": 1e-8,
"warmup_steps": 0,
"weight_decay": 0.0,
"dropout": 0.1,
"hidden_units": 1536
}
}
......@@ -995,6 +995,262 @@ class LAVA(Policy):
# response.append("1")
else:
response.append(str(num_results))
elif slot == 'place':
if 'arriv' in " ".join(tokens[index-2:index]) or "to" in " ".join(tokens[index-2:index]):
if active_domain == "train":
try:
response.append(
top_results[active_domain]["destination"])
except:
response.append(
state[active_domain]['semi']["destination"])
elif active_domain == "taxi":
response.append(
state[active_domain]["destination"])
elif 'leav' in " ".join(tokens[index-2:index]) or "from" in tokens[index-2:index] or "depart" in " ".join(tokens[index-2:index]):
if active_domain == "train":
try:
response.append(
top_results[active_domain]["departure"])
except:
response.append(
state[active_domain]['semi']["departure"])
elif active_domain == "taxi":
response.append(
state[active_domain]['semi']["departure"])
elif "hospital" in template:
response.append("Cambridge")
else:
try:
for d in state:
if d == 'history':
continue
for s in ['destination', 'departure']:
if s in state[d]:
response.append(
state[d][s])
raise
except:
pass
else:
response.append(token)
elif slot == 'time':
if 'arrive' in ' '.join(response[-5:]) or 'arrival' in ' '.join(response[-5:]) or 'arriving' in ' '.join(response[-3:]):
if active_domain == "train" and 'arriveBy' in top_results[active_domain]:
# print('{} -> {}'.format(token, top_results[active_domain]['arriveBy']))
response.append(
top_results[active_domain]['arriveBy'])
continue
for d in state:
if d == 'history':
continue
if 'arrive by' in state[d]:
response.append(
state[d]['arrive by'])
break
elif 'leave' in ' '.join(response[-5:]) or 'leaving' in ' '.join(response[-5:]) or 'departure' in ' '.join(response[-3:]):
if active_domain == "train" and 'leaveAt' in top_results[active_domain]:
# print('{} -> {}'.format(token, top_results[active_domain]['leaveAt']))
response.append(
top_results[active_domain]['leaveAt'])
continue
for d in state:
if d == 'history':
continue
if 'leave at' in state[d]:
response.append(
state[d]['leave at'])
break
elif 'book' in response or "booked" in response:
if state['restaurant']['book time'] != "":
response.append(
state['restaurant']['book time'])
else:
try:
for d in state:
if d == 'history':
continue
for s in ['arrive by', 'leave at']:
if s in state[d]:
response.append(
state[d][s])
raise
except:
pass
else:
response.append(token)
elif slot == 'price':
if active_domain == 'attraction':
# .split()[0]
value = top_results['attraction']['entrance fee']
if "?" in value:
value = "unknown"
# if "?" not in value:
# try:
# value = str(int(value))
# except:
# value = 'free'
# else:
# value = "unknown"
response.append(value)
elif active_domain == "train":
value = top_results[active_domain][slot].split()[0]
if state[active_domain]['book people'] not in ["", "dontcare"]:
try:
value = str(float(value) * int(state[active_domain]['book people']))
except ValueError:
int_map = {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10}
value = str(float(value) * int_map[state[active_domain]['book people']])
response.append(value)
elif slot == "day" and active_domain in ["restaurant", "hotel"]:
if state[active_domain]['book day'] != "":
response.append(
state[active_domain]['book day'])
else:
# slot-filling based on query results
for d in top_results:
if slot in top_results[d]:
response.append(top_results[d][slot])
break
else:
# slot-filling based on belief state
for d in state:
if d == 'history':
continue
if slot in state[d]:
response.append(state[d][slot])
break
else:
response.append(token)
else:
if domain == 'hospital':
if slot == 'phone':
response.append('01223216297')
elif slot == 'department':
if state['hospital']['department'] != "":
response.append(state['hospital']['department'])
else:
response.append('neurosciences critical care unit')
elif slot == 'address':
response.append("56 Lincoln street")
elif slot == "postcode":
response.append('533421')
elif domain == 'police':
if slot == 'phone':
response.append('01223358966')
elif slot == 'name':
response.append('Parkside Police Station')
elif slot == 'address':
response.append('Parkside, Cambridge')
elif slot == 'postcode':
response.append('533420')
elif domain == 'taxi':
if slot == 'phone':
response.append('01223358966')
elif slot == 'color':
# response.append(random.choice(["black","white","red","yellow","blue",'grey']))
response.append("black")
elif slot == 'type':
# response.append(random.choice(["toyota","skoda","bmw",'honda','ford','audi','lexus','volvo','volkswagen','tesla']))
response.append("toyota")
else:
# print(token)
response.append(token)
else:
if token == "pounds" and len(response) > 0 and ("pounds" in response[-1] or "unknown" in response[-1] or "free" in response[-1]):
pass
else:
response.append(token)
try:
response = ' '.join(response)
except Exception as e:
# pprint(response)
raise
response = response.replace(' -s', 's')
response = response.replace(' -ly', 'ly')
response = response.replace(' .', '.')
response = response.replace(' ?', '?')
# if "not mentioned" in response:
# pdb.set_trace()
return response
def populate_template_unified(self, template, top_results, num_results, state, active_domain):
print("template:",template)
# print("top_results:",top_results)
# active_domain = None if len(
# top_results.keys()) == 0 else list(top_results.keys())[0]
template = template.replace(
'book [value_count] of', 'book one of')
tokens = template.split()
response = []
for index, token in enumerate(tokens):
if token.startswith('[') and (token.endswith(']') or token.endswith('].') or token.endswith('],')):
domain = token[1:-1].split('_')[0]
slot = token[1:-1].split('_')[1]
if slot.endswith(']'):
slot = slot[:-1]
if domain == 'train' and slot == 'id':
slot = 'trainID'
elif active_domain != 'train' and slot == 'price':
slot = 'price range'
elif slot == 'reference':
slot = 'Ref'
if domain in top_results and len(top_results[domain]) > 0 and slot in top_results[domain]:
# print('{} -> {}'.format(token, top_results[domain][slot]))
response.append(top_results[domain][slot])
elif domain == 'value':
if slot == 'count':
if "there are" in " ".join(tokens[index-2:index]) or "i have" in " ".join(tokens[index-2:index]):
response.append(str(num_results))
# the first [value_count], the last [value_count]
elif "the" in tokens[index-2]:
response.append("one")
elif active_domain == "restaurant":
if "people" in tokens[index:index+1] or "table" in tokens[index-2:index]:
response.append(
state[active_domain]["book people"])
elif active_domain == "train":
if "ticket" in " ".join(tokens[index-2:index+1]) or "people" in tokens[index:]:
response.append(
state[active_domain]["book people"])
elif index+1 < len(tokens) and "minute" in tokens[index+1]:
response.append(
top_results['train']['duration'].split()[0])
elif active_domain == "hotel":
if index+1 < len(tokens):
if "star" in tokens[index+1]:
response.append(top_results['hotel']['stars'])
elif "nights" in tokens[index+1]:
response.append(
state[active_domain]["book stay"])
elif "people" in tokens[index+1]:
response.append(
state[active_domain]["book people"])
elif active_domain == "attraction":
if index + 1 < len(tokens):
if "pounds" in tokens[index+1] and "entrance fee" in " ".join(tokens[index-3:index]):
value = top_results[active_domain]['entrance fee']
if "?" in value:
value = "unknown"
# if "?" not in value:
# try:
# value = str(int(value))
# except:
# value = 'free'
# else:
# value = "unknown"
response.append(value)
# if "there are" in " ".join(tokens[index-2:index]):
# response.append(str(num_results))
# elif "the" in tokens[index-2]: # the first [value_count], the last [value_count]
# response.append("1")
else:
response.append(str(num_results))
elif slot == 'place':
if 'arriv' in " ".join(tokens[index-2:index]) or "to" in " ".join(tokens[index-2:index]):
if active_domain == "train":
......
def default_state():
state = dict(user_action=[],
system_action=[],
belief_state={},
belief_state={
'attraction': {'type': '', 'name': '', 'area': ''},
'hotel': {'name': '', 'area': '', 'parking': '', 'price range': '', 'stars': '4', 'internet': 'yes', 'type': 'hotel', 'book stay': '', 'book day': '', 'book people': ''},
'restaurant': {'food': '', 'price range': '', 'name': '', 'area': '', 'book time': '', 'book day': '', 'book people': ''},
'taxi': {'leave at': '', 'destination': '', 'departure': '', 'arrive by': ''},
'train': {'leave at': '', 'destination': '', 'day': '', 'arrive by': '', 'departure': '', 'book people': ''},
'hospital': {'department': ''}
},
booked={},
request_state={},
terminated=False,
......
......@@ -9,7 +9,11 @@ import random
from pprint import pprint
from argparse import ArgumentParser
from convlab.nlu.jointBERT.unified_datasets import BERTNLU
<<<<<<< HEAD
# from convlab.nlu.jointBERT.multiwoz import BERTNLU as BERTNLU_woz
=======
from convlab.nlu.jointBERT.multiwoz import BERTNLU as BERTNLU_woz
>>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805
# from convlab.nlu.milu.multiwoz import MILU
# available DST models
from convlab.dst.rule.multiwoz import RuleDST
......@@ -80,6 +84,7 @@ def test_end2end(args, model_dir):
# where the models are saved from training
# lava_dir = "/gpfs/project/lubis/ConvLab-3/convlab/policy/lava/multiwoz/experiments_woz/sys_config_log_model/"
lava_dir = "/gpfs/project/lubis/LAVA_code/LAVA_published/experiments_woz/sys_config_log_model/"
if "rl" in model_dir:
......@@ -121,7 +126,11 @@ def test_end2end(args, model_dir):
user_agent = PipelineAgent(
user_nlu, user_dst, user_policy, user_nlg, name='user')
sys_agent = PipelineAgent(
<<<<<<< HEAD
sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
=======
sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True, word_level_policy_nlu=user_nlu)
>>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805
sys_agent.add_booking_info = False
......@@ -132,7 +141,11 @@ def test_end2end(args, model_dir):
set_seed(args.seed)
model_name = '{}_{}_lava_{}'.format(args.US_type, args.dst_type, model_dir)
<<<<<<< HEAD
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=500)
=======
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=10)
>>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805
if __name__ == '__main__':
parser = ArgumentParser()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment