From cb7bb5b004fe1c733671a8db00585f9958f2c2d7 Mon Sep 17 00:00:00 2001
From: zhuqi <zqwerty@users.noreply.github.com>
Date: Sun, 27 Sep 2020 16:43:26 +0800
Subject: [PATCH] Maintenance (#132)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output
---
 .../rule/multiwoz/policy_agenda_multiwoz.py   | 213 ++++++++++--------
 1 file changed, 123 insertions(+), 90 deletions(-)

diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index 3e717ff..a13e5e8 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -106,16 +106,20 @@ class UserPolicyAgendaMultiWoz(Policy):
             self.agenda.close_session()
         else:
             sys_action = self._transform_sysact_in(sys_action)
+            # print('sys action before update agenda', sys_action)
             self.agenda.update(sys_action, self.goal)
             if self.goal.task_complete():
                 self.agenda.close_session()
 
-        # A -> A' + user_action
-        # action = self.agenda.get_action(random.randint(2, self.max_initiative))
-        action = self.agenda.get_action(self.max_initiative)
+        action = {}
+        while len(action) == 0:
+            # A -> A' + user_action
+            # action = self.agenda.get_action(random.randint(2, self.max_initiative))
+            action = self.agenda.get_action(self.max_initiative)
 
-        # transform to DA
-        action = self._transform_usract_out(action)
+            # transform to DA
+            action = self._transform_usract_out(action)
+            # print(action)
 
         tuples = []
         for domain_intent, svs in action.items():
@@ -169,6 +173,8 @@ class UserPolicyAgendaMultiWoz(Policy):
                             new_action[new_act].append(['NotBook', 'none'])
                         elif slot is not None:
                             new_action[new_act].append([slot, pairs[1]])
+                    if len(new_action[new_act]) == 0:
+                        new_action.pop(new_act)
                     # new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]]
                 else:
                     new_action[act] = action[act]
@@ -848,7 +854,6 @@ class Agenda(object):
         diaacts = []
         slots = []
         values = []
-
         p_diaact, p_slot = self.__check_next_diaact_slot()
         if p_diaact.split('-')[1] == 'inform' and p_slot in BOOK_SLOT:
             for _ in range(10 if self.__cur_push_num == 0 else self.__cur_push_num):
@@ -914,23 +919,23 @@ if __name__ == '__main__':
     from convlab2.dst.rule.multiwoz.dst import RuleDST
     from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU
 
-    seed = 50
+    seed = 41
     np.random.seed(seed)
     random.seed(seed)
     torch.manual_seed(seed)
+    #
+    # sys_nlu = BERTNLU()
+    # sys_dst = RuleDST()
+    # sys_policy = RulePolicy()
+    # sys_nlg = TemplateNLG(is_user=False)
+    # sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
 
-    sys_nlu = BERTNLU()
-    sys_dst = RuleDST()
-    sys_policy = RulePolicy()
-    sys_nlg = TemplateNLG(is_user=False)
-    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
-
-    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
-                       model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
-    user_dst = None
-    user_policy = RulePolicy(character='usr')
-    user_nlg = TemplateNLG(is_user=True)
-    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
+    # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
+    #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
+    # user_dst = None
+    # user_policy = RulePolicy(character='usr')
+    # user_nlg = TemplateNLG(is_user=True)
+    # user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
 
     # evaluator = MultiWozEvaluator()
     # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)
@@ -938,14 +943,14 @@ if __name__ == '__main__':
 
 
 
-    # user_policy = UserPolicyAgendaMultiWoz()
+    user_policy = UserPolicyAgendaMultiWoz()
     #
-    # sys_policy = RuleBasedMultiwozBot()
+    sys_policy = RulePolicy(character='sys')
     #
-    # user_nlg = TemplateNLG(is_user=True, mode='manual')
-    # sys_nlg = TemplateNLG(is_user=False, mode='manual')
+    user_nlg = TemplateNLG(is_user=True, mode='manual')
+    sys_nlg = TemplateNLG(is_user=False, mode='manual')
     #
-    # dst = RuleDST()
+    dst = RuleDST()
     #
     # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
     #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
@@ -956,12 +961,24 @@ if __name__ == '__main__':
     #     if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
     #         break
     # # pprint(goal)
-    user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'),
-                 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
-                           'info': {'internet': 'no',
-                                    'parking': 'no',
-                                    'pricerange': 'moderate',
-                                    'area': 'centre'}},
+    user_goal = {'domain_ordering': ('hotel', 'attraction'),
+                 'train': {
+                     'info': {'arriveBy': '16:00',
+                              'day': 'monday',
+                              'departure': 'cambridge',
+                              'destination': 'stansted airport'},
+                     'book': {'people': 2}, 'booked': '?'
+                 },
+                 'attraction': {
+                     'info': {'type': 'museum'},
+                     'reqt': ['phone']
+                 },
+                 'hotel': {
+                           'info': {'internet': 'yes',
+                                    'parking': 'yes',
+                                    'stars': '4',
+                                    'type': 'hotel'},
+                           'reqt': ['postcode']},
                  'restaurant': {'info': {'area': 'centre',
                                          'food': 'portuguese',
                                          'pricerange': 'cheap'},
@@ -986,7 +1003,7 @@ if __name__ == '__main__':
     user_policy.init_session(ini_goal=goal)
     print('init goal:')
     # pprint(user_policy.get_goal())
-    pprint(user_agent.policy.get_goal())
+    # pprint(user_agent.policy.get_goal())
     # pprint(sess.evaluator.goal)
     # print('-' * 50)
     # for i in range(20):
@@ -1005,60 +1022,66 @@ if __name__ == '__main__':
     # print('=' * 100)
 
     history = []
-    user_utt = user_agent.response('')
-    print(user_utt)
-    user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
-    # history.append(['user', user_utt])
-    sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
-    sys_utt = sys_agent.response(user_utt)
-    pprint(sys_agent.dst.state)
-    print(sys_utt)
-    sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
-    # history.append(['user', user_utt])
-
-    user_utt = user_agent.response(sys_utt)
-    print(user_utt)
-    user_utt = "It just needs to be cheap ."
-    sys_utt = sys_agent.response(user_utt)
-    print(sys_utt)
-    sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
-
-    user_utt = user_agent.response(sys_utt)
-    print(user_utt)
-    sys_utt = sys_agent.response(user_utt)
-    print(sys_utt)
-
-    user_utt = user_agent.response(sys_utt)
-    print(user_utt)
-    sys_utt = sys_agent.response(user_utt)
-    print(sys_utt)
-
-    user_utt = user_agent.response(sys_utt)
-    print(user_utt)
-    sys_utt = sys_agent.response(user_utt)
-    print(sys_utt)
-
+    # user_utt = user_agent.response('')
+    # print(user_utt)
+    # user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
+    # # history.append(['user', user_utt])
+    # sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
+    # sys_utt = sys_agent.response(user_utt)
+    # pprint(sys_agent.dst.state)
+    # print(sys_utt)
+    # sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
+    # # history.append(['user', user_utt])
     #
-    # print(user_policy.agenda)
-    # user_act = user_policy.predict([])
-    # print(user_act)
-    # user_utt = user_nlg.generate(user_act)
+    # user_utt = user_agent.response(sys_utt)
     # print(user_utt)
-    # state = dst.state
-    # state['user_action'] = user_act
-    # dst.update(user_act)
-    # # pprint(state)
-    # sys_act = sys_policy.predict(state)
-    # sys_utt = sys_nlg.generate(sys_act)
-    # # sys_act.append(["Request", "Restaurant", "Price", "?"])
-    # # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
-    # print(sys_act)
+    # user_utt = "It just needs to be cheap ."
+    # sys_utt = sys_agent.response(user_utt)
     # print(sys_utt)
+    # sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
     #
-    # user_act = user_policy.predict(sys_act)
-    # print(user_act)
-    # user_utt = user_nlg.generate(user_act)
+    # user_utt = user_agent.response(sys_utt)
+    # print(user_utt)
+    # sys_utt = sys_agent.response(user_utt)
+    # print(sys_utt)
+    #
+    # user_utt = user_agent.response(sys_utt)
+    # print(user_utt)
+    # sys_utt = sys_agent.response(user_utt)
+    # print(sys_utt)
+    #
+    # user_utt = user_agent.response(sys_utt)
     # print(user_utt)
+    # sys_utt = sys_agent.response(user_utt)
+    # print(sys_utt)
+
+    #
+    print(user_policy.agenda)
+    user_act = user_policy.predict([])
+    print(user_act)
+    user_utt = user_nlg.generate(user_act)
+    print(user_utt)
+    history.append(['user', user_utt])
+    state = dst.state
+    state['user_action'] = user_act
+    dst.update(user_act)
+    # pprint(state)
+    sys_act = sys_policy.predict(state)
+    sys_utt = sys_nlg.generate(sys_act)
+    # sys_act.append(["Request", "Restaurant", "Price", "?"])
+    # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
+    sys_act = [['Inform', 'Hotel', 'Post', 'pe296fl']]
+    print(sys_act)
+    history.append(['sys', user_utt])
+
+    # sys_utt = sys_agent.response(user_utt)
+    # print(sys_utt)
+    #
+    user_act = user_policy.predict(sys_act)
+    print(user_act)
+    user_utt = user_nlg.generate(user_act)
+    print(user_utt)
+    history.append(['user', user_utt])
     # state = dst.state
     # state['user_action'] = user_act
     # dst.update(user_act)
@@ -1066,24 +1089,34 @@ if __name__ == '__main__':
     # sys_act = sys_policy.predict(state)
     # # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
     # print(sys_act)
+    sys_act = [
+        ['Inform', 'Hotel', 'Post', 'pe296fl']
+    ]
+    print(sys_act)
+    # sys_utt = sys_agent.response(user_utt)
+    # print(sys_utt)
+    # sys_utt = 'The arrive time is 15:08 . The train will be departing from cambridge . The booking is for arriving in stansted airport . TR6936 will be your perfect fit . How about 14:40 will that work for you ?'
+    # history.append(['sys', user_utt])
     #
     #
-    # user_act = user_policy.predict(sys_act)
-    # print(user_act)
-    # user_utt = user_nlg.generate(user_act)
-    # print(user_utt)
+    # sys_act = user_nlu.predict(sys_utt, history)
+    # print(sys_act)
+    user_act = user_policy.predict(sys_act)
+    print(user_act)
+    user_utt = user_nlg.generate(user_act)
+    print(user_utt)
     # state = dst.state
     # state['user_action'] = user_act
     # dst.update(user_act)
     # # pprint(state)
     # sys_act = sys_policy.predict(state)
-    # # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
-    # print(sys_act)
+    sys_act = [['Request', 'Hotel', 'Price', '?'], ['Request', 'Attraction', 'Price', '?']]
+    print(sys_act)
     # #
-    # user_act = user_policy.predict(sys_act)
-    # print(user_act)
-    # user_utt = user_nlg.generate(user_act)
-    # print(user_utt)
+    user_act = user_policy.predict(sys_act)
+    print(user_act)
+    user_utt = user_nlg.generate(user_act)
+    print(user_utt)
     # state = dst.state
     # state['user_action'] = user_act
     # dst.update(user_act)
-- 
GitLab