From 50424c343574d9c14c63236da02cbd7bcc1af580 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 24 Aug 2020 16:53:19 +0800
Subject: [PATCH] add template for interent-no, parking-no in templatenlg

---
 convlab2/nlg/template/multiwoz/nlg.py               |  7 ++++++-
 .../policy/rule/multiwoz/policy_agenda_multiwoz.py  |  4 ++--
 convlab2/util/analysis_tool/analyzer.py             | 13 +++++++------
 tests/test_end2end.py                               |  2 +-
 4 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/convlab2/nlg/template/multiwoz/nlg.py b/convlab2/nlg/template/multiwoz/nlg.py
index b03fa54..96764ec 100755
--- a/convlab2/nlg/template/multiwoz/nlg.py
+++ b/convlab2/nlg/template/multiwoz/nlg.py
@@ -207,6 +207,11 @@ class TemplateNLG(NLG):
                             "I would prefer something that is {} .".format(value),
                             "it needs to be {} .".format(value)
                         ])
+                    elif slot in ['Internet', 'Parking'] and value == 'no':
+                        sentence = random.choice([
+                            "It does n't need to have {} .".format(slot.lower()),
+                            "I do n't need free {} .".format(slot.lower()),
+                        ])
                     elif dialog_act in template and slot in template[dialog_act]:
                         sentence = random.choice(template[dialog_act][slot])
                         sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value))
@@ -248,7 +253,7 @@ class TemplateNLG(NLG):
 
 def example():
     # dialog act
-    dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Name', 'fds'], ['welcome', 'general', 'none', 'none']]
+    dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']]
     print(dialog_acts)
 
     # system model for manual, auto, auto_manual
diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index 70572b8..3e717ff 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -956,9 +956,9 @@ if __name__ == '__main__':
     #     if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
     #         break
     # # pprint(goal)
-    user_goal = {'domain_ordering': ('restaurant', 'taxi'),
+    user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'),
                  'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
-                           'info': {'internet': 'yes',
+                           'info': {'internet': 'no',
                                     'parking': 'no',
                                     'pricerange': 'moderate',
                                     'area': 'centre'}},
diff --git a/convlab2/util/analysis_tool/analyzer.py b/convlab2/util/analysis_tool/analyzer.py
index fb25a31..eeae31c 100755
--- a/convlab2/util/analysis_tool/analyzer.py
+++ b/convlab2/util/analysis_tool/analyzer.py
@@ -81,6 +81,7 @@ class Analyzer:
         if not os.path.exists(output_dir):
             os.mkdir(output_dir)
         f = open(os.path.join(output_dir, 'res.txt'), 'w')
+        flog = open(os.path.join(output_dir, 'log.txt'), 'w')
 
         for j in tqdm(range(total_dialog), desc="dialogue"):
             sys_response = '' if self.user_agent.nlu else []
@@ -106,13 +107,13 @@ class Analyzer:
             for i in range(40):
                 sys_response, user_response, session_over, reward = sess.next_turn(
                     sys_response)
-                # print('user in', sess.user_agent.get_in_da(),file=f)
-                # print('user out', sess.user_agent.get_out_da(),file=f)
+                print('user in', sess.user_agent.get_in_da(),file=flog)
+                print('user out', sess.user_agent.get_out_da(),file=flog)
                 #
-                # print('sys in', sess.sys_agent.get_in_da(),file=f)
-                # print('sys out', sess.sys_agent.get_out_da(),file=f)
-                # print('user:', user_response,file=f)
-                # print('sys:', sys_response,file=f)
+                print('sys in', sess.sys_agent.get_in_da(),file=flog)
+                print('sys out', sess.sys_agent.get_out_da(),file=flog)
+                print('user:', user_response,file=flog)
+                print('sys:', sys_response,file=flog)
 
                 step += 2
 
diff --git a/tests/test_end2end.py b/tests/test_end2end.py
index bf2a089..2aef321 100755
--- a/tests/test_end2end.py
+++ b/tests/test_end2end.py
@@ -67,7 +67,7 @@ def test_end2end():
     analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
 
     set_seed(20200202)
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000)
+    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=100)
 
 if __name__ == '__main__':
     test_end2end()
-- 
GitLab