Skip to content
Snippets Groups Projects
Commit 37f2df82 authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

parent cd692f19
Branches
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import random import random
import os import os
from pprint import pprint from pprint import pprint
import collections
from convlab2.nlg import NLG from convlab2.nlg import NLG
...@@ -67,6 +68,32 @@ class TemplateNLG(NLG): ...@@ -67,6 +68,32 @@ class TemplateNLG(NLG):
self.manual_user_template = read_json(os.path.join(template_dir, 'manual_user_template_nlg.json')) self.manual_user_template = read_json(os.path.join(template_dir, 'manual_user_template_nlg.json'))
self.manual_system_template = read_json(os.path.join(template_dir, 'manual_system_template_nlg.json')) self.manual_system_template = read_json(os.path.join(template_dir, 'manual_system_template_nlg.json'))
def sorted_dialog_act(self, dialog_acts):
new_action_group = {}
for item in dialog_acts:
intent, domain, slot, value = item
if domain not in new_action_group:
new_action_group[domain] = {'nooffer': [], 'inform-name': [], 'inform-other': [], 'request': [], 'other': []}
if intent == 'NoOffer':
new_action_group[domain]['nooffer'].append(item)
elif intent == 'Inform' and slot == 'Name':
new_action_group[domain]['inform-name'].append(item)
elif intent == 'Inform':
new_action_group[domain]['inform-other'].append(item)
elif intent == 'request':
new_action_group[domain]['request'].append(item)
else:
new_action_group[domain]['other'].append(item)
new_action = []
if 'general' in new_action_group:
new_action += new_action_group['general']['other']
del new_action_group['general']
for domain in new_action_group:
for k in ['nooffer', 'inform-name', 'inform-other', 'request', 'other']:
new_action = new_action_group[domain][k] + new_action
return new_action
def generate(self, dialog_acts): def generate(self, dialog_acts):
"""NLG for Multiwoz dataset """NLG for Multiwoz dataset
...@@ -75,7 +102,8 @@ class TemplateNLG(NLG): ...@@ -75,7 +102,8 @@ class TemplateNLG(NLG):
Returns: Returns:
generated sentence generated sentence
""" """
action = {} dialog_acts = self.sorted_dialog_act(dialog_acts)
action = collections.OrderedDict()
for intent, domain, slot, value in dialog_acts: for intent, domain, slot, value in dialog_acts:
k = '-'.join([domain, intent]) k = '-'.join([domain, intent])
action.setdefault(k, []) action.setdefault(k, [])
...@@ -165,9 +193,21 @@ class TemplateNLG(NLG): ...@@ -165,9 +193,21 @@ class TemplateNLG(NLG):
for slot, value in slot_value_pairs: for slot, value in slot_value_pairs:
if value in ["do nt care", "do n't care", "dontcare"]: if value in ["do nt care", "do n't care", "dontcare"]:
sentence = 'I don\'t care about the {} of the {}'.format(slot.lower(), dialog_act.split('-')[0].lower()) sentence = 'I don\'t care about the {} of the {}'.format(slot.lower(), dialog_act.split('-')[0].lower())
elif self.is_user and dialog_act.split('-')[1] == 'Inform' and slot == 'Choice' and value == 'any':
# user have no preference, any choice is ok
sentence = random.choice([
"Please pick one for me. ",
"Anyone would be ok. ",
"Just select one for me. "
])
elif dialog_act in template and slot in template[dialog_act]: elif dialog_act in template and slot in template[dialog_act]:
sentence = random.choice(template[dialog_act][slot]) sentence = random.choice(template[dialog_act][slot])
sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value)) sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value))
elif slot == 'NotBook':
sentence = random.choice([
"I do not need to book. ",
"I 'm not looking to make a booking at the moment."
])
else: else:
if slot in slot2word: if slot in slot2word:
sentence = 'The {} is {} . '.format(slot2word[slot], str(value)) sentence = 'The {} is {} . '.format(slot2word[slot], str(value))
...@@ -201,7 +241,7 @@ class TemplateNLG(NLG): ...@@ -201,7 +241,7 @@ class TemplateNLG(NLG):
def example(): def example():
# dialog act # dialog act
dialog_acts = [['Inform', 'Train', 'Day', 'wednesday'], ['Inform', 'Train', 'Leave', '10:15']] dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Name', 'fds'], ['welcome', 'general', 'none', 'none']]
print(dialog_acts) print(dialog_acts)
# system model for manual, auto, auto_manual # system model for manual, auto, auto_manual
......
...@@ -74,11 +74,11 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -74,11 +74,11 @@ class UserPolicyAgendaMultiWoz(Policy):
self.domain_goals = self.goal.domain_goals self.domain_goals = self.goal.domain_goals
self.agenda = Agenda(self.goal) self.agenda = Agenda(self.goal)
def predict(self, state): def predict(self, sys_dialog_act):
""" """
Predict an user act based on state and preorder system action. Predict an user act based on state and preorder system action.
Args: Args:
state (tuple): Dialog state. sys_dialog_act (list): system dialogue act: [[intent, domain, slot, value],...].
Returns: Returns:
action (tuple): User act. action (tuple): User act.
session_over (boolean): True to terminate session, otherwise session continues. session_over (boolean): True to terminate session, otherwise session continues.
...@@ -86,10 +86,10 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -86,10 +86,10 @@ class UserPolicyAgendaMultiWoz(Policy):
""" """
self.__turn += 2 self.__turn += 2
assert isinstance(state, list) assert isinstance(sys_dialog_act, list)
sys_action = {} sys_action = {}
for intent, domain, slot, value in state: for intent, domain, slot, value in sys_dialog_act:
k = '-'.join([domain, intent]) k = '-'.join([domain, intent])
sys_action.setdefault(k,[]) sys_action.setdefault(k,[])
sys_action[k].append([slot, value]) sys_action[k].append([slot, value])
...@@ -143,29 +143,35 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -143,29 +143,35 @@ class UserPolicyAgendaMultiWoz(Policy):
@classmethod @classmethod
def _transform_usract_out(cls, action): def _transform_usract_out(cls, action):
# print('before transform', action)
new_action = {} new_action = {}
for act in action.keys(): for act in action.keys():
if '-' in act: if '-' in act:
if 'general' not in act: if 'general' not in act:
(dom, intent) = act.split('-') (dom, intent) = act.split('-')
new_act = dom.capitalize() + '-' + intent.capitalize() new_act = dom.capitalize() + '-' + intent.capitalize()
if action[act] == [['none', 'none']]:
new_action[new_act] = [['none', 'none']]
continue
new_action[new_act] = [] new_action[new_act] = []
for pairs in action[act]: for pairs in action[act]:
slot = REF_USR_DA_M[dom.capitalize()].get(pairs[0], None) slot = REF_USR_DA_M[dom.capitalize()].get(pairs[0], None)
if slot is not None: if pairs[0] == 'none' and pairs[1] == 'none':
new_action[new_act].append(['none', 'none'])
elif pairs[0] == 'choice' and pairs[1] == 'any':
new_action[new_act].append(['Choice', 'any'])
elif pairs[0] == 'NotBook' and pairs[1] == 'none':
new_action[new_act].append(['NotBook', 'none'])
elif slot is not None:
new_action[new_act].append([slot, pairs[1]]) new_action[new_act].append([slot, pairs[1]])
# new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]] # new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]]
else: else:
new_action[act] = action[act] new_action[act] = action[act]
else: else:
pass pass
# print('after transform', new_action)
return new_action return new_action
@classmethod @classmethod
def _transform_sysact_in(cls, action): def _transform_sysact_in(cls, action):
# print("sys in", action)
new_action = {} new_action = {}
if not isinstance(action, dict): if not isinstance(action, dict):
logging.warning('illegal da: {}'.format(action)) logging.warning('illegal da: {}'.format(action))
...@@ -195,7 +201,7 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -195,7 +201,7 @@ class UserPolicyAgendaMultiWoz(Policy):
new_action[act.lower()] = new_list new_action[act.lower()] = new_list
else: else:
new_action[act.lower()] = action[act] new_action[act.lower()] = action[act]
# print("sys in transformed", new_action)
return new_action return new_action
@classmethod @classmethod
...@@ -209,6 +215,9 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -209,6 +215,9 @@ class UserPolicyAgendaMultiWoz(Policy):
if slot not in cls.stand_value_dict[domain]: if slot not in cls.stand_value_dict[domain]:
return value return value
if slot in ['parking', 'internet'] and value == 'none':
return 'yes'
value_list = cls.stand_value_dict[domain][slot] value_list = cls.stand_value_dict[domain][slot]
low_value_list = [item.lower() for item in value_list] low_value_list = [item.lower() for item in value_list]
value_list = sorted(list(set(value_list)|set(low_value_list))) value_list = sorted(list(set(value_list)|set(low_value_list)))
...@@ -386,6 +395,8 @@ class Agenda(object): ...@@ -386,6 +395,8 @@ class Agenda(object):
len(goal.domain_goals[domain]['info'])): len(goal.domain_goals[domain]['info'])):
self.__push(domain + '-inform', slot, goal.domain_goals[domain]['info'][slot]) self.__push(domain + '-inform', slot, goal.domain_goals[domain]['info'][slot])
self.__push(domain + '-inform', "none", "none")
self.cur_domain = None self.cur_domain = None
def update(self, sys_action, goal: Goal): def update(self, sys_action, goal: Goal):
...@@ -422,6 +433,21 @@ class Agenda(object): ...@@ -422,6 +433,21 @@ class Agenda(object):
if self.update_domain(diaact, slot_vals, goal): if self.update_domain(diaact, slot_vals, goal):
return return
for diaact in sys_action.keys():
if 'inform' in diaact or 'recommend' in diaact:
for slot, val in sys_action[diaact]:
if slot == 'name':
self._remove_item(diaact.split('-')[0]+'-inform', 'choice')
if 'booking' in diaact:
g_book = self._get_goal_infos(self.cur_domain, goal)[-2]
if len(g_book) == 0:
self._push_item(self.cur_domain + '-inform', "NotBook", "none")
if 'OfferBook' in diaact:
domain = diaact.split('-')[0]
g_book = self._get_goal_infos(domain, goal)[-2]
if len(g_book) == 0:
self._push_item(domain + '-inform', "NotBook", "none")
self.post_process(goal) self.post_process(goal)
def post_process(self, goal: Goal): def post_process(self, goal: Goal):
...@@ -517,6 +543,7 @@ class Agenda(object): ...@@ -517,6 +543,7 @@ class Agenda(object):
Returns: Returns:
action (dict): user diaact action (dict): user diaact
""" """
# print(self)
diaacts, slots, values = self.__pop(initiative) diaacts, slots, values = self.__pop(initiative)
action = {} action = {}
for (diaact, slot, value) in zip(diaacts, slots, values): for (diaact, slot, value) in zip(diaacts, slots, values):
...@@ -563,6 +590,10 @@ class Agenda(object): ...@@ -563,6 +590,10 @@ class Agenda(object):
logging.warning('illegal booking slot: {}, domain: {}'.format(slot, domain)) logging.warning('illegal booking slot: {}, domain: {}'.format(slot, domain))
continue continue
# For multiple choices, add new intent to select one:
if slot == 'choice':
self._push_item(domain + '-inform', "choice", "any")
if slot in g_reqt: if slot in g_reqt:
if not self._check_reqt_info(domain): if not self._check_reqt_info(domain):
self._remove_item(domain + '-request', slot) self._remove_item(domain + '-request', slot)
...@@ -628,10 +659,13 @@ class Agenda(object): ...@@ -628,10 +659,13 @@ class Agenda(object):
goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS: goal.domain_goals[places[-2]]['reqt']['address'] not in NOT_SURE_VALS:
self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address']) self._push_item(domain + '-inform', slot, goal.domain_goals[places[-2]]['reqt']['address'])
elif random.random() < 0.5: # elif random.random() < 0.5:
self._push_item(domain + '-inform', slot, DEF_VAL_DNC) # self._push_item(domain + '-inform', slot, DEF_VAL_DNC)
elif random.random() < 0.5: # elif random.random() < 0.5:
# self._push_item(domain + '-inform', slot, DEF_VAL_DNC)
# for those sys requests that are not in user goal
self._push_item(domain + '-inform', slot, DEF_VAL_DNC) self._push_item(domain + '-inform', slot, DEF_VAL_DNC)
return False return False
...@@ -668,6 +702,9 @@ class Agenda(object): ...@@ -668,6 +702,9 @@ class Agenda(object):
def _handle_select(self, domain, intent, slot_vals, goal: Goal): def _handle_select(self, domain, intent, slot_vals, goal: Goal):
g_reqt, g_info, g_fail_info, g_book, g_fail_book = self._get_goal_infos(domain, goal) g_reqt, g_info, g_fail_info, g_book, g_fail_book = self._get_goal_infos(domain, goal)
# delete Choice # delete Choice
for slot, val in slot_vals:
if slot == 'choice':
self._push_item(domain + '-inform', "choice", "any")
slot_vals = [[slot, val] for [slot, val] in slot_vals if slot != 'choice'] slot_vals = [[slot, val] for [slot, val] in slot_vals if slot != 'choice']
if slot_vals: if slot_vals:
...@@ -814,3 +851,83 @@ class Agenda(object): ...@@ -814,3 +851,83 @@ class Agenda(object):
text += '<stack btm>\n' text += '<stack btm>\n'
text += '-----agenda-----\n' text += '-----agenda-----\n'
return text return text
if __name__ == '__main__':
import numpy as np
import torch
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
user_policy = UserPolicyAgendaMultiWoz()
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
sys_policy = RuleBasedMultiwozBot()
from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
user_nlg = TemplateNLG(is_user=True, mode='manual')
sys_nlg = TemplateNLG(is_user=False, mode='manual')
from convlab2.util.multiwoz.state import default_state
user_policy.init_session()
sys_policy.init_session()
print(user_policy.goal)
print(user_policy.agenda)
user_act = user_policy.predict([])
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
state = default_state()
state['user_action'] = user_act
sys_act = sys_policy.predict(state)
sys_act.append(["Request", "Restaurant", "Price", "?"])
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Reqmore", "General", "none", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Inform", "Hotel", "Parking", "none"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Request", "Booking", "people", "?"]]
print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
sys_act = sys_policy.predict(state)
sys_act = [["Inform", "Hotel", "Post", "233"], ["Book", "Booking", "none", "none"]]
print(sys_act)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment