From cb7bb5b004fe1c733671a8db00585f9958f2c2d7 Mon Sep 17 00:00:00 2001 From: zhuqi <zqwerty@users.noreply.github.com> Date: Sun, 27 Sep 2020 16:43:26 +0800 Subject: [PATCH] Maintenance (#132) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * fix lower case for int value in nlg.py * fix empty user utterance problem in multiwoz simulator, issue #130 * remove debug output --- .../rule/multiwoz/policy_agenda_multiwoz.py | 213 ++++++++++-------- 1 file changed, 123 insertions(+), 90 deletions(-) diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index 3e717ff..a13e5e8 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -106,16 +106,20 @@ class UserPolicyAgendaMultiWoz(Policy): self.agenda.close_session() else: sys_action = self._transform_sysact_in(sys_action) + # print('sys action before update agenda', sys_action) self.agenda.update(sys_action, self.goal) if self.goal.task_complete(): self.agenda.close_session() - # A -> A' + user_action - # action = self.agenda.get_action(random.randint(2, self.max_initiative)) - action = self.agenda.get_action(self.max_initiative) + action = {} + while len(action) == 0: + # A -> A' + user_action + # action = self.agenda.get_action(random.randint(2, self.max_initiative)) + action = self.agenda.get_action(self.max_initiative) - # transform to DA - action = self._transform_usract_out(action) + # transform to DA + action = self._transform_usract_out(action) + # print(action) tuples = [] for domain_intent, svs in action.items(): @@ -169,6 +173,8 @@ class UserPolicyAgendaMultiWoz(Policy): new_action[new_act].append(['NotBook', 'none']) elif slot is not None: new_action[new_act].append([slot, pairs[1]]) + if len(new_action[new_act]) == 0: + new_action.pop(new_act) # new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]] else: new_action[act] = action[act] @@ -848,7 +854,6 @@ class Agenda(object): diaacts = [] slots = [] values = [] - p_diaact, p_slot = self.__check_next_diaact_slot() if p_diaact.split('-')[1] == 'inform' and p_slot in BOOK_SLOT: for _ in range(10 if self.__cur_push_num == 0 else self.__cur_push_num): @@ -914,23 +919,23 @@ if __name__ == '__main__': from convlab2.dst.rule.multiwoz.dst import RuleDST from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU - seed = 50 + seed = 41 np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) + # + # sys_nlu = BERTNLU() + # sys_dst = RuleDST() + # sys_policy = RulePolicy() + # sys_nlg = TemplateNLG(is_user=False) + # sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - sys_nlu = BERTNLU() - sys_dst = RuleDST() - sys_policy = RulePolicy() - sys_nlg = TemplateNLG(is_user=False) - sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') - - user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', - model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') - user_dst = None - user_policy = RulePolicy(character='usr') - user_nlg = TemplateNLG(is_user=True) - user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') + # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + # user_dst = None + # user_policy = RulePolicy(character='usr') + # user_nlg = TemplateNLG(is_user=True) + # user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') # evaluator = MultiWozEvaluator() # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator) @@ -938,14 +943,14 @@ if __name__ == '__main__': - # user_policy = UserPolicyAgendaMultiWoz() + user_policy = UserPolicyAgendaMultiWoz() # - # sys_policy = RuleBasedMultiwozBot() + sys_policy = RulePolicy(character='sys') # - # user_nlg = TemplateNLG(is_user=True, mode='manual') - # sys_nlg = TemplateNLG(is_user=False, mode='manual') + user_nlg = TemplateNLG(is_user=True, mode='manual') + sys_nlg = TemplateNLG(is_user=False, mode='manual') # - # dst = RuleDST() + dst = RuleDST() # # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') @@ -956,12 +961,24 @@ if __name__ == '__main__': # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: # break # # pprint(goal) - user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'), - 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'}, - 'info': {'internet': 'no', - 'parking': 'no', - 'pricerange': 'moderate', - 'area': 'centre'}}, + user_goal = {'domain_ordering': ('hotel', 'attraction'), + 'train': { + 'info': {'arriveBy': '16:00', + 'day': 'monday', + 'departure': 'cambridge', + 'destination': 'stansted airport'}, + 'book': {'people': 2}, 'booked': '?' + }, + 'attraction': { + 'info': {'type': 'museum'}, + 'reqt': ['phone'] + }, + 'hotel': { + 'info': {'internet': 'yes', + 'parking': 'yes', + 'stars': '4', + 'type': 'hotel'}, + 'reqt': ['postcode']}, 'restaurant': {'info': {'area': 'centre', 'food': 'portuguese', 'pricerange': 'cheap'}, @@ -986,7 +1003,7 @@ if __name__ == '__main__': user_policy.init_session(ini_goal=goal) print('init goal:') # pprint(user_policy.get_goal()) - pprint(user_agent.policy.get_goal()) + # pprint(user_agent.policy.get_goal()) # pprint(sess.evaluator.goal) # print('-' * 50) # for i in range(20): @@ -1005,60 +1022,66 @@ if __name__ == '__main__': # print('=' * 100) history = [] - user_utt = user_agent.response('') - print(user_utt) - user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?' - # history.append(['user', user_utt]) - sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese' - sys_utt = sys_agent.response(user_utt) - pprint(sys_agent.dst.state) - print(sys_utt) - sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ." - # history.append(['user', user_utt]) - - user_utt = user_agent.response(sys_utt) - print(user_utt) - user_utt = "It just needs to be cheap ." - sys_utt = sys_agent.response(user_utt) - print(sys_utt) - sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?" - - user_utt = user_agent.response(sys_utt) - print(user_utt) - sys_utt = sys_agent.response(user_utt) - print(sys_utt) - - user_utt = user_agent.response(sys_utt) - print(user_utt) - sys_utt = sys_agent.response(user_utt) - print(sys_utt) - - user_utt = user_agent.response(sys_utt) - print(user_utt) - sys_utt = sys_agent.response(user_utt) - print(sys_utt) - + # user_utt = user_agent.response('') + # print(user_utt) + # user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?' + # # history.append(['user', user_utt]) + # sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese' + # sys_utt = sys_agent.response(user_utt) + # pprint(sys_agent.dst.state) + # print(sys_utt) + # sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ." + # # history.append(['user', user_utt]) # - # print(user_policy.agenda) - # user_act = user_policy.predict([]) - # print(user_act) - # user_utt = user_nlg.generate(user_act) + # user_utt = user_agent.response(sys_utt) # print(user_utt) - # state = dst.state - # state['user_action'] = user_act - # dst.update(user_act) - # # pprint(state) - # sys_act = sys_policy.predict(state) - # sys_utt = sys_nlg.generate(sys_act) - # # sys_act.append(["Request", "Restaurant", "Price", "?"]) - # # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']] - # print(sys_act) + # user_utt = "It just needs to be cheap ." + # sys_utt = sys_agent.response(user_utt) # print(sys_utt) + # sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?" # - # user_act = user_policy.predict(sys_act) - # print(user_act) - # user_utt = user_nlg.generate(user_act) + # user_utt = user_agent.response(sys_utt) + # print(user_utt) + # sys_utt = sys_agent.response(user_utt) + # print(sys_utt) + # + # user_utt = user_agent.response(sys_utt) + # print(user_utt) + # sys_utt = sys_agent.response(user_utt) + # print(sys_utt) + # + # user_utt = user_agent.response(sys_utt) # print(user_utt) + # sys_utt = sys_agent.response(user_utt) + # print(sys_utt) + + # + print(user_policy.agenda) + user_act = user_policy.predict([]) + print(user_act) + user_utt = user_nlg.generate(user_act) + print(user_utt) + history.append(['user', user_utt]) + state = dst.state + state['user_action'] = user_act + dst.update(user_act) + # pprint(state) + sys_act = sys_policy.predict(state) + sys_utt = sys_nlg.generate(sys_act) + # sys_act.append(["Request", "Restaurant", "Price", "?"]) + # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']] + sys_act = [['Inform', 'Hotel', 'Post', 'pe296fl']] + print(sys_act) + history.append(['sys', user_utt]) + + # sys_utt = sys_agent.response(user_utt) + # print(sys_utt) + # + user_act = user_policy.predict(sys_act) + print(user_act) + user_utt = user_nlg.generate(user_act) + print(user_utt) + history.append(['user', user_utt]) # state = dst.state # state['user_action'] = user_act # dst.update(user_act) @@ -1066,24 +1089,34 @@ if __name__ == '__main__': # sys_act = sys_policy.predict(state) # # sys_act = [['Inform', 'Hotel', 'Choice', '3']] # print(sys_act) + sys_act = [ + ['Inform', 'Hotel', 'Post', 'pe296fl'] + ] + print(sys_act) + # sys_utt = sys_agent.response(user_utt) + # print(sys_utt) + # sys_utt = 'The arrive time is 15:08 . The train will be departing from cambridge . The booking is for arriving in stansted airport . TR6936 will be your perfect fit . How about 14:40 will that work for you ?' + # history.append(['sys', user_utt]) # # - # user_act = user_policy.predict(sys_act) - # print(user_act) - # user_utt = user_nlg.generate(user_act) - # print(user_utt) + # sys_act = user_nlu.predict(sys_utt, history) + # print(sys_act) + user_act = user_policy.predict(sys_act) + print(user_act) + user_utt = user_nlg.generate(user_act) + print(user_utt) # state = dst.state # state['user_action'] = user_act # dst.update(user_act) # # pprint(state) # sys_act = sys_policy.predict(state) - # # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]] - # print(sys_act) + sys_act = [['Request', 'Hotel', 'Price', '?'], ['Request', 'Attraction', 'Price', '?']] + print(sys_act) # # - # user_act = user_policy.predict(sys_act) - # print(user_act) - # user_utt = user_nlg.generate(user_act) - # print(user_utt) + user_act = user_policy.predict(sys_act) + print(user_act) + user_utt = user_nlg.generate(user_act) + print(user_utt) # state = dst.state # state['user_action'] = user_act # dst.update(user_act) -- GitLab