From 50424c343574d9c14c63236da02cbd7bcc1af580 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 24 Aug 2020 16:53:19 +0800 Subject: [PATCH] add template for interent-no, parking-no in templatenlg --- convlab2/nlg/template/multiwoz/nlg.py | 7 ++++++- .../policy/rule/multiwoz/policy_agenda_multiwoz.py | 4 ++-- convlab2/util/analysis_tool/analyzer.py | 13 +++++++------ tests/test_end2end.py | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/convlab2/nlg/template/multiwoz/nlg.py b/convlab2/nlg/template/multiwoz/nlg.py index b03fa54..96764ec 100755 --- a/convlab2/nlg/template/multiwoz/nlg.py +++ b/convlab2/nlg/template/multiwoz/nlg.py @@ -207,6 +207,11 @@ class TemplateNLG(NLG): "I would prefer something that is {} .".format(value), "it needs to be {} .".format(value) ]) + elif slot in ['Internet', 'Parking'] and value == 'no': + sentence = random.choice([ + "It does n't need to have {} .".format(slot.lower()), + "I do n't need free {} .".format(slot.lower()), + ]) elif dialog_act in template and slot in template[dialog_act]: sentence = random.choice(template[dialog_act][slot]) sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value)) @@ -248,7 +253,7 @@ class TemplateNLG(NLG): def example(): # dialog act - dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Name', 'fds'], ['welcome', 'general', 'none', 'none']] + dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']] print(dialog_acts) # system model for manual, auto, auto_manual diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index 70572b8..3e717ff 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -956,9 +956,9 @@ if __name__ == '__main__': # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: # break # # pprint(goal) - user_goal = {'domain_ordering': ('restaurant', 'taxi'), + user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'), 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'}, - 'info': {'internet': 'yes', + 'info': {'internet': 'no', 'parking': 'no', 'pricerange': 'moderate', 'area': 'centre'}}, diff --git a/convlab2/util/analysis_tool/analyzer.py b/convlab2/util/analysis_tool/analyzer.py index fb25a31..eeae31c 100755 --- a/convlab2/util/analysis_tool/analyzer.py +++ b/convlab2/util/analysis_tool/analyzer.py @@ -81,6 +81,7 @@ class Analyzer: if not os.path.exists(output_dir): os.mkdir(output_dir) f = open(os.path.join(output_dir, 'res.txt'), 'w') + flog = open(os.path.join(output_dir, 'log.txt'), 'w') for j in tqdm(range(total_dialog), desc="dialogue"): sys_response = '' if self.user_agent.nlu else [] @@ -106,13 +107,13 @@ class Analyzer: for i in range(40): sys_response, user_response, session_over, reward = sess.next_turn( sys_response) - # print('user in', sess.user_agent.get_in_da(),file=f) - # print('user out', sess.user_agent.get_out_da(),file=f) + print('user in', sess.user_agent.get_in_da(),file=flog) + print('user out', sess.user_agent.get_out_da(),file=flog) # - # print('sys in', sess.sys_agent.get_in_da(),file=f) - # print('sys out', sess.sys_agent.get_out_da(),file=f) - # print('user:', user_response,file=f) - # print('sys:', sys_response,file=f) + print('sys in', sess.sys_agent.get_in_da(),file=flog) + print('sys out', sess.sys_agent.get_out_da(),file=flog) + print('user:', user_response,file=flog) + print('sys:', sys_response,file=flog) step += 2 diff --git a/tests/test_end2end.py b/tests/test_end2end.py index bf2a089..2aef321 100755 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -67,7 +67,7 @@ def test_end2end(): analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') set_seed(20200202) - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=100) if __name__ == '__main__': test_end2end() -- GitLab