Skip to content
Snippets Groups Projects
Unverified Commit 61719a28 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Fix simulator (#83)

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* remove police and hospital domain in goal generator

* update multiwoz evaluator: adding 'internet/parking-none, 24:** to valid value
parent d0461611
Branches
No related tags found
No related merge requests found
...@@ -30,7 +30,7 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na ...@@ -30,7 +30,7 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na
'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'}, 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}} 'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
time_re = re.compile(r'^(([01]\d|2[0-3]):([0-5]\d)|24:00)$') time_re = re.compile(r'^(([01]\d|2[0-4]):([0-5]\d)|24:00)$')
NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"] NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]
class MultiWozEvaluator(Evaluator): class MultiWozEvaluator(Evaluator):
...@@ -229,7 +229,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -229,7 +229,7 @@ class MultiWozEvaluator(Evaluator):
elif key == "duration": elif key == "duration":
return 'minute' in value return 'minute' in value
elif key == "internet" or key == "parking": elif key == "internet" or key == "parking":
return value in ["yes", "no"] return value in ["yes", "no", "none"]
elif key == "phone": elif key == "phone":
return re.match(r'^\d{11}$', value) or domain == "restaurant" return re.match(r'^\d{11}$', value) or domain == "restaurant"
elif key == "price": elif key == "price":
......
...@@ -135,7 +135,7 @@ class GoalGenerator: ...@@ -135,7 +135,7 @@ class GoalGenerator:
"""User goal generator.""" """User goal generator."""
def __init__(self, def __init__(self,
goal_model_path=os.path.join(get_root_path(), 'data/multiwoz/goal/new_goal_model.pkl'), goal_model_path=os.path.join(get_root_path(), 'data/multiwoz/goal/new_goal_model_no_police_hospital.pkl'),
corpus_path=None, corpus_path=None,
boldify=False, boldify=False,
sample_info_from_trainset=True, sample_info_from_trainset=True,
...@@ -163,13 +163,13 @@ class GoalGenerator: ...@@ -163,13 +163,13 @@ class GoalGenerator:
self._build_goal_model() self._build_goal_model()
print('Building goal model is done') print('Building goal model is done')
# remove some slot # remove some slot (now no police and hospital domains)
del self.ind_slot_dist['police']['reqt']['postcode'] # del self.ind_slot_dist['police']['reqt']['postcode']
del self.ind_slot_value_dist['police']['reqt']['postcode'] # del self.ind_slot_value_dist['police']['reqt']['postcode']
del self.ind_slot_dist['hospital']['reqt']['postcode'] # del self.ind_slot_dist['hospital']['reqt']['postcode']
del self.ind_slot_value_dist['hospital']['reqt']['postcode'] # del self.ind_slot_value_dist['hospital']['reqt']['postcode']
del self.ind_slot_dist['hospital']['reqt']['address'] # del self.ind_slot_dist['hospital']['reqt']['address']
del self.ind_slot_value_dist['hospital']['reqt']['address'] # del self.ind_slot_value_dist['hospital']['reqt']['address']
# print(self.slots_combination_dist['police']) # print(self.slots_combination_dist['police'])
# print(self.slots_combination_dist['hospital']) # print(self.slots_combination_dist['hospital'])
...@@ -187,6 +187,8 @@ class GoalGenerator: ...@@ -187,6 +187,8 @@ class GoalGenerator:
domain_orderings = [] domain_orderings = []
for d in dialogs: for d in dialogs:
d_domains = _get_dialog_domains(dialogs[d]) d_domains = _get_dialog_domains(dialogs[d])
if 'police' in d_domains or 'hospital' in d_domains:
continue
first_index = [] first_index = []
for domain in d_domains: for domain in d_domains:
message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \ message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \
...@@ -209,6 +211,9 @@ class GoalGenerator: ...@@ -209,6 +211,9 @@ class GoalGenerator:
self.slots_num_dist = {domain: {} for domain in domains} self.slots_num_dist = {domain: {} for domain in domains}
for d in dialogs: for d in dialogs:
d_domains = _get_dialog_domains(dialogs[d])
if 'police' in d_domains or 'hospital' in d_domains:
continue
for domain in domains: for domain in domains:
if dialogs[d]['goal'][domain] != {}: if dialogs[d]['goal'][domain] != {}:
domain_cnt[domain] += 1 domain_cnt[domain] += 1
...@@ -735,4 +740,5 @@ class GoalGenerator: ...@@ -735,4 +740,5 @@ class GoalGenerator:
if __name__ == '__main__': if __name__ == '__main__':
goal_generator = GoalGenerator(corpus_path=os.path.join(get_root_path(), 'data/multiwoz/train.json'), sample_reqt_from_trainset=True) goal_generator = GoalGenerator(corpus_path=os.path.join(get_root_path(), 'data/multiwoz/train.json'), sample_reqt_from_trainset=True)
# goal_generator._build_goal_model()
pprint(goal_generator.get_user_goal()) pprint(goal_generator.get_user_goal())
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment