Skip to content
Snippets Groups Projects
Commit 5779d158 authored by newRuntieException's avatar newRuntieException Committed by zhuqi
Browse files

update Evaluator: check whether final goal satisfies constraints

parent 50424c34
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,8 @@ class MultiWozEvaluator(Evaluator):
self.goal = {}
self.cur_domain = ''
self.booked = {}
self.dbs = Database().dbs
self.database = Database()
self.dbs = self.database.dbs
def _init_dict(self):
dic = {}
......@@ -294,10 +295,12 @@ class MultiWozEvaluator(Evaluator):
"""
book_sess = self.book_rate(ref2goal)
inform_sess = self.inform_F1(ref2goal)
goal_sess = self.final_goal_analyze()
# book rate == 1 & inform recall == 1
if (book_sess == 1 and inform_sess[1] == 1) \
if ((book_sess == 1 and inform_sess[1] == 1) \
or (book_sess == 1 and inform_sess[1] is None) \
or (book_sess is None and inform_sess[1] == 1):
or (book_sess is None and inform_sess[1] == 1)) \
and goal_sess == 1:
return 1
else:
return 0
......@@ -366,3 +369,29 @@ class MultiWozEvaluator(Evaluator):
return 1
else:
return 0
def _final_goal_analyze(self):
"""whether the final goal satisfies constraints"""
match = mismatch = 0
for domain, dom_goal_dict in self.goal.items():
constraints = []
if 'reqt' in dom_goal_dict:
constraints += list(dom_goal_dict['reqt'].items())
if 'info' in dom_goal_dict:
constraints += list(dom_goal_dict['info'].items())
query_result = self.database.query(domain, constraints)
if not query_result:
mismatch += 1
else:
if self.booked[domain] is not None and not self.database.query(domain, list(self.booked[domain].items())):
mismatch += 1
else:
match += 1
return match, mismatch
def final_goal_analyze(self):
match, mismatch = self._final_goal_analyze()
if match == mismatch == 0:
return 1
else:
return match / (match + mismatch)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment