Skip to content
Snippets Groups Projects
Unverified Commit 6a84bb10 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Fix goal generator and dbquery for multiwoz (#32)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* replace attraction type 'mutliple sports' to 'multiple sports', involving only one entity

* add depart and destination constraints for searching db (ignore=False), modify goal generator to draw the values of these two slots from database
parent bdc9dba7
No related branches found
No related tags found
No related merge requests found
......@@ -37,9 +37,10 @@ def generate(total_num=1000, seed=42, output_file='goal.json'):
"timestamp": str(datetime.datetime.now()),
"ID": len(goals)
})
print('avg domains:', np.mean(avg_domains)) # avg domains: 1.827
print('avg domains:', np.mean(avg_domains)) # avg domains: 1.846
json.dump(goals, open(output_file, 'w'), indent=4)
if __name__ == '__main__':
generate(output_file='goal20200623.json')
generate(output_file='goal20200629.json')
\ No newline at end of file
......@@ -35,7 +35,7 @@ templates = {
'intro': 'You are looking for information in Cambridge.',
'restaurant': {
'intro': 'You are looking forward to trying local restaurants.',
'request': 'Once you find a restaurnat, make sure you get {}.',
'request': 'Once you find a restaurant, make sure you get {}.',
'area': 'The restaurant should be in the {}.',
'food': 'The restaurant should serve {} food.',
'name': 'You are looking for a particular restaurant. Its name is called {}.',
......@@ -148,6 +148,7 @@ class GoalGenerator:
self.corpus_path = corpus_path
self.db = Database()
self.boldify = do_boldify if boldify else null_boldify
self.train_database = self.db.query('train',[])
if os.path.exists(self.goal_model_path):
self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load(
open(self.goal_model_path, 'rb'))
......@@ -305,13 +306,18 @@ class GoalGenerator:
else:
domain_goal['info']['leaveAt'] = nomial_sample(cnt_slot_value['info']['leaveAt'])
if domain in ['taxi', 'train'] and 'departure' not in domain_goal['info']:
if domain in ['train']:
random_train = random.choice(self.train_database)
domain_goal['info']['departure'] = random_train['departure']
domain_goal['info']['destination'] = random_train['destination']
if domain in ['taxi'] and 'departure' not in domain_goal['info']:
domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure'])
if domain in ['taxi', 'train'] and 'destination' not in domain_goal['info']:
if domain in ['taxi'] and 'destination' not in domain_goal['info']:
domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination'])
if domain in ['taxi', 'train'] and \
if domain in ['taxi'] and \
'departure' in domain_goal['info'] and \
'destination' in domain_goal['info'] and \
domain_goal['info']['departure'] == domain_goal['info']['destination']:
......@@ -515,7 +521,8 @@ class GoalGenerator:
info in user_goal['attraction'].keys() and
'area' in user_goal['restaurant'][info] and
'area' in user_goal['attraction'][info] and
user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']):
user_goal['restaurant'][info]['area'] ==
user_goal['attraction'][info]['area']):
return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot]))
else:
restaurant_index = user_goal['domain_ordering'].index('restaurant')
......@@ -539,17 +546,21 @@ class GoalGenerator:
if 'arriveBy' in state[info]:
m.append('The taxi should arrive at the {} from the {} by {}.'.format(self.boldify(places[0]),
self.boldify(places[1]),
self.boldify(state[info]['arriveBy'])))
self.boldify(state[info][
'arriveBy'])))
elif 'leaveAt' in state[info]:
m.append('The taxi should leave from the {} to the {} after {}.'.format(self.boldify(places[0]),
self.boldify(places[1]),
self.boldify(state[info]['leaveAt'])))
self.boldify(
state[info][
'leaveAt'])))
message.append(' '.join(m))
else:
while len(state[info]) > 0:
num_acts = random.randint(1, min(len(state[info]), 3))
slots = random.sample(list(state[info].keys()), num_acts)
sents = [fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet']]
sents = [fill_info_template(user_goal, dom, slot, info) for slot in slots if
slot not in ['parking', 'internet']]
if 'parking' in slots:
sents.append(templates[dom]['parking ' + state[info]['parking']])
if 'internet' in slots:
......@@ -564,11 +575,14 @@ class GoalGenerator:
if 'fail_info' in user_goal[dom]:
# if 'fail_info' in user_goal[dom]:
adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1],
zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0]
zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[
0][0][0]
if adjusted_slot in ['internet', 'parking']:
message.append(templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]])
message.append(
templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]])
else:
message.append(templates[dom]['fail_info ' + adjusted_slot].format(self.boldify(user_goal[dom]['info'][adjusted_slot])))
message.append(templates[dom]['fail_info ' + adjusted_slot].format(
self.boldify(user_goal[dom]['info'][adjusted_slot])))
# reqt
if 'reqt' in state:
......@@ -634,7 +648,8 @@ class GoalGenerator:
# fail_book
if 'fail_book' in user_goal[dom]:
adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(),
user_goal[dom]['fail_book'].items())))[0][0][0]
user_goal[dom]['fail_book'].items())))[0][
0][0]
if adjusted_slot in ['internet', 'parking']:
message.append(
......
......@@ -17,7 +17,7 @@ class Database(object):
'data/multiwoz/db/{}_db.json'.format(domain))) as f:
self.dbs[domain] = json.load(f)
def query(self, domain, constraints, ignore_open=True):
def query(self, domain, constraints, ignore_open=False):
"""Returns the list of entities for a given domain
based on the annotation of the belief state"""
# query the db
......@@ -29,6 +29,9 @@ class Database(object):
return self.dbs['police']
if domain == 'hospital':
return self.dbs['hospital']
for ele in constraints:
if ele[0] == 'area' and ele[1] == 'center':
ele[1] = 'centre'
found = []
for i, record in enumerate(self.dbs[domain]):
......@@ -54,7 +57,7 @@ class Database(object):
elif ignore_open and key in ['destination', 'departure']:
continue
else:
if val.strip() != record[key].strip():
if val.strip().lower() != record[key].strip().lower():
break
except:
continue
......@@ -63,3 +66,6 @@ class Database(object):
found.append(record)
return found
if __name__ == '__main__':
db = Database()
print(db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arriveBy', '11:15']]))
......@@ -1085,7 +1085,7 @@
"phone": "01223576412",
"postcode": "cb19ej",
"pricerange": "?",
"type": "mutliple sports"
"type": "multiple sports"
},
{
"address": "8 market passage",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment