Skip to content
Snippets Groups Projects
Commit 50424c34 authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

add template for interent-no, parking-no in templatenlg

parent b190c50a
No related branches found
No related tags found
No related merge requests found
...@@ -207,6 +207,11 @@ class TemplateNLG(NLG): ...@@ -207,6 +207,11 @@ class TemplateNLG(NLG):
"I would prefer something that is {} .".format(value), "I would prefer something that is {} .".format(value),
"it needs to be {} .".format(value) "it needs to be {} .".format(value)
]) ])
elif slot in ['Internet', 'Parking'] and value == 'no':
sentence = random.choice([
"It does n't need to have {} .".format(slot.lower()),
"I do n't need free {} .".format(slot.lower()),
])
elif dialog_act in template and slot in template[dialog_act]: elif dialog_act in template and slot in template[dialog_act]:
sentence = random.choice(template[dialog_act][slot]) sentence = random.choice(template[dialog_act][slot])
sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value)) sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value))
...@@ -248,7 +253,7 @@ class TemplateNLG(NLG): ...@@ -248,7 +253,7 @@ class TemplateNLG(NLG):
def example(): def example():
# dialog act # dialog act
dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Name', 'fds'], ['welcome', 'general', 'none', 'none']] dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']]
print(dialog_acts) print(dialog_acts)
# system model for manual, auto, auto_manual # system model for manual, auto, auto_manual
......
...@@ -956,9 +956,9 @@ if __name__ == '__main__': ...@@ -956,9 +956,9 @@ if __name__ == '__main__':
# if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
# break # break
# # pprint(goal) # # pprint(goal)
user_goal = {'domain_ordering': ('restaurant', 'taxi'), user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'),
'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'}, 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
'info': {'internet': 'yes', 'info': {'internet': 'no',
'parking': 'no', 'parking': 'no',
'pricerange': 'moderate', 'pricerange': 'moderate',
'area': 'centre'}}, 'area': 'centre'}},
......
...@@ -81,6 +81,7 @@ class Analyzer: ...@@ -81,6 +81,7 @@ class Analyzer:
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.mkdir(output_dir) os.mkdir(output_dir)
f = open(os.path.join(output_dir, 'res.txt'), 'w') f = open(os.path.join(output_dir, 'res.txt'), 'w')
flog = open(os.path.join(output_dir, 'log.txt'), 'w')
for j in tqdm(range(total_dialog), desc="dialogue"): for j in tqdm(range(total_dialog), desc="dialogue"):
sys_response = '' if self.user_agent.nlu else [] sys_response = '' if self.user_agent.nlu else []
...@@ -106,13 +107,13 @@ class Analyzer: ...@@ -106,13 +107,13 @@ class Analyzer:
for i in range(40): for i in range(40):
sys_response, user_response, session_over, reward = sess.next_turn( sys_response, user_response, session_over, reward = sess.next_turn(
sys_response) sys_response)
# print('user in', sess.user_agent.get_in_da(),file=f) print('user in', sess.user_agent.get_in_da(),file=flog)
# print('user out', sess.user_agent.get_out_da(),file=f) print('user out', sess.user_agent.get_out_da(),file=flog)
# #
# print('sys in', sess.sys_agent.get_in_da(),file=f) print('sys in', sess.sys_agent.get_in_da(),file=flog)
# print('sys out', sess.sys_agent.get_out_da(),file=f) print('sys out', sess.sys_agent.get_out_da(),file=flog)
# print('user:', user_response,file=f) print('user:', user_response,file=flog)
# print('sys:', sys_response,file=f) print('sys:', sys_response,file=flog)
step += 2 step += 2
......
...@@ -67,7 +67,7 @@ def test_end2end(): ...@@ -67,7 +67,7 @@ def test_end2end():
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202) set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000) analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=100)
if __name__ == '__main__': if __name__ == '__main__':
test_end2end() test_end2end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment