Skip to content
Snippets Groups Projects
Commit b190c50a authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

add **kwargs in init_session for self-defined goal; remove request for...

add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy
parent 98b50a81
Branches
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ class Agent(ABC):
pass
@abstractmethod
def init_session(self):
def init_session(self, **kwargs):
"""Reset the class variables to prepare for a new session."""
pass
......@@ -140,7 +140,7 @@ class PipelineAgent(Agent):
return self.policy.get_reward()
return None
def init_session(self):
def init_session(self, **kwargs):
"""Init the attributes of DST and Policy module."""
if self.nlu is not None:
self.nlu.init_session()
......@@ -149,7 +149,7 @@ class PipelineAgent(Agent):
if self.name == 'sys':
self.dst.state['history'].append([self.name, 'null'])
if self.policy is not None:
self.policy.init_session()
self.policy.init_session(**kwargs)
if self.nlg is not None:
self.nlg.init_session()
self.history = []
......
......@@ -139,9 +139,9 @@ class BiSession(Session):
"""
self.sys_agent.policy.train()
def init_session(self):
def init_session(self, **kwargs):
self.sys_agent.init_session()
self.user_agent.init_session()
self.user_agent.init_session(**kwargs)
if self.evaluator:
self.evaluator.add_goal(self.user_agent.policy.get_goal())
......
......@@ -456,13 +456,13 @@
"I also need a place to dine that is #RESTAURANT-INFORM-PRICE# priced ."
],
"Food": [
"How about #RESTAURANT-INFORM-FOOD# .",
"How about #RESTAURANT-INFORM-FOOD# food .",
"are there any #RESTAURANT-INFORM-FOOD# restaurants ?",
"Hmm , I 'll try #RESTAURANT-INFORM-FOOD# .",
"Hmm , I 'll try #RESTAURANT-INFORM-FOOD# food .",
"I 'd like to find a #RESTAURANT-INFORM-FOOD# restaurant , if possible .",
"Do you have #RESTAURANT-INFORM-FOOD# food ?",
"Yes . This restaurant should serve #RESTAURANT-INFORM-FOOD# food too .",
"I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# .",
"I ' m visiting Cambridge and would like some suggestions for an restaurant which serves #RESTAURANT-INFORM-FOOD# food .",
"how about a #RESTAURANT-INFORM-FOOD# restaurant ?",
"I would prefer #RESTAURANT-INFORM-FOOD# food please ."
],
......
......@@ -111,8 +111,8 @@ class UserPolicyAgendaMultiWoz(Policy):
self.agenda.close_session()
# A -> A' + user_action
action = self.agenda.get_action(random.randint(2, self.max_initiative))
# action = self.agenda.get_action(self.max_initiative)
# action = self.agenda.get_action(random.randint(2, self.max_initiative))
action = self.agenda.get_action(self.max_initiative)
# transform to DA
action = self._transform_usract_out(action)
......@@ -561,6 +561,7 @@ class Agenda(object):
def close_session(self):
""" Clear up all actions """
self.__stack = []
self.__cur_push_num = 0
self.__push(self.CLOSE_ACT)
def get_action(self, initiative=1):
......@@ -867,9 +868,10 @@ class Agenda(object):
except Exception as e:
break
else:
if self.__cur_push_num == 0 or (all([self.__stack[-i]['value'] == DEF_VAL_DNC for i in range(1, self.__cur_push_num+1)])):
if self.__cur_push_num == 0 or (all([self.__stack[-i-1]['value'] == DEF_VAL_DNC for i in
range(0, min(len(self.__stack), self.__cur_push_num))])):
# pop more when only dontcare
num2pop = 4
num2pop = initiative
else:
num2pop = self.__cur_push_num
for _ in range(num2pop):
......@@ -905,156 +907,247 @@ if __name__ == '__main__':
import numpy as np
import torch
from pprint import pprint
from convlab2.dialog_agent import PipelineAgent, BiSession
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
from convlab2.dst.rule.multiwoz.dst import RuleDST
from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU
seed = 50
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
user_policy = UserPolicyAgendaMultiWoz()
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
sys_policy = RuleBasedMultiwozBot()
from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
user_nlg = TemplateNLG(is_user=True, mode='manual')
sys_nlg = TemplateNLG(is_user=False, mode='manual')
from convlab2.dst.rule.multiwoz.dst import RuleDST
dst = RuleDST()
sys_nlu = BERTNLU()
sys_dst = RuleDST()
sys_policy = RulePolicy()
sys_nlg = TemplateNLG(is_user=False)
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')
user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
user_dst = None
user_policy = RulePolicy(character='usr')
user_nlg = TemplateNLG(is_user=True)
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
# evaluator = MultiWozEvaluator()
# sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)
# user_policy = UserPolicyAgendaMultiWoz()
#
# sys_policy = RuleBasedMultiwozBot()
#
# user_nlg = TemplateNLG(is_user=True, mode='manual')
# sys_nlg = TemplateNLG(is_user=False, mode='manual')
#
# dst = RuleDST()
#
# user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
# model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
#
goal_generator = GoalGenerator()
while True:
goal = goal_generator.get_user_goal()
if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
break
# pprint(goal)
user_goal = {'domain_ordering': ('hotel', 'restaurant', 'taxi'),
# while True:
# goal = goal_generator.get_user_goal()
# if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
# break
# # pprint(goal)
user_goal = {'domain_ordering': ('restaurant', 'taxi'),
'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
'info': {'internet': 'yes',
'parking': 'no',
'pricerange': 'moderate',
'area': 'centre'}},
'restaurant': {'info': {'area': 'centre',
'food': 'chinese',
'pricerange': 'moderate'},
'reqt': ['address']},
'food': 'portuguese',
'pricerange': 'cheap'},
'fail_info': {'area': 'centre',
'food': 'portuguese',
'pricerange': 'expensive'},
'reqt': ['postcode']},
'taxi': {'info': {'arriveBy': '13:00'}, 'reqt': ['car type', 'phone']}}
# user_goal = goal
# # user_goal = goal
goal = Goal(goal_generator)
goal.set_user_goal(user_goal)
#
# user_policy.init_session(ini_goal=goal)
# sys_policy.init_session()
#
# goal = user_policy.get_goal()
#
# pprint(goal)
sys_response = ''
# sess.init_session(ini_goal=goal)
user_policy.init_session(ini_goal=goal)
sys_policy.init_session()
goal = user_policy.get_goal()
pprint(goal)
print('init goal:')
# pprint(user_policy.get_goal())
pprint(user_agent.policy.get_goal())
# pprint(sess.evaluator.goal)
# print('-' * 50)
# for i in range(20):
# sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
# print('user:', user_response)
# print('sys:', sys_response)
# print()
# if session_over is True:
# break
# print('task success:', sess.evaluator.task_success())
# print('book rate:', sess.evaluator.book_rate())
# print('inform precision/recall/f1:', sess.evaluator.inform_F1())
# print('-' * 50)
# print('final goal:')
# pprint(sess.evaluator.goal)
# print('=' * 100)
history = []
user_utt = user_agent.response('')
print(user_utt)
user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
# history.append(['user', user_utt])
sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
sys_utt = sys_agent.response(user_utt)
pprint(sys_agent.dst.state)
print(sys_utt)
sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
# history.append(['user', user_utt])
user_utt = user_agent.response(sys_utt)
print(user_utt)
user_utt = "It just needs to be cheap ."
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
print(user_policy.agenda)
user_act = user_policy.predict([])
print(user_act)
user_utt = user_nlg.generate(user_act)
user_utt = user_agent.response(sys_utt)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act.append(["Request", "Restaurant", "Price", "?"])
# sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
user_utt = user_agent.response(sys_utt)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [['Inform', 'Hotel', 'Choice', '3']]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
user_utt = user_agent.response(sys_utt)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
print(sys_act)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
#
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Reqmore", "General", "none", "none"]]
print(sys_act)
# print(user_policy.agenda)
# user_act = user_policy.predict([])
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# sys_utt = sys_nlg.generate(sys_act)
# # sys_act.append(["Request", "Restaurant", "Price", "?"])
# # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
# print(sys_act)
# print(sys_utt)
#
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Inform", "Hotel", "Parking", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Request", "Booking", "people", "?"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
# sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]]
print(sys_act)
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
# print(sys_act)
#
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
# print(sys_act)
# #
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Reqmore", "General", "none", "none"]]
# print(sys_act)
# #
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Inform", "Hotel", "Parking", "none"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Request", "Booking", "people", "?"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# sys_act = [["Request", "Taxi", "Dest", "?"], ["Request", "Taxi", "Depart", "?"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Request", "Taxi", "Destination", "?"], ["Request", "Taxi", "Departure", "?"]]
# print(sys_act)
......@@ -30,11 +30,11 @@ class RulePolicy(Policy):
"""
return self.policy.predict(state)
def init_session(self):
def init_session(self, **kwargs):
"""
Restore after one session
"""
self.policy.init_session()
self.policy.init_session(**kwargs)
def is_terminated(self):
if self.character == 'sys':
......
......@@ -225,16 +225,16 @@ class RuleBasedMultiwozBot(Policy):
slot_name = REF_USR_DA[domain].get(slot, slot)
DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]])
p = random.random()
# Ask user if he wants to change constraint
if p < 0.3:
req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
if domain + "-Request" not in DA:
DA[domain + "-Request"] = []
for i in range(req_num):
slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
DA[domain + "-Request"].append([slot_name, "?"])
# p = random.random()
# # Ask user if he wants to change constraint
# if p < 0.3:
# req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
# if domain + "-Request" not in DA:
# DA[domain + "-Request"] = []
# for i in range(req_num):
# slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
# DA[domain + "-Request"].append([slot_name, "?"])
# There's exactly one result matching user's constraint
# elif len(state['kb_results_dict']) == 1:
......
......@@ -67,7 +67,7 @@ def test_end2end():
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-RulePolicy-TemplateNLG', total_dialog=1000)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000)
if __name__ == '__main__':
test_end2end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment