diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index 6cd066e64b9254813fc5684471004f1a534b100a..86c0ed7b1a9aa2667661c48cb7fc26a112fd364c 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -48,6 +48,7 @@ class MultiWozEvaluator(Evaluator): self.dbs = self.database.dbs self.check_book_constraints = check_book_constraints self.check_domain_success = check_domain_success + self.complete = 0 self.success = 0 self.success_strict = 0 self.successful_domains = [] @@ -340,6 +341,23 @@ class MultiWozEvaluator(Evaluator): else: return score + def check_booking_done(self, ref2goal=True): + if ref2goal: + goal = self._expand(self.goal) + else: + goal = self._init_dict() + for domain in belief_domains: + if domain in self.goal and 'book' in self.goal[domain]: + goal[domain]['book'] = self.goal[domain]['book'] + + # check for every domain where booking is required whether a booking has been made + for domain in goal: + if goal[domain]['book']: + if not self.booked[domain]: + return False + + return True + def inform_F1(self, ref2goal=True, aggregate=True): if ref2goal: goal = self._expand(self.goal) @@ -370,6 +388,7 @@ class MultiWozEvaluator(Evaluator): """ judge if all the domains are successfully completed """ + booking_done = self.check_booking_done(ref2goal) book_sess = self.book_rate(ref2goal) book_constraint_sess = self.book_rate_constrains(ref2goal) inform_sess = self.inform_F1(ref2goal) @@ -379,10 +398,12 @@ class MultiWozEvaluator(Evaluator): or (book_sess == 1 and inform_sess[1] is None) or (book_sess is None and inform_sess[1] == 1)) \ and goal_sess == 1: + self.complete = 1 self.success = 1 self.success_strict = 1 if (book_constraint_sess == 1 or book_constraint_sess is None) else 0 return self.success if not self.check_book_constraints else self.success_strict else: + self.complete = 1 if booking_done and (inform_sess[1] == 1 or inform_sess[1] is None) else 0 self.success = 0 self.success_strict = 0 return 0 diff --git a/convlab2/policy/evaluate_distributed.py b/convlab2/policy/evaluate_distributed.py index 8f4c8ae421b70c17fe10388fa61f9e5e34564498..94f8649811723c678b45a20f459fd333c1b43222 100644 --- a/convlab2/policy/evaluate_distributed.py +++ b/convlab2/policy/evaluate_distributed.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import random +import torch import sys from pprint import pprint @@ -25,6 +26,8 @@ def sampler(pid, queue, evt, sess, seed_range): buff = Memory_evaluator() for seed in seed_range: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) sess.init_session() @@ -69,9 +72,9 @@ def sampler(pid, queue, evt, sess, seed_range): if session_over is True: success = sess.evaluator.task_success() + complete = sess.evaluator.complete success = sess.evaluator.success success_strict = sess.evaluator.success_strict - complete = sess.user_agent.policy.policy.goal.task_complete() break total_return_success = 80 if success_strict else -40 diff --git a/convlab2/util/custom_util.py b/convlab2/util/custom_util.py index ef57956638fcedde8b40a0d4ec57911ddd9cd00d..21083d9964fb329eae692a6d8638b59d5a5408ae 100644 --- a/convlab2/util/custom_util.py +++ b/convlab2/util/custom_util.py @@ -305,10 +305,12 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False if session_over is True: task_succ = sess.evaluator.task_success() + complete = sess.evaluator.complete task_succ = sess.evaluator.success task_succ_strict = sess.evaluator.success_strict break else: + complete = 0 task_succ = 0 task_succ_strict = 0 @@ -318,8 +320,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False else: task_success[key].append(task_succ_strict) - task_success['All_user_sim'].append( - int(sess.user_agent.policy.policy.goal.task_complete())) + task_success['All_user_sim'].append(complete) task_success['All_evaluator'].append(task_succ) task_success['All_evaluator_strict'].append(task_succ_strict) total_return = 80 if task_succ_strict else -40