Skip to content
Snippets Groups Projects
Commit e38286f1 authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

can manually set user goal in agenda now

parent db2b6aab
Branches
No related tags found
No related merge requests found
......@@ -67,10 +67,13 @@ class UserPolicyAgendaMultiWoz(Policy):
def reset_turn(self):
self.__turn = 0
def init_session(self):
def init_session(self, ini_goal=None):
""" Build new Goal and Agenda for next session """
self.reset_turn()
if not ini_goal:
self.goal = Goal(self.goal_generator)
else:
self.goal = ini_goal
self.domain_goals = self.goal.domain_goals
self.agenda = Agenda(self.goal)
......@@ -310,7 +313,7 @@ class Goal(object):
"""
create new Goal by random
Args:
goal_generator (GoalGenerator): Goal Gernerator.
goal_generator (GoalGenerator): Goal Generator.
"""
self.domain_goals = goal_generator.get_user_goal()
......@@ -324,6 +327,25 @@ class Goal(object):
if 'book' in self.domain_goals[domain].keys():
self.domain_goals[domain]['booked'] = DEF_VAL_UNK
def set_user_goal(self, user_goal):
"""
set new Goal given user goal generated by goal_generator.get_user_goal()
Args:
user_goal : user goal generated by GoalGenerator.
"""
self.domain_goals = user_goal
self.domains = list(self.domain_goals['domain_ordering'])
del self.domain_goals['domain_ordering']
for domain in self.domains:
if 'reqt' in self.domain_goals[domain].keys():
self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']}
if 'book' in self.domain_goals[domain].keys():
self.domain_goals[domain]['booked'] = DEF_VAL_UNK
def task_complete(self):
"""
Check that all requests have been met
......@@ -861,9 +883,11 @@ class Agenda(object):
if __name__ == '__main__':
import numpy as np
import torch
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
from pprint import pprint
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
user_policy = UserPolicyAgendaMultiWoz()
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
......@@ -873,72 +897,92 @@ if __name__ == '__main__':
sys_nlg = TemplateNLG(is_user=False, mode='manual')
from convlab2.util.multiwoz.state import default_state
user_policy.init_session()
goal_generator = GoalGenerator()
goal = goal_generator.get_user_goal()
pprint(goal)
user_goal = {'domain_ordering': ('restaurant', 'hotel', 'attraction`'),
'attraction': {'info': {'area': 'west', 'type': 'museum'}, 'reqt': ['phone']},
'hotel': {'book': {'day': 'saturday', 'people': '5', 'stay': '4'},
'fail_info': {'area': 'north',
'internet': 'yes',
'pricerange': 'expensive',
'stars': '3'},
'info': {'area': 'north',
'internet': 'yes',
'pricerange': 'moderate',
'stars': '3'}},
'restaurant': {'book': {'day': 'saturday', 'people': '5', 'time': '18:00'},
'info': {'area': 'centre', 'food': 'british'},
'reqt': ['address']}}
goal = Goal(goal_generator)
goal.set_user_goal(user_goal)
user_policy.init_session(ini_goal=goal)
sys_policy.init_session()
goal = user_policy.get_goal()
print(goal)
pprint(goal)
print(user_policy.agenda)
user_act = user_policy.predict([])
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = default_state()
state['user_action'] = user_act
sys_act = sys_policy.predict(state)
sys_act.append(["Request", "Restaurant", "Price", "?"])
sys_act = [['Inform', 'Hotel', 'Choice', '0']]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
print(goal)
sys_act = sys_policy.predict(state)
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Reqmore", "General", "none", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Inform", "Hotel", "Parking", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Request", "Booking", "people", "?"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
print(sys_act)
# user_act = user_policy.predict([])
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# state = default_state()
# state['user_action'] = user_act
# sys_act = sys_policy.predict(state)
# sys_act.append(["Request", "Restaurant", "Price", "?"])
# sys_act = [['Inform', 'Hotel', 'Choice', '0']]
# print(sys_act)
#
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
#
# print(goal)
#
# sys_act = sys_policy.predict(state)
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = sys_policy.predict(state)
# sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = sys_policy.predict(state)
# sys_act = [["Reqmore", "General", "none", "none"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = sys_policy.predict(state)
# sys_act = [["Inform", "Hotel", "Parking", "none"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = sys_policy.predict(state)
# sys_act = [["Request", "Booking", "people", "?"]]
# print(sys_act)
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = sys_policy.predict(state)
# sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
# print(sys_act)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment