Skip to content
Snippets Groups Projects
Commit bdc9dba7 authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

modify build message function for goal generation

parent 6e238521
No related branches found
No related tags found
No related merge requests found
...@@ -254,7 +254,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): ...@@ -254,7 +254,7 @@ class MultiWozEvalWorld(MTurkTaskWorld):
except Exception as e: except Exception as e:
print(e) print(e)
num_goal_trials += 1 num_goal_trials += 1
self.goal_message = goal_generator.build_message(self.goal) self.goal_message, _ = goal_generator.build_message(self.goal)
self.goal_text = '<ul>' self.goal_text = '<ul>'
for m in self.goal_message: for m in self.goal_message:
self.goal_text += '<li>' + m + '</li>' self.goal_text += '<li>' + m + '</li>'
......
"""
generate user goal for collecting new multiwoz data
"""
from convlab2.task.multiwoz.goal_generator import GoalGenerator
import random
import numpy as np
import json
import datetime
from pprint import pprint
def generate(total_num=1000, seed=42, output_file='goal.json'):
random.seed(seed)
np.random.seed(seed)
goal_generator = GoalGenerator()
goals = []
avg_domains = []
while len(goals) < total_num:
goal = goal_generator.get_user_goal()
# pprint(goal)
if 'police' in goal['domain_ordering']:
no_police = list(goal['domain_ordering'])
no_police.remove('police')
goal['domain_ordering'] = tuple(no_police)
del goal['police']
try:
message = goal_generator.build_message(goal)[1]
except:
continue
# print(message)
avg_domains.append(len(goal['domain_ordering']))
goals.append({
"goals": [],
"ori_goals": goal,
"description": message,
"timestamp": str(datetime.datetime.now()),
"ID": len(goals)
})
print('avg domains:', np.mean(avg_domains)) # avg domains: 1.827
json.dump(goals, open(output_file, 'w'), indent=4)
if __name__ == '__main__':
generate(output_file='goal20200623.json')
...@@ -494,6 +494,8 @@ class GoalGenerator: ...@@ -494,6 +494,8 @@ class GoalGenerator:
def build_message(self, user_goal, boldify=null_boldify): def build_message(self, user_goal, boldify=null_boldify):
message = [] message = []
message_by_domain = []
mess_ptr4domain = 0
state = deepcopy(user_goal) state = deepcopy(user_goal)
for dom in user_goal['domain_ordering']: for dom in user_goal['domain_ordering']:
...@@ -641,11 +643,15 @@ class GoalGenerator: ...@@ -641,11 +643,15 @@ class GoalGenerator:
message.append(templates[dom]['fail_book ' + adjusted_slot].format( message.append(templates[dom]['fail_book ' + adjusted_slot].format(
self.boldify(user_goal[dom]['book'][adjusted_slot]))) self.boldify(user_goal[dom]['book'][adjusted_slot])))
dm = message[mess_ptr4domain:]
mess_ptr4domain = len(message)
message_by_domain.append(' '.join(dm))
if boldify == do_boldify: if boldify == do_boldify:
for i, m in enumerate(message): for i, m in enumerate(message):
message[i] = message[i].replace('wifi', "<b>wifi</b>") message[i] = message[i].replace('wifi', "<b>wifi</b>")
message[i] = message[i].replace('internet', "<b>internet</b>") message[i] = message[i].replace('internet', "<b>internet</b>")
message[i] = message[i].replace('parking', "<b>parking</b>") message[i] = message[i].replace('parking', "<b>parking</b>")
return message return message, message_by_domain
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment