Skip to content
Snippets Groups Projects
Unverified Commit 67f9dddf authored by aaa123git's avatar aaa123git Committed by GitHub
Browse files

update dbquery and session (#99)


* update dbquery: ? matches all; fix bug in init_session

* update multiwoz_eval, check Ref of booked

* filter domain in final_goal_analyze

Co-authored-by: default avatarnewRuntieException <wdz15@mails.tsinghua.edu.cn>
parent b82732ea
Branches
No related tags found
No related merge requests found
...@@ -144,6 +144,8 @@ class BiSession(Session): ...@@ -144,6 +144,8 @@ class BiSession(Session):
self.user_agent.init_session(**kwargs) self.user_agent.init_session(**kwargs)
if self.evaluator: if self.evaluator:
self.evaluator.add_goal(self.user_agent.policy.get_goal()) self.evaluator.add_goal(self.user_agent.policy.get_goal())
self.dialog_history = []
self.__turn_indicator = 0
class DealornotSession(Session): class DealornotSession(Session):
...@@ -198,3 +200,5 @@ class DealornotSession(Session): ...@@ -198,3 +200,5 @@ class DealornotSession(Session):
self.__turn_indicator = random.choice([0, 1]) self.__turn_indicator = random.choice([0, 1])
self.alice.init_session() self.alice.init_session()
self.bob.init_session() self.bob.init_session()
self.current_agent = None
self.dialog_history = []
...@@ -33,6 +33,7 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na ...@@ -33,6 +33,7 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na
time_re = re.compile(r'^(([01]\d|2[0-4]):([0-5]\d)|24:00)$') time_re = re.compile(r'^(([01]\d|2[0-4]):([0-5]\d)|24:00)$')
NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"] NUL_VALUE = ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]
class MultiWozEvaluator(Evaluator): class MultiWozEvaluator(Evaluator):
def __init__(self): def __init__(self):
self.sys_da_array = [] self.sys_da_array = []
...@@ -101,10 +102,12 @@ class MultiWozEvaluator(Evaluator): ...@@ -101,10 +102,12 @@ class MultiWozEvaluator(Evaluator):
if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']:
if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \
len(self.dbs[self.cur_domain]) > int(value): len(self.dbs[self.cur_domain]) > int(value):
self.booked[self.cur_domain] = self.dbs[self.cur_domain][int(value)] self.booked[self.cur_domain] = self.dbs[self.cur_domain][int(value)].copy()
self.booked[self.cur_domain]['Ref'] = value
elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': elif da == 'train-offerbooked-ref' or da == 'train-inform-ref':
if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value): if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value):
self.booked['train'] = self.dbs['train'][int(value)] self.booked['train'] = self.dbs['train'][int(value)].copy()
self.booked['train']['Ref'] = value
elif da == 'taxi-inform-car': elif da == 'taxi-inform-car':
if not self.booked['taxi']: if not self.booked['taxi']:
self.booked['taxi'] = 'booked' self.booked['taxi'] = 'booked'
...@@ -329,7 +332,6 @@ class MultiWozEvaluator(Evaluator): ...@@ -329,7 +332,6 @@ class MultiWozEvaluator(Evaluator):
inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) inform = self._inform_F1_goal(goal, self.sys_da_array, [domain])
return inform return inform
def domain_success(self, domain, ref2goal=True): def domain_success(self, domain, ref2goal=True):
""" """
judge if the domain (subtask) is successfully completed judge if the domain (subtask) is successfully completed
...@@ -376,18 +378,26 @@ class MultiWozEvaluator(Evaluator): ...@@ -376,18 +378,26 @@ class MultiWozEvaluator(Evaluator):
for domain, dom_goal_dict in self.goal.items(): for domain, dom_goal_dict in self.goal.items():
constraints = [] constraints = []
if 'reqt' in dom_goal_dict: if 'reqt' in dom_goal_dict:
constraints += list(dom_goal_dict['reqt'].items()) reqt_constraints = list(dom_goal_dict['reqt'].items())
constraints += reqt_constraints
else:
reqt_constraints = []
if 'info' in dom_goal_dict: if 'info' in dom_goal_dict:
constraints += list(dom_goal_dict['info'].items()) info_constraints = list(dom_goal_dict['info'].items())
query_result = self.database.query(domain, constraints) constraints += info_constraints
else:
info_constraints = []
query_result = self.database.query(domain, info_constraints, soft_contraints=reqt_constraints)
if not query_result: if not query_result:
mismatch += 1 mismatch += 1
else: continue
booked = self.booked[domain] booked = self.booked[domain]
if booked is None: if not self.goal[domain].get('book'):
match += 1 match += 1
elif isinstance(booked, dict): elif isinstance(booked, dict):
if all(booked.get(k, object()) == v for k, v in constraints): ref = booked['Ref']
if any(found['Ref'] == ref for found in query_result):
match += 1 match += 1
else: else:
mismatch += 1 mismatch += 1
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import json import json
import os import os
import random import random
from fuzzywuzzy import fuzz
from itertools import chain
from copy import deepcopy from copy import deepcopy
...@@ -18,7 +20,7 @@ class Database(object): ...@@ -18,7 +20,7 @@ class Database(object):
'data/multiwoz/db/{}_db.json'.format(domain))) as f: 'data/multiwoz/db/{}_db.json'.format(domain))) as f:
self.dbs[domain] = json.load(f) self.dbs[domain] = json.load(f)
def query(self, domain, constraints, ignore_open=False): def query(self, domain, constraints, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
"""Returns the list of entities for a given domain """Returns the list of entities for a given domain
based on the annotation of the belief state""" based on the annotation of the belief state"""
# query the db # query the db
...@@ -43,7 +45,9 @@ class Database(object): ...@@ -43,7 +45,9 @@ class Database(object):
found = [] found = []
for i, record in enumerate(self.dbs[domain]): for i, record in enumerate(self.dbs[domain]):
for key, val in constraints: constraints_iterator = zip(constraints, [False] * len(constraints))
soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints))
for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator):
if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care":
pass pass
else: else:
...@@ -64,9 +68,16 @@ class Database(object): ...@@ -64,9 +68,16 @@ class Database(object):
# elif ignore_open and key in ['destination', 'departure', 'name']: # elif ignore_open and key in ['destination', 'departure', 'name']:
elif ignore_open and key in ['destination', 'departure']: elif ignore_open and key in ['destination', 'departure']:
continue continue
elif record[key].strip() == '?':
# '?' matches any constraint
continue
else: else:
if not fuzzy_match:
if val.strip().lower() != record[key].strip().lower(): if val.strip().lower() != record[key].strip().lower():
break break
else:
if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio:
break
except: except:
continue continue
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment