Skip to content
Snippets Groups Projects
Unverified Commit 4787ac63 authored by Christian Geishauser's avatar Christian Geishauser Committed by GitHub
Browse files

Merge pull request #24 from ConvLab/complete_rate_refactor

Complete rate evaluated in multiwoz_evaluator
parents ca649ca4 1cb5d869
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,7 @@ class MultiWozEvaluator(Evaluator):
self.dbs = self.database.dbs
self.check_book_constraints = check_book_constraints
self.check_domain_success = check_domain_success
self.complete = 0
self.success = 0
self.success_strict = 0
self.successful_domains = []
......@@ -340,6 +341,23 @@ class MultiWozEvaluator(Evaluator):
else:
return score
def check_booking_done(self, ref2goal=True):
if ref2goal:
goal = self._expand(self.goal)
else:
goal = self._init_dict()
for domain in belief_domains:
if domain in self.goal and 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book']
# check for every domain where booking is required whether a booking has been made
for domain in goal:
if goal[domain]['book']:
if not self.booked[domain]:
return False
return True
def inform_F1(self, ref2goal=True, aggregate=True):
if ref2goal:
goal = self._expand(self.goal)
......@@ -370,6 +388,7 @@ class MultiWozEvaluator(Evaluator):
"""
judge if all the domains are successfully completed
"""
booking_done = self.check_booking_done(ref2goal)
book_sess = self.book_rate(ref2goal)
book_constraint_sess = self.book_rate_constrains(ref2goal)
inform_sess = self.inform_F1(ref2goal)
......@@ -379,10 +398,12 @@ class MultiWozEvaluator(Evaluator):
or (book_sess == 1 and inform_sess[1] is None)
or (book_sess is None and inform_sess[1] == 1)) \
and goal_sess == 1:
self.complete = 1
self.success = 1
self.success_strict = 1 if (book_constraint_sess == 1 or book_constraint_sess is None) else 0
return self.success if not self.check_book_constraints else self.success_strict
else:
self.complete = 1 if booking_done and (inform_sess[1] == 1 or inform_sess[1] is None) else 0
self.success = 0
self.success_strict = 0
return 0
......
# -*- coding: utf-8 -*-
import random
import torch
import sys
from pprint import pprint
......@@ -25,6 +26,8 @@ def sampler(pid, queue, evt, sess, seed_range):
buff = Memory_evaluator()
for seed in seed_range:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
sess.init_session()
......@@ -69,9 +72,9 @@ def sampler(pid, queue, evt, sess, seed_range):
if session_over is True:
success = sess.evaluator.task_success()
complete = sess.evaluator.complete
success = sess.evaluator.success
success_strict = sess.evaluator.success_strict
complete = sess.user_agent.policy.policy.goal.task_complete()
break
total_return_success = 80 if success_strict else -40
......
......@@ -305,10 +305,12 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
if session_over is True:
task_succ = sess.evaluator.task_success()
complete = sess.evaluator.complete
task_succ = sess.evaluator.success
task_succ_strict = sess.evaluator.success_strict
break
else:
complete = 0
task_succ = 0
task_succ_strict = 0
......@@ -318,8 +320,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
else:
task_success[key].append(task_succ_strict)
task_success['All_user_sim'].append(
int(sess.user_agent.policy.policy.goal.task_complete()))
task_success['All_user_sim'].append(complete)
task_success['All_evaluator'].append(task_succ)
task_success['All_evaluator_strict'].append(task_succ_strict)
total_return = 80 if task_succ_strict else -40
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment