From b190c50ad731defb525ad3ee6616438844401a17 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 24 Aug 2020 16:04:39 +0800
Subject: [PATCH] add **kwargs in init_session for self-defined goal; remove
 request for nooffer-slot in rule-sys-policy

---
 convlab2/dialog_agent/agent.py                |   6 +-
 convlab2/dialog_agent/session.py              |   4 +-
 .../multiwoz/manual_user_template_nlg.json    |   6 +-
 .../rule/multiwoz/policy_agenda_multiwoz.py   | 359 +++++++++++-------
 convlab2/policy/rule/multiwoz/rule.py         |   4 +-
 .../rule/multiwoz/rule_based_multiwoz_bot.py  |  20 +-
 tests/test_end2end.py                         |   2 +-
 7 files changed, 247 insertions(+), 154 deletions(-)

diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py
index f2d9ee6..adb9d33 100755
--- a/convlab2/dialog_agent/agent.py
+++ b/convlab2/dialog_agent/agent.py
@@ -32,7 +32,7 @@ class Agent(ABC):
         pass
 
     @abstractmethod
-    def init_session(self):
+    def init_session(self, **kwargs):
         """Reset the class variables to prepare for a new session."""
         pass
 
@@ -140,7 +140,7 @@ class PipelineAgent(Agent):
             return self.policy.get_reward()
         return None
 
-    def init_session(self):
+    def init_session(self, **kwargs):
         """Init the attributes of DST and Policy module."""
         if self.nlu is not None:
             self.nlu.init_session()
@@ -149,7 +149,7 @@ class PipelineAgent(Agent):
             if self.name == 'sys':
                 self.dst.state['history'].append([self.name, 'null'])
         if self.policy is not None:
-            self.policy.init_session()
+            self.policy.init_session(**kwargs)
         if self.nlg is not None:
             self.nlg.init_session()
         self.history = []
diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py
index f96bab6..abb3e83 100755
--- a/convlab2/dialog_agent/session.py
+++ b/convlab2/dialog_agent/session.py
@@ -139,9 +139,9 @@ class BiSession(Session):
         """
         self.sys_agent.policy.train()
 
-    def init_session(self):
+    def init_session(self, **kwargs):
         self.sys_agent.init_session()
-        self.user_agent.init_session()
+        self.user_agent.init_session(**kwargs)
         if self.evaluator:
             self.evaluator.add_goal(self.user_agent.policy.get_goal())
 
diff --git a/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json b/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json
index 4a44fed..3da8288 100755
--- a/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json
+++ b/convlab2/nlg/template/multiwoz/manual_user_template_nlg.json
@@ -456,13 +456,13 @@
             "I also need a place to dine that is #RESTAURANT-INFORM-PRICE# priced ."
         ],
         "Food": [
-            "How about #RESTAURANT-INFORM-FOOD# .",
+            "How about #RESTAURANT-INFORM-FOOD# food .",
             "are there any #RESTAURANT-INFORM-FOOD# restaurants ?",
-            "Hmm , I 'll try #RESTAURANT-INFORM-FOOD# .",
+            "Hmm , I 'll try #RESTAURANT-INFORM-FOOD# food .",
             "I 'd like to find a #RESTAURANT-INFORM-FOOD# restaurant , if possible .",
             "Do you have #RESTAURANT-INFORM-FOOD# food ?",
             "Yes . This restaurant should serve #RESTAURANT-INFORM-FOOD# food too .",
-            "I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# .",
+            "I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# food .",
             "how about a #RESTAURANT-INFORM-FOOD# restaurant ?",
             "I would prefer #RESTAURANT-INFORM-FOOD# food please ."
         ],
diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index a0af0df..70572b8 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -111,8 +111,8 @@ class UserPolicyAgendaMultiWoz(Policy):
                 self.agenda.close_session()
 
         # A -> A' + user_action
-        action = self.agenda.get_action(random.randint(2, self.max_initiative))
-        # action = self.agenda.get_action(self.max_initiative)
+        # action = self.agenda.get_action(random.randint(2, self.max_initiative))
+        action = self.agenda.get_action(self.max_initiative)
 
         # transform to DA
         action = self._transform_usract_out(action)
@@ -561,6 +561,7 @@ class Agenda(object):
     def close_session(self):
         """ Clear up all actions """
         self.__stack = []
+        self.__cur_push_num = 0
         self.__push(self.CLOSE_ACT)
 
     def get_action(self, initiative=1):
@@ -867,9 +868,10 @@ class Agenda(object):
                 except Exception as e:
                     break
         else:
-            if self.__cur_push_num == 0 or (all([self.__stack[-i]['value'] == DEF_VAL_DNC for i in range(1, self.__cur_push_num+1)])):
+            if self.__cur_push_num == 0 or (all([self.__stack[-i-1]['value'] == DEF_VAL_DNC for i in
+                                                 range(0, min(len(self.__stack), self.__cur_push_num))])):
                 # pop more when only dontcare
-                num2pop = 4
+                num2pop = initiative
             else:
                 num2pop = self.__cur_push_num
             for _ in range(num2pop):
@@ -905,156 +907,247 @@ if __name__ == '__main__':
     import numpy as np
     import torch
     from pprint import pprint
+    from convlab2.dialog_agent import PipelineAgent, BiSession
+    from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
+    from convlab2.policy.rule.multiwoz import RulePolicy
+    from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
+    from convlab2.dst.rule.multiwoz.dst import RuleDST
+    from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU
+
     seed = 50
     np.random.seed(seed)
     random.seed(seed)
     torch.manual_seed(seed)
 
-    user_policy = UserPolicyAgendaMultiWoz()
-    from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
-    sys_policy = RuleBasedMultiwozBot()
-    from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
-    user_nlg = TemplateNLG(is_user=True, mode='manual')
-    sys_nlg = TemplateNLG(is_user=False, mode='manual')
-    from convlab2.dst.rule.multiwoz.dst import RuleDST
-    dst = RuleDST()
+    sys_nlu = BERTNLU()
+    sys_dst = RuleDST()
+    sys_policy = RulePolicy()
+    sys_nlg = TemplateNLG(is_user=False)
+    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
+
+    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
+                       model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
+    user_dst = None
+    user_policy = RulePolicy(character='usr')
+    user_nlg = TemplateNLG(is_user=True)
+    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
+
+    # evaluator = MultiWozEvaluator()
+    # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)
+
+
 
+
+    # user_policy = UserPolicyAgendaMultiWoz()
+    #
+    # sys_policy = RuleBasedMultiwozBot()
+    #
+    # user_nlg = TemplateNLG(is_user=True, mode='manual')
+    # sys_nlg = TemplateNLG(is_user=False, mode='manual')
+    #
+    # dst = RuleDST()
+    #
+    # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
+    #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
+    #
     goal_generator = GoalGenerator()
-    while True:
-        goal = goal_generator.get_user_goal()
-        if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
-            break
-    # pprint(goal)
-    user_goal = {'domain_ordering': ('hotel', 'restaurant', 'taxi'),
+    # while True:
+    #     goal = goal_generator.get_user_goal()
+    #     if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
+    #         break
+    # # pprint(goal)
+    user_goal = {'domain_ordering': ('restaurant', 'taxi'),
                  'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
                            'info': {'internet': 'yes',
                                     'parking': 'no',
                                     'pricerange': 'moderate',
                                     'area': 'centre'}},
                  'restaurant': {'info': {'area': 'centre',
-                                         'food': 'chinese',
-                                         'pricerange': 'moderate'},
-                                'reqt': ['address']},
+                                         'food': 'portuguese',
+                                         'pricerange': 'cheap'},
+                                'fail_info': {'area': 'centre',
+                                         'food': 'portuguese',
+                                         'pricerange': 'expensive'},
+                                'reqt': ['postcode']},
                  'taxi': {'info': {'arriveBy': '13:00'}, 'reqt': ['car type', 'phone']}}
-    # user_goal = goal
+    # # user_goal = goal
     goal = Goal(goal_generator)
     goal.set_user_goal(user_goal)
+    #
+    # user_policy.init_session(ini_goal=goal)
+    # sys_policy.init_session()
+    #
+    # goal = user_policy.get_goal()
+    #
+    # pprint(goal)
 
+    sys_response = ''
+    # sess.init_session(ini_goal=goal)
     user_policy.init_session(ini_goal=goal)
-    sys_policy.init_session()
-
-    goal = user_policy.get_goal()
-
-    pprint(goal)
+    print('init goal:')
+    # pprint(user_policy.get_goal())
+    pprint(user_agent.policy.get_goal())
+    # pprint(sess.evaluator.goal)
+    # print('-' * 50)
+    # for i in range(20):
+    #     sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
+    #     print('user:', user_response)
+    #     print('sys:', sys_response)
+    #     print()
+    #     if session_over is True:
+    #         break
+    # print('task success:', sess.evaluator.task_success())
+    # print('book rate:', sess.evaluator.book_rate())
+    # print('inform precision/recall/f1:', sess.evaluator.inform_F1())
+    # print('-' * 50)
+    # print('final goal:')
+    # pprint(sess.evaluator.goal)
+    # print('=' * 100)
+
+    history = []
+    user_utt = user_agent.response('')
+    print(user_utt)
+    user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
+    # history.append(['user', user_utt])
+    sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
+    sys_utt = sys_agent.response(user_utt)
+    pprint(sys_agent.dst.state)
+    print(sys_utt)
+    sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
+    # history.append(['user', user_utt])
+
+    user_utt = user_agent.response(sys_utt)
+    print(user_utt)
+    user_utt = "It just needs to be cheap ."
+    sys_utt = sys_agent.response(user_utt)
+    print(sys_utt)
+    sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
 
-    print(user_policy.agenda)
-    user_act = user_policy.predict([])
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
+    user_utt = user_agent.response(sys_utt)
     print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act.append(["Request", "Restaurant", "Price", "?"])
-    # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
-    print(sys_act)
-
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
+    sys_utt = sys_agent.response(user_utt)
+    print(sys_utt)
+
+    user_utt = user_agent.response(sys_utt)
     print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
-    print(sys_act)
-
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
+    sys_utt = sys_agent.response(user_utt)
+    print(sys_utt)
+
+    user_utt = user_agent.response(sys_utt)
     print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
-    print(sys_act)
+    sys_utt = sys_agent.response(user_utt)
+    print(sys_utt)
+
     #
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Reqmore", "General", "none", "none"]]
-    print(sys_act)
+    # print(user_policy.agenda)
+    # user_act = user_policy.predict([])
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # sys_utt = sys_nlg.generate(sys_act)
+    # # sys_act.append(["Request", "Restaurant", "Price", "?"])
+    # # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
+    # print(sys_act)
+    # print(sys_utt)
     #
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Inform", "Hotel", "Parking", "none"]]
-    print(sys_act)
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Request", "Booking", "people", "?"]]
-    print(sys_act)
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
-    print(sys_act)
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]]
-    print(sys_act)
-
-    user_act = user_policy.predict(sys_act)
-    print(user_act)
-    user_utt = user_nlg.generate(user_act)
-    print(user_utt)
-    state = dst.state
-    state['user_action'] = user_act
-    dst.update(user_act)
-    # pprint(state)
-    sys_act = sys_policy.predict(state)
-    # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]]
-    print(sys_act)
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
+    # print(sys_act)
+    #
+    #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
+    # print(sys_act)
+    # #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Reqmore", "General", "none", "none"]]
+    # print(sys_act)
+    # #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Inform", "Hotel", "Parking", "none"]]
+    # print(sys_act)
+    #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Request", "Booking", "people", "?"]]
+    # print(sys_act)
+    #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
+    # print(sys_act)
+    #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]]
+    # print(sys_act)
+    #
+    # user_act = user_policy.predict(sys_act)
+    # print(user_act)
+    # user_utt = user_nlg.generate(user_act)
+    # print(user_utt)
+    # state = dst.state
+    # state['user_action'] = user_act
+    # dst.update(user_act)
+    # # pprint(state)
+    # sys_act = sys_policy.predict(state)
+    # # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]]
+    # print(sys_act)
diff --git a/convlab2/policy/rule/multiwoz/rule.py b/convlab2/policy/rule/multiwoz/rule.py
index fb0c463..9fdc3a3 100755
--- a/convlab2/policy/rule/multiwoz/rule.py
+++ b/convlab2/policy/rule/multiwoz/rule.py
@@ -30,11 +30,11 @@ class RulePolicy(Policy):
         """
         return self.policy.predict(state)
 
-    def init_session(self):
+    def init_session(self, **kwargs):
         """
         Restore after one session
         """
-        self.policy.init_session()
+        self.policy.init_session(**kwargs)
 
     def is_terminated(self):
         if self.character == 'sys':
diff --git a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py
index e17ba98..7683936 100755
--- a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py
+++ b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py
@@ -225,16 +225,16 @@ class RuleBasedMultiwozBot(Policy):
                         slot_name = REF_USR_DA[domain].get(slot, slot)
                         DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]])
 
-                p = random.random()
-
-                # Ask user if he wants to change constraint
-                if p < 0.3:
-                    req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
-                    if domain + "-Request" not in DA:
-                        DA[domain + "-Request"] = []
-                    for i in range(req_num):
-                        slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
-                        DA[domain + "-Request"].append([slot_name, "?"])
+                # p = random.random()
+
+                # # Ask user if he wants to change constraint
+                # if p < 0.3:
+                #     req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
+                #     if domain + "-Request" not in DA:
+                #         DA[domain + "-Request"] = []
+                #     for i in range(req_num):
+                #         slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
+                #         DA[domain + "-Request"].append([slot_name, "?"])
 
             # There's exactly one result matching user's constraint
             # elif len(state['kb_results_dict']) == 1:
diff --git a/tests/test_end2end.py b/tests/test_end2end.py
index c07b286..bf2a089 100755
--- a/tests/test_end2end.py
+++ b/tests/test_end2end.py
@@ -67,7 +67,7 @@ def test_end2end():
     analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
 
     set_seed(20200202)
-    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-RulePolicy-TemplateNLG', total_dialog=1000)
+    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000)
 
 if __name__ == '__main__':
     test_end2end()
-- 
GitLab