Skip to content
Snippets Groups Projects
Commit b190c50a authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

add **kwargs in init_session for self-defined goal; remove request for...

add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy
parent 98b50a81
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ class Agent(ABC): ...@@ -32,7 +32,7 @@ class Agent(ABC):
pass pass
@abstractmethod @abstractmethod
def init_session(self): def init_session(self, **kwargs):
"""Reset the class variables to prepare for a new session.""" """Reset the class variables to prepare for a new session."""
pass pass
...@@ -140,7 +140,7 @@ class PipelineAgent(Agent): ...@@ -140,7 +140,7 @@ class PipelineAgent(Agent):
return self.policy.get_reward() return self.policy.get_reward()
return None return None
def init_session(self): def init_session(self, **kwargs):
"""Init the attributes of DST and Policy module.""" """Init the attributes of DST and Policy module."""
if self.nlu is not None: if self.nlu is not None:
self.nlu.init_session() self.nlu.init_session()
...@@ -149,7 +149,7 @@ class PipelineAgent(Agent): ...@@ -149,7 +149,7 @@ class PipelineAgent(Agent):
if self.name == 'sys': if self.name == 'sys':
self.dst.state['history'].append([self.name, 'null']) self.dst.state['history'].append([self.name, 'null'])
if self.policy is not None: if self.policy is not None:
self.policy.init_session() self.policy.init_session(**kwargs)
if self.nlg is not None: if self.nlg is not None:
self.nlg.init_session() self.nlg.init_session()
self.history = [] self.history = []
......
...@@ -139,9 +139,9 @@ class BiSession(Session): ...@@ -139,9 +139,9 @@ class BiSession(Session):
""" """
self.sys_agent.policy.train() self.sys_agent.policy.train()
def init_session(self): def init_session(self, **kwargs):
self.sys_agent.init_session() self.sys_agent.init_session()
self.user_agent.init_session() self.user_agent.init_session(**kwargs)
if self.evaluator: if self.evaluator:
self.evaluator.add_goal(self.user_agent.policy.get_goal()) self.evaluator.add_goal(self.user_agent.policy.get_goal())
......
...@@ -456,13 +456,13 @@ ...@@ -456,13 +456,13 @@
"I also need a place to dine that is #RESTAURANT-INFORM-PRICE# priced ." "I also need a place to dine that is #RESTAURANT-INFORM-PRICE# priced ."
], ],
"Food": [ "Food": [
"How about #RESTAURANT-INFORM-FOOD# .", "How about #RESTAURANT-INFORM-FOOD# food .",
"are there any #RESTAURANT-INFORM-FOOD# restaurants ?", "are there any #RESTAURANT-INFORM-FOOD# restaurants ?",
"Hmm , I 'll try #RESTAURANT-INFORM-FOOD# .", "Hmm , I 'll try #RESTAURANT-INFORM-FOOD# food .",
"I 'd like to find a #RESTAURANT-INFORM-FOOD# restaurant , if possible .", "I 'd like to find a #RESTAURANT-INFORM-FOOD# restaurant , if possible .",
"Do you have #RESTAURANT-INFORM-FOOD# food ?", "Do you have #RESTAURANT-INFORM-FOOD# food ?",
"Yes . This restaurant should serve #RESTAURANT-INFORM-FOOD# food too .", "Yes . This restaurant should serve #RESTAURANT-INFORM-FOOD# food too .",
"I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# .", "I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# food .",
"how about a #RESTAURANT-INFORM-FOOD# restaurant ?", "how about a #RESTAURANT-INFORM-FOOD# restaurant ?",
"I would prefer #RESTAURANT-INFORM-FOOD# food please ." "I would prefer #RESTAURANT-INFORM-FOOD# food please ."
], ],
......
...@@ -111,8 +111,8 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -111,8 +111,8 @@ class UserPolicyAgendaMultiWoz(Policy):
self.agenda.close_session() self.agenda.close_session()
# A -> A' + user_action # A -> A' + user_action
action = self.agenda.get_action(random.randint(2, self.max_initiative)) # action = self.agenda.get_action(random.randint(2, self.max_initiative))
# action = self.agenda.get_action(self.max_initiative) action = self.agenda.get_action(self.max_initiative)
# transform to DA # transform to DA
action = self._transform_usract_out(action) action = self._transform_usract_out(action)
...@@ -561,6 +561,7 @@ class Agenda(object): ...@@ -561,6 +561,7 @@ class Agenda(object):
def close_session(self): def close_session(self):
""" Clear up all actions """ """ Clear up all actions """
self.__stack = [] self.__stack = []
self.__cur_push_num = 0
self.__push(self.CLOSE_ACT) self.__push(self.CLOSE_ACT)
def get_action(self, initiative=1): def get_action(self, initiative=1):
...@@ -867,9 +868,10 @@ class Agenda(object): ...@@ -867,9 +868,10 @@ class Agenda(object):
except Exception as e: except Exception as e:
break break
else: else:
if self.__cur_push_num == 0 or (all([self.__stack[-i]['value'] == DEF_VAL_DNC for i in range(1, self.__cur_push_num+1)])): if self.__cur_push_num == 0 or (all([self.__stack[-i-1]['value'] == DEF_VAL_DNC for i in
range(0, min(len(self.__stack), self.__cur_push_num))])):
# pop more when only dontcare # pop more when only dontcare
num2pop = 4 num2pop = initiative
else: else:
num2pop = self.__cur_push_num num2pop = self.__cur_push_num
for _ in range(num2pop): for _ in range(num2pop):
...@@ -905,156 +907,247 @@ if __name__ == '__main__': ...@@ -905,156 +907,247 @@ if __name__ == '__main__':
import numpy as np import numpy as np
import torch import torch
from pprint import pprint from pprint import pprint
from convlab2.dialog_agent import PipelineAgent, BiSession
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
from convlab2.dst.rule.multiwoz.dst import RuleDST
from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU
seed = 50 seed = 50
np.random.seed(seed) np.random.seed(seed)
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
user_policy = UserPolicyAgendaMultiWoz() sys_nlu = BERTNLU()
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot sys_dst = RuleDST()
sys_policy = RuleBasedMultiwozBot() sys_policy = RulePolicy()
from convlab2.nlg.template.multiwoz.nlg import TemplateNLG sys_nlg = TemplateNLG(is_user=False)
user_nlg = TemplateNLG(is_user=True, mode='manual') sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
sys_nlg = TemplateNLG(is_user=False, mode='manual')
from convlab2.dst.rule.multiwoz.dst import RuleDST user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
dst = RuleDST() model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
user_dst = None
user_policy = RulePolicy(character='usr')
user_nlg = TemplateNLG(is_user=True)
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
# evaluator = MultiWozEvaluator()
# sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)
# user_policy = UserPolicyAgendaMultiWoz()
#
# sys_policy = RuleBasedMultiwozBot()
#
# user_nlg = TemplateNLG(is_user=True, mode='manual')
# sys_nlg = TemplateNLG(is_user=False, mode='manual')
#
# dst = RuleDST()
#
# user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
# model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
#
goal_generator = GoalGenerator() goal_generator = GoalGenerator()
while True: # while True:
goal = goal_generator.get_user_goal() # goal = goal_generator.get_user_goal()
if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
break # break
# pprint(goal) # # pprint(goal)
user_goal = {'domain_ordering': ('hotel', 'restaurant', 'taxi'), user_goal = {'domain_ordering': ('restaurant', 'taxi'),
'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'}, 'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
'info': {'internet': 'yes', 'info': {'internet': 'yes',
'parking': 'no', 'parking': 'no',
'pricerange': 'moderate', 'pricerange': 'moderate',
'area': 'centre'}}, 'area': 'centre'}},
'restaurant': {'info': {'area': 'centre', 'restaurant': {'info': {'area': 'centre',
'food': 'chinese', 'food': 'portuguese',
'pricerange': 'moderate'}, 'pricerange': 'cheap'},
'reqt': ['address']}, 'fail_info': {'area': 'centre',
'food': 'portuguese',
'pricerange': 'expensive'},
'reqt': ['postcode']},
'taxi': {'info': {'arriveBy': '13:00'}, 'reqt': ['car type', 'phone']}} 'taxi': {'info': {'arriveBy': '13:00'}, 'reqt': ['car type', 'phone']}}
# user_goal = goal # # user_goal = goal
goal = Goal(goal_generator) goal = Goal(goal_generator)
goal.set_user_goal(user_goal) goal.set_user_goal(user_goal)
#
# user_policy.init_session(ini_goal=goal)
# sys_policy.init_session()
#
# goal = user_policy.get_goal()
#
# pprint(goal)
sys_response = ''
# sess.init_session(ini_goal=goal)
user_policy.init_session(ini_goal=goal) user_policy.init_session(ini_goal=goal)
sys_policy.init_session() print('init goal:')
# pprint(user_policy.get_goal())
goal = user_policy.get_goal() pprint(user_agent.policy.get_goal())
# pprint(sess.evaluator.goal)
pprint(goal) # print('-' * 50)
# for i in range(20):
# sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
# print('user:', user_response)
# print('sys:', sys_response)
# print()
# if session_over is True:
# break
# print('task success:', sess.evaluator.task_success())
# print('book rate:', sess.evaluator.book_rate())
# print('inform precision/recall/f1:', sess.evaluator.inform_F1())
# print('-' * 50)
# print('final goal:')
# pprint(sess.evaluator.goal)
# print('=' * 100)
history = []
user_utt = user_agent.response('')
print(user_utt)
user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
# history.append(['user', user_utt])
sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
sys_utt = sys_agent.response(user_utt)
pprint(sys_agent.dst.state)
print(sys_utt)
sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
# history.append(['user', user_utt])
user_utt = user_agent.response(sys_utt)
print(user_utt)
user_utt = "It just needs to be cheap ."
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
print(user_policy.agenda) user_utt = user_agent.response(sys_utt)
user_act = user_policy.predict([])
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt) print(user_utt)
state = dst.state sys_utt = sys_agent.response(user_utt)
state['user_action'] = user_act print(sys_utt)
dst.update(user_act)
# pprint(state) user_utt = user_agent.response(sys_utt)
sys_act = sys_policy.predict(state)
# sys_act.append(["Request", "Restaurant", "Price", "?"])
# sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt) print(user_utt)
state = dst.state sys_utt = sys_agent.response(user_utt)
state['user_action'] = user_act print(sys_utt)
dst.update(user_act)
# pprint(state) user_utt = user_agent.response(sys_utt)
sys_act = sys_policy.predict(state)
# sys_act = [['Inform', 'Hotel', 'Choice', '3']]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt) print(user_utt)
state = dst.state sys_utt = sys_agent.response(user_utt)
state['user_action'] = user_act print(sys_utt)
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
print(sys_act)
# #
user_act = user_policy.predict(sys_act) # print(user_policy.agenda)
print(user_act) # user_act = user_policy.predict([])
user_utt = user_nlg.generate(user_act) # print(user_act)
print(user_utt) # user_utt = user_nlg.generate(user_act)
state = dst.state # print(user_utt)
state['user_action'] = user_act # state = dst.state
dst.update(user_act) # state['user_action'] = user_act
# pprint(state) # dst.update(user_act)
sys_act = sys_policy.predict(state) # # pprint(state)
# sys_act = [["Reqmore", "General", "none", "none"]] # sys_act = sys_policy.predict(state)
print(sys_act) # sys_utt = sys_nlg.generate(sys_act)
# # sys_act.append(["Request", "Restaurant", "Price", "?"])
# # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
# print(sys_act)
# print(sys_utt)
# #
user_act = user_policy.predict(sys_act) # user_act = user_policy.predict(sys_act)
print(user_act) # print(user_act)
user_utt = user_nlg.generate(user_act) # user_utt = user_nlg.generate(user_act)
print(user_utt) # print(user_utt)
state = dst.state # state = dst.state
state['user_action'] = user_act # state['user_action'] = user_act
dst.update(user_act) # dst.update(user_act)
# pprint(state) # # pprint(state)
sys_act = sys_policy.predict(state) # sys_act = sys_policy.predict(state)
# sys_act = [["Inform", "Hotel", "Parking", "none"]] # # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
print(sys_act) # print(sys_act)
#
user_act = user_policy.predict(sys_act) #
print(user_act) # user_act = user_policy.predict(sys_act)
user_utt = user_nlg.generate(user_act) # print(user_act)
print(user_utt) # user_utt = user_nlg.generate(user_act)
state = dst.state # print(user_utt)
state['user_action'] = user_act # state = dst.state
dst.update(user_act) # state['user_action'] = user_act
# pprint(state) # dst.update(user_act)
sys_act = sys_policy.predict(state) # # pprint(state)
# sys_act = [["Request", "Booking", "people", "?"]] # sys_act = sys_policy.predict(state)
print(sys_act) # # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
# print(sys_act)
user_act = user_policy.predict(sys_act) # #
print(user_act) # user_act = user_policy.predict(sys_act)
user_utt = user_nlg.generate(user_act) # print(user_act)
print(user_utt) # user_utt = user_nlg.generate(user_act)
state = dst.state # print(user_utt)
state['user_action'] = user_act # state = dst.state
dst.update(user_act) # state['user_action'] = user_act
# pprint(state) # dst.update(user_act)
sys_act = sys_policy.predict(state) # # pprint(state)
# sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]] # sys_act = sys_policy.predict(state)
print(sys_act) # # sys_act = [["Reqmore", "General", "none", "none"]]
# print(sys_act)
user_act = user_policy.predict(sys_act) # #
print(user_act) # user_act = user_policy.predict(sys_act)
user_utt = user_nlg.generate(user_act) # print(user_act)
print(user_utt) # user_utt = user_nlg.generate(user_act)
state = dst.state # print(user_utt)
state['user_action'] = user_act # state = dst.state
dst.update(user_act) # state['user_action'] = user_act
# pprint(state) # dst.update(user_act)
sys_act = sys_policy.predict(state) # # pprint(state)
sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]] # sys_act = sys_policy.predict(state)
print(sys_act) # # sys_act = [["Inform", "Hotel", "Parking", "none"]]
# print(sys_act)
user_act = user_policy.predict(sys_act) #
print(user_act) # user_act = user_policy.predict(sys_act)
user_utt = user_nlg.generate(user_act) # print(user_act)
print(user_utt) # user_utt = user_nlg.generate(user_act)
state = dst.state # print(user_utt)
state['user_action'] = user_act # state = dst.state
dst.update(user_act) # state['user_action'] = user_act
# pprint(state) # dst.update(user_act)
sys_act = sys_policy.predict(state) # # pprint(state)
# sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]] # sys_act = sys_policy.predict(state)
print(sys_act) # # sys_act = [["Request", "Booking", "people", "?"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]]
# print(sys_act)
...@@ -30,11 +30,11 @@ class RulePolicy(Policy): ...@@ -30,11 +30,11 @@ class RulePolicy(Policy):
""" """
return self.policy.predict(state) return self.policy.predict(state)
def init_session(self): def init_session(self, **kwargs):
""" """
Restore after one session Restore after one session
""" """
self.policy.init_session() self.policy.init_session(**kwargs)
def is_terminated(self): def is_terminated(self):
if self.character == 'sys': if self.character == 'sys':
......
...@@ -225,16 +225,16 @@ class RuleBasedMultiwozBot(Policy): ...@@ -225,16 +225,16 @@ class RuleBasedMultiwozBot(Policy):
slot_name = REF_USR_DA[domain].get(slot, slot) slot_name = REF_USR_DA[domain].get(slot, slot)
DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]]) DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]])
p = random.random() # p = random.random()
# Ask user if he wants to change constraint # # Ask user if he wants to change constraint
if p < 0.3: # if p < 0.3:
req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) # req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
if domain + "-Request" not in DA: # if domain + "-Request" not in DA:
DA[domain + "-Request"] = [] # DA[domain + "-Request"] = []
for i in range(req_num): # for i in range(req_num):
slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) # slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
DA[domain + "-Request"].append([slot_name, "?"]) # DA[domain + "-Request"].append([slot_name, "?"])
# There's exactly one result matching user's constraint # There's exactly one result matching user's constraint
# elif len(state['kb_results_dict']) == 1: # elif len(state['kb_results_dict']) == 1:
......
...@@ -67,7 +67,7 @@ def test_end2end(): ...@@ -67,7 +67,7 @@ def test_end2end():
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz') analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202) set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-RulePolicy-TemplateNLG', total_dialog=1000) analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000)
if __name__ == '__main__': if __name__ == '__main__':
test_end2end() test_end2end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment