From b190c50ad731defb525ad3ee6616438844401a17 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 24 Aug 2020 16:04:39 +0800 Subject: [PATCH] add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy --- convlab2/dialog_agent/agent.py | 6 +- convlab2/dialog_agent/session.py | 4 +- .../multiwoz/manual_user_template_nlg.json | 6 +- .../rule/multiwoz/policy_agenda_multiwoz.py | 359 +++++++++++------- convlab2/policy/rule/multiwoz/rule.py | 4 +- .../rule/multiwoz/rule_based_multiwoz_bot.py | 20 +- tests/test_end2end.py | 2 +- 7 files changed, 247 insertions(+), 154 deletions(-) diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py index f2d9ee6..adb9d33 100755 --- a/convlab2/dialog_agent/agent.py +++ b/convlab2/dialog_agent/agent.py @@ -32,7 +32,7 @@ class Agent(ABC): pass @abstractmethod - def init_session(self): + def init_session(self, **kwargs): """Reset the class variables to prepare for a new session.""" pass @@ -140,7 +140,7 @@ class PipelineAgent(Agent): return self.policy.get_reward() return None - def init_session(self): + def init_session(self, **kwargs): """Init the attributes of DST and Policy module.""" if self.nlu is not None: self.nlu.init_session() @@ -149,7 +149,7 @@ class PipelineAgent(Agent): if self.name == 'sys': self.dst.state['history'].append([self.name, 'null']) if self.policy is not None: - self.policy.init_session() + self.policy.init_session(**kwargs) if self.nlg is not None: self.nlg.init_session() self.history = [] diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py index f96bab6..abb3e83 100755 --- a/convlab2/dialog_agent/session.py +++ b/convlab2/dialog_agent/session.py @@ -139,9 +139,9 @@ class BiSession(Session): """ self.sys_agent.policy.train() - def init_session(self): + def init_session(self, **kwargs): self.sys_agent.init_session() - self.user_agent.init_session() + self.user_agent.init_session(**kwargs) if self.evaluator: self.evaluator.add_goal(self.user_agent.policy.get_goal()) diff --git a/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json b/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json index 4a44fed..3da8288 100755 --- a/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json +++ b/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json @@ -456,13 +456,13 @@ "I also need a place to dine that is #RESTAURANT-INFORM-PRICE# priced ." ], "Food": [ - "How about #RESTAURANT-INFORM-FOOD# .", + "How about #RESTAURANT-INFORM-FOOD# food .", "are there any #RESTAURANT-INFORM-FOOD# restaurants ?", - "Hmm , I 'll try #RESTAURANT-INFORM-FOOD# .", + "Hmm , I 'll try #RESTAURANT-INFORM-FOOD# food .", "I 'd like to find a #RESTAURANT-INFORM-FOOD# restaurant , if possible .", "Do you have #RESTAURANT-INFORM-FOOD# food ?", "Yes . This restaurant should serve #RESTAURANT-INFORM-FOOD# food too .", - "I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# .", + "I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# food .", "how about a #RESTAURANT-INFORM-FOOD# restaurant ?", "I would prefer #RESTAURANT-INFORM-FOOD# food please ." ], diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index a0af0df..70572b8 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -111,8 +111,8 @@ class UserPolicyAgendaMultiWoz(Policy): self.agenda.close_session() # A -> A' + user_action - action = self.agenda.get_action(random.randint(2, self.max_initiative)) - # action = self.agenda.get_action(self.max_initiative) + # action = self.agenda.get_action(random.randint(2, self.max_initiative)) + action = self.agenda.get_action(self.max_initiative) # transform to DA action = self._transform_usract_out(action) @@ -561,6 +561,7 @@ class Agenda(object): def close_session(self): """ Clear up all actions """ self.__stack = [] + self.__cur_push_num = 0 self.__push(self.CLOSE_ACT) def get_action(self, initiative=1): @@ -867,9 +868,10 @@ class Agenda(object): except Exception as e: break else: - if self.__cur_push_num == 0 or (all([self.__stack[-i]['value'] == DEF_VAL_DNC for i in range(1, self.__cur_push_num+1)])): + if self.__cur_push_num == 0 or (all([self.__stack[-i-1]['value'] == DEF_VAL_DNC for i in + range(0, min(len(self.__stack), self.__cur_push_num))])): # pop more when only dontcare - num2pop = 4 + num2pop = initiative else: num2pop = self.__cur_push_num for _ in range(num2pop): @@ -905,156 +907,247 @@ if __name__ == '__main__': import numpy as np import torch from pprint import pprint + from convlab2.dialog_agent import PipelineAgent, BiSession + from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator + from convlab2.policy.rule.multiwoz import RulePolicy + from convlab2.nlg.template.multiwoz.nlg import TemplateNLG + from convlab2.dst.rule.multiwoz.dst import RuleDST + from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU + seed = 50 np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) - user_policy = UserPolicyAgendaMultiWoz() - from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot - sys_policy = RuleBasedMultiwozBot() - from convlab2.nlg.template.multiwoz.nlg import TemplateNLG - user_nlg = TemplateNLG(is_user=True, mode='manual') - sys_nlg = TemplateNLG(is_user=False, mode='manual') - from convlab2.dst.rule.multiwoz.dst import RuleDST - dst = RuleDST() + sys_nlu = BERTNLU() + sys_dst = RuleDST() + sys_policy = RulePolicy() + sys_nlg = TemplateNLG(is_user=False) + sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') + + user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + user_dst = None + user_policy = RulePolicy(character='usr') + user_nlg = TemplateNLG(is_user=True) + user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user') + + # evaluator = MultiWozEvaluator() + # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator) + + + + # user_policy = UserPolicyAgendaMultiWoz() + # + # sys_policy = RuleBasedMultiwozBot() + # + # user_nlg = TemplateNLG(is_user=True, mode='manual') + # sys_nlg = TemplateNLG(is_user=False, mode='manual') + # + # dst = RuleDST() + # + # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') + # goal_generator = GoalGenerator() - while True: - goal = goal_generator.get_user_goal() - if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: - break - # pprint(goal) - user_goal = {'domain_ordering': ('hotel', 'restaurant', 'taxi'), + # while True: + # goal = goal_generator.get_user_goal() + # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: + # break + # # pprint(goal) + user_goal = {'domain_ordering': ('restaurant', 'taxi'), 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'}, 'info': {'internet': 'yes', 'parking': 'no', 'pricerange': 'moderate', 'area': 'centre'}}, 'restaurant': {'info': {'area': 'centre', - 'food': 'chinese', - 'pricerange': 'moderate'}, - 'reqt': ['address']}, + 'food': 'portuguese', + 'pricerange': 'cheap'}, + 'fail_info': {'area': 'centre', + 'food': 'portuguese', + 'pricerange': 'expensive'}, + 'reqt': ['postcode']}, 'taxi': {'info': {'arriveBy': '13:00'}, 'reqt': ['car type', 'phone']}} - # user_goal = goal + # # user_goal = goal goal = Goal(goal_generator) goal.set_user_goal(user_goal) + # + # user_policy.init_session(ini_goal=goal) + # sys_policy.init_session() + # + # goal = user_policy.get_goal() + # + # pprint(goal) + sys_response = '' + # sess.init_session(ini_goal=goal) user_policy.init_session(ini_goal=goal) - sys_policy.init_session() - - goal = user_policy.get_goal() - - pprint(goal) + print('init goal:') + # pprint(user_policy.get_goal()) + pprint(user_agent.policy.get_goal()) + # pprint(sess.evaluator.goal) + # print('-' * 50) + # for i in range(20): + # sys_response, user_response, session_over, reward = sess.next_turn(sys_response) + # print('user:', user_response) + # print('sys:', sys_response) + # print() + # if session_over is True: + # break + # print('task success:', sess.evaluator.task_success()) + # print('book rate:', sess.evaluator.book_rate()) + # print('inform precision/recall/f1:', sess.evaluator.inform_F1()) + # print('-' * 50) + # print('final goal:') + # pprint(sess.evaluator.goal) + # print('=' * 100) + + history = [] + user_utt = user_agent.response('') + print(user_utt) + user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?' + # history.append(['user', user_utt]) + sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese' + sys_utt = sys_agent.response(user_utt) + pprint(sys_agent.dst.state) + print(sys_utt) + sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ." + # history.append(['user', user_utt]) + + user_utt = user_agent.response(sys_utt) + print(user_utt) + user_utt = "It just needs to be cheap ." + sys_utt = sys_agent.response(user_utt) + print(sys_utt) + sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?" - print(user_policy.agenda) - user_act = user_policy.predict([]) - print(user_act) - user_utt = user_nlg.generate(user_act) + user_utt = user_agent.response(sys_utt) print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act.append(["Request", "Restaurant", "Price", "?"]) - # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']] - print(sys_act) - - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) + sys_utt = sys_agent.response(user_utt) + print(sys_utt) + + user_utt = user_agent.response(sys_utt) print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [['Inform', 'Hotel', 'Choice', '3']] - print(sys_act) - - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) + sys_utt = sys_agent.response(user_utt) + print(sys_utt) + + user_utt = user_agent.response(sys_utt) print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]] - print(sys_act) + sys_utt = sys_agent.response(user_utt) + print(sys_utt) + # - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Reqmore", "General", "none", "none"]] - print(sys_act) + # print(user_policy.agenda) + # user_act = user_policy.predict([]) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # sys_utt = sys_nlg.generate(sys_act) + # # sys_act.append(["Request", "Restaurant", "Price", "?"]) + # # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']] + # print(sys_act) + # print(sys_utt) # - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Inform", "Hotel", "Parking", "none"]] - print(sys_act) - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Request", "Booking", "people", "?"]] - print(sys_act) - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]] - print(sys_act) - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]] - print(sys_act) - - user_act = user_policy.predict(sys_act) - print(user_act) - user_utt = user_nlg.generate(user_act) - print(user_utt) - state = dst.state - state['user_action'] = user_act - dst.update(user_act) - # pprint(state) - sys_act = sys_policy.predict(state) - # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]] - print(sys_act) + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [['Inform', 'Hotel', 'Choice', '3']] + # print(sys_act) + # + # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]] + # print(sys_act) + # # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Reqmore", "General", "none", "none"]] + # print(sys_act) + # # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Inform", "Hotel", "Parking", "none"]] + # print(sys_act) + # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Request", "Booking", "people", "?"]] + # print(sys_act) + # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]] + # print(sys_act) + # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]] + # print(sys_act) + # + # user_act = user_policy.predict(sys_act) + # print(user_act) + # user_utt = user_nlg.generate(user_act) + # print(user_utt) + # state = dst.state + # state['user_action'] = user_act + # dst.update(user_act) + # # pprint(state) + # sys_act = sys_policy.predict(state) + # # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]] + # print(sys_act) diff --git a/convlab2/policy/rule/multiwoz/rule.py b/convlab2/policy/rule/multiwoz/rule.py index fb0c463..9fdc3a3 100755 --- a/convlab2/policy/rule/multiwoz/rule.py +++ b/convlab2/policy/rule/multiwoz/rule.py @@ -30,11 +30,11 @@ class RulePolicy(Policy): """ return self.policy.predict(state) - def init_session(self): + def init_session(self, **kwargs): """ Restore after one session """ - self.policy.init_session() + self.policy.init_session(**kwargs) def is_terminated(self): if self.character == 'sys': diff --git a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py index e17ba98..7683936 100755 --- a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py +++ b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py @@ -225,16 +225,16 @@ class RuleBasedMultiwozBot(Policy): slot_name = REF_USR_DA[domain].get(slot, slot) DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]]) - p = random.random() - - # Ask user if he wants to change constraint - if p < 0.3: - req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) - if domain + "-Request" not in DA: - DA[domain + "-Request"] = [] - for i in range(req_num): - slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) - DA[domain + "-Request"].append([slot_name, "?"]) + # p = random.random() + + # # Ask user if he wants to change constraint + # if p < 0.3: + # req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) + # if domain + "-Request" not in DA: + # DA[domain + "-Request"] = [] + # for i in range(req_num): + # slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) + # DA[domain + "-Request"].append([slot_name, "?"]) # There's exactly one result matching user's constraint # elif len(state['kb_results_dict']) == 1: diff --git a/tests/test_end2end.py b/tests/test_end2end.py index c07b286..bf2a089 100755 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -67,7 +67,7 @@ def test_end2end(): analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') set_seed(20200202) - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-RulePolicy-TemplateNLG', total_dialog=1000) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000) if __name__ == '__main__': test_end2end() -- GitLab