diff --git a/convlab2/human_eval/worlds.py b/convlab2/human_eval/worlds.py index 29a8bb956aeb562f22f064061e87512b1244d4f7..70d86a1c4db42d4b960d4133e9e96c461b8c180a 100755 --- a/convlab2/human_eval/worlds.py +++ b/convlab2/human_eval/worlds.py @@ -254,7 +254,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): except Exception as e: print(e) num_goal_trials += 1 - self.goal_message = goal_generator.build_message(self.goal) + self.goal_message, _ = goal_generator.build_message(self.goal) self.goal_text = '<ul>' for m in self.goal_message: self.goal_text += '<li>' + m + '</li>' diff --git a/convlab2/task/multiwoz/generate_goals.py b/convlab2/task/multiwoz/generate_goals.py new file mode 100644 index 0000000000000000000000000000000000000000..02c551e422387f990335652c67c613b8cddc4eaa --- /dev/null +++ b/convlab2/task/multiwoz/generate_goals.py @@ -0,0 +1,45 @@ +""" +generate user goal for collecting new multiwoz data +""" + +from convlab2.task.multiwoz.goal_generator import GoalGenerator +import random +import numpy as np +import json +import datetime +from pprint import pprint + + +def generate(total_num=1000, seed=42, output_file='goal.json'): + random.seed(seed) + np.random.seed(seed) + goal_generator = GoalGenerator() + goals = [] + avg_domains = [] + while len(goals) < total_num: + goal = goal_generator.get_user_goal() + # pprint(goal) + if 'police' in goal['domain_ordering']: + no_police = list(goal['domain_ordering']) + no_police.remove('police') + goal['domain_ordering'] = tuple(no_police) + del goal['police'] + try: + message = goal_generator.build_message(goal)[1] + except: + continue + # print(message) + avg_domains.append(len(goal['domain_ordering'])) + goals.append({ + "goals": [], + "ori_goals": goal, + "description": message, + "timestamp": str(datetime.datetime.now()), + "ID": len(goals) + }) + print('avg domains:', np.mean(avg_domains)) # avg domains: 1.827 + json.dump(goals, open(output_file, 'w'), indent=4) + + +if __name__ == '__main__': + generate(output_file='goal20200623.json') diff --git a/convlab2/task/multiwoz/goal_generator.py b/convlab2/task/multiwoz/goal_generator.py index d141c869ae893b94f60c2030956c22973c3cc4c2..a4629dd8813d2468ad3c7b4f174b9afae1eacc6f 100755 --- a/convlab2/task/multiwoz/goal_generator.py +++ b/convlab2/task/multiwoz/goal_generator.py @@ -494,6 +494,8 @@ class GoalGenerator: def build_message(self, user_goal, boldify=null_boldify): message = [] + message_by_domain = [] + mess_ptr4domain = 0 state = deepcopy(user_goal) for dom in user_goal['domain_ordering']: @@ -641,11 +643,15 @@ class GoalGenerator: message.append(templates[dom]['fail_book ' + adjusted_slot].format( self.boldify(user_goal[dom]['book'][adjusted_slot]))) + dm = message[mess_ptr4domain:] + mess_ptr4domain = len(message) + message_by_domain.append(' '.join(dm)) + if boldify == do_boldify: for i, m in enumerate(message): message[i] = message[i].replace('wifi', "<b>wifi</b>") message[i] = message[i].replace('internet', "<b>internet</b>") message[i] = message[i].replace('parking', "<b>parking</b>") - return message + return message, message_by_domain