Skip to content
Snippets Groups Projects
Commit 50424c34 authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

add template for interent-no, parking-no in templatenlg

parent b190c50a
No related branches found
No related tags found
No related merge requests found
......@@ -207,6 +207,11 @@ class TemplateNLG(NLG):
"I would prefer something that is {} .".format(value),
"it needs to be {} .".format(value)
])
elif slot in ['Internet', 'Parking'] and value == 'no':
sentence = random.choice([
"It does n't need to have {} .".format(slot.lower()),
"I do n't need free {} .".format(slot.lower()),
])
elif dialog_act in template and slot in template[dialog_act]:
sentence = random.choice(template[dialog_act][slot])
sentence = sentence.replace('#{}-{}#'.format(dialog_act.upper(), slot.upper()), str(value))
......@@ -248,7 +253,7 @@ class TemplateNLG(NLG):
def example():
# dialog act
dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Name', 'fds'], ['welcome', 'general', 'none', 'none']]
dialog_acts = [['Inform', 'Hotel', 'Area', 'east'],['Inform', 'Hotel', 'Internet', 'no'], ['welcome', 'general', 'none', 'none']]
print(dialog_acts)
# system model for manual, auto, auto_manual
......
......@@ -956,9 +956,9 @@ if __name__ == '__main__':
# if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
# break
# # pprint(goal)
user_goal = {'domain_ordering': ('restaurant', 'taxi'),
user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'),
'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
'info': {'internet': 'yes',
'info': {'internet': 'no',
'parking': 'no',
'pricerange': 'moderate',
'area': 'centre'}},
......
......@@ -81,6 +81,7 @@ class Analyzer:
if not os.path.exists(output_dir):
os.mkdir(output_dir)
f = open(os.path.join(output_dir, 'res.txt'), 'w')
flog = open(os.path.join(output_dir, 'log.txt'), 'w')
for j in tqdm(range(total_dialog), desc="dialogue"):
sys_response = '' if self.user_agent.nlu else []
......@@ -106,13 +107,13 @@ class Analyzer:
for i in range(40):
sys_response, user_response, session_over, reward = sess.next_turn(
sys_response)
# print('user in', sess.user_agent.get_in_da(),file=f)
# print('user out', sess.user_agent.get_out_da(),file=f)
print('user in', sess.user_agent.get_in_da(),file=flog)
print('user out', sess.user_agent.get_out_da(),file=flog)
#
# print('sys in', sess.sys_agent.get_in_da(),file=f)
# print('sys out', sess.sys_agent.get_out_da(),file=f)
# print('user:', user_response,file=f)
# print('sys:', sys_response,file=f)
print('sys in', sess.sys_agent.get_in_da(),file=flog)
print('sys out', sess.sys_agent.get_out_da(),file=flog)
print('user:', user_response,file=flog)
print('sys:', sys_response,file=flog)
step += 2
......
......@@ -67,7 +67,7 @@ def test_end2end():
analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')
set_seed(20200202)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=1000)
analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='end2end', total_dialog=100)
if __name__ == '__main__':
test_end2end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment