diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py index abb3e83be02c0c101d40e4dbf800598ec8429bbc..ffe70de297fcf9c8aa11c47ce1e880fe50db5da9 100755 --- a/convlab2/dialog_agent/session.py +++ b/convlab2/dialog_agent/session.py @@ -144,6 +144,8 @@ class BiSession(Session): self.user_agent.init_session(**kwargs) if self.evaluator: self.evaluator.add_goal(self.user_agent.policy.get_goal()) + self.dialog_history = [] + self.__turn_indicator = 0 class DealornotSession(Session): @@ -198,3 +200,5 @@ class DealornotSession(Session): self.__turn_indicator = random.choice([0, 1]) self.alice.init_session() self.bob.init_session() + self.current_agent = None + self.dialog_history = [] diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index 65816a5f236718325fcc3c0bf1520bb3c67c0175..e7303c56d204eed52bf9deb3d55a5a258f9a1ff1 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -33,6 +33,7 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na time_re = re.compile(r'^(([01]\d|2[0-4]):([0-5]\d)|24:00)$') NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"] + class MultiWozEvaluator(Evaluator): def __init__(self): self.sys_da_array = [] @@ -101,10 +102,12 @@ class MultiWozEvaluator(Evaluator): if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ len(self.dbs[self.cur_domain]) > int(value): - self.booked[self.cur_domain] = self.dbs[self.cur_domain][int(value)] + self.booked[self.cur_domain] = self.dbs[self.cur_domain][int(value)].copy() + self.booked[self.cur_domain]['Ref'] = value elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value): - self.booked['train'] = self.dbs['train'][int(value)] + self.booked['train'] = self.dbs['train'][int(value)].copy() + self.booked['train']['Ref'] = value elif da == 'taxi-inform-car': if not self.booked['taxi']: self.booked['taxi'] = 'booked' @@ -182,7 +185,7 @@ class MultiWozEvaluator(Evaluator): for domain in domains: inform_slot[domain] = set() TP, FP, FN = 0, 0, 0 - + inform_not_reqt = set() reqt_not_inform = set() bad_inform = set() @@ -198,7 +201,7 @@ class MultiWozEvaluator(Evaluator): else: bad_inform.add((intent, domain, key)) FP += 1 - + for domain in domains: for k in goal[domain]['reqt']: if k in inform_slot[domain]: @@ -225,7 +228,7 @@ class MultiWozEvaluator(Evaluator): return time_re.match(value) elif key == "day": return value.lower() in ["monday", "tuesday", "wednesday", "thursday", "friday", - "saturday", "sunday"] + "saturday", "sunday"] elif key == "duration": return 'minute' in value elif key == "internet" or key == "parking": @@ -298,9 +301,9 @@ class MultiWozEvaluator(Evaluator): goal_sess = self.final_goal_analyze() # book rate == 1 & inform recall == 1 if ((book_sess == 1 and inform_sess[1] == 1) \ - or (book_sess == 1 and inform_sess[1] is None) \ - or (book_sess is None and inform_sess[1] == 1)) \ - and goal_sess == 1: + or (book_sess == 1 and inform_sess[1] is None) \ + or (book_sess is None and inform_sess[1] == 1)) \ + and goal_sess == 1: return 1 else: return 0 @@ -328,7 +331,6 @@ class MultiWozEvaluator(Evaluator): inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) return inform - def domain_success(self, domain, ref2goal=True): """ @@ -376,23 +378,31 @@ class MultiWozEvaluator(Evaluator): for domain, dom_goal_dict in self.goal.items(): constraints = [] if 'reqt' in dom_goal_dict: - constraints += list(dom_goal_dict['reqt'].items()) + reqt_constraints = list(dom_goal_dict['reqt'].items()) + constraints += reqt_constraints + else: + reqt_constraints = [] if 'info' in dom_goal_dict: - constraints += list(dom_goal_dict['info'].items()) - query_result = self.database.query(domain, constraints) + info_constraints = list(dom_goal_dict['info'].items()) + constraints += info_constraints + else: + info_constraints = [] + query_result = self.database.query(domain, info_constraints, soft_contraints=reqt_constraints) if not query_result: mismatch += 1 - else: - booked = self.booked[domain] - if booked is None: + continue + + booked = self.booked[domain] + if not self.goal[domain].get('book'): + match += 1 + elif isinstance(booked, dict): + ref = booked['Ref'] + if any(found['Ref'] == ref for found in query_result): match += 1 - elif isinstance(booked, dict): - if all(booked.get(k, object()) == v for k, v in constraints): - match += 1 - else: - mismatch += 1 else: - match += 1 + mismatch += 1 + else: + match += 1 return match, mismatch def final_goal_analyze(self): diff --git a/convlab2/util/multiwoz/dbquery.py b/convlab2/util/multiwoz/dbquery.py index b84bf8a7cd42c3ed001ef8989d291128164bcbe3..80aab6df92701112b533ad24de2728368bcb7b5c 100755 --- a/convlab2/util/multiwoz/dbquery.py +++ b/convlab2/util/multiwoz/dbquery.py @@ -3,6 +3,8 @@ import json import os import random +from fuzzywuzzy import fuzz +from itertools import chain from copy import deepcopy @@ -18,7 +20,7 @@ class Database(object): 'data/multiwoz/db/{}_db.json'.format(domain))) as f: self.dbs[domain] = json.load(f) - def query(self, domain, constraints, ignore_open=False): + def query(self, domain, constraints, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): """Returns the list of entities for a given domain based on the annotation of the belief state""" # query the db @@ -43,7 +45,9 @@ class Database(object): found = [] for i, record in enumerate(self.dbs[domain]): - for key, val in constraints: + constraints_iterator = zip(constraints, [False] * len(constraints)) + soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) + for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": pass else: @@ -64,9 +68,16 @@ class Database(object): # elif ignore_open and key in ['destination', 'departure', 'name']: elif ignore_open and key in ['destination', 'departure']: continue + elif record[key].strip() == '?': + # '?' matches any constraint + continue else: - if val.strip().lower() != record[key].strip().lower(): - break + if not fuzzy_match: + if val.strip().lower() != record[key].strip().lower(): + break + else: + if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio: + break except: continue else: