Skip to content
Snippets Groups Projects
Commit 9d85c597 authored by truthless11's avatar truthless11 Committed by zhuqi
Browse files

fix sequicityy

parent acfef62a
No related branches found
No related tags found
No related merge requests found
...@@ -120,7 +120,7 @@ Performance (the first row is the default config for each module. Empty entries ...@@ -120,7 +120,7 @@ Performance (the first row is the default config for each module. Empty entries
| BERTNLU | RuleDST | **HDSA** | None | 35.6 | 27.5 | 5.4 | 47.8/57.2/48.8 | 13.0/31.5 | | BERTNLU | RuleDST | **HDSA** | None | 35.6 | 27.5 | 5.4 | 47.8/57.2/48.8 | 13.0/31.5 |
| BERTNLU | RuleDST | **LaRL** | None | 40.6 | 34.0 | 45.6 | 47.8/54.1/47.6 | 15.0/28.6 | | BERTNLU | RuleDST | **LaRL** | None | 40.6 | 34.0 | 45.6 | 47.8/54.1/47.6 | 15.0/28.6 |
| None | **SUMBT** | **LaRL** | None | 39.4| 33.1| 39.5 | 48.5/56.0/48.8| 15.5/28.7| | None | **SUMBT** | **LaRL** | None | 39.4| 33.1| 39.5 | 48.5/56.0/48.8| 15.5/28.7|
| None | None | **Sequicity*** | None | 13.1 | 10.5 | 5.1 | 41.4/30.8/31.3 | 12.9/38.3 | | None | None | **Sequicity*** | None | 21.7 | 14.0 | 4.9 | 36.3/35.1/32.0 | 18.2/35.2 |
| None | None | **DAMD*** | None | 38.5 | 33.6 | 50.9 | 62.1/60.7/57.4 | 10.4/28.2 | | None | None | **DAMD*** | None | 38.5 | 33.6 | 50.9 | 62.1/60.7/57.4 | 10.4/28.2 |
*: end-to-end models used as sys_agent directly. *: end-to-end models used as sys_agent directly.
......
import os
import argparse import argparse
import csv import csv
import functools import functools
...@@ -124,6 +125,8 @@ class GenericEvaluator: ...@@ -124,6 +125,8 @@ class GenericEvaluator:
self.metric_dict = {} self.metric_dict = {}
self.entity_dict = {} self.entity_dict = {}
filename = result_path.split('/')[-1] filename = result_path.split('/')[-1]
if not os.path.exists('./sheets/'):
os.makedirs('./sheets/')
dump_dir = './sheets/' + filename.replace('.csv','.report.txt') dump_dir = './sheets/' + filename.replace('.csv','.report.txt')
self.dump_file = open(dump_dir,'w') self.dump_file = open(dump_dir,'w')
......
...@@ -61,7 +61,7 @@ In terms of `success F1`, Sequicity by order shows the (F1, Precision, Recall) ...@@ -61,7 +61,7 @@ In terms of `success F1`, Sequicity by order shows the (F1, Precision, Recall)
| BLEU | Match | Success (F1, Prec., Rec.) | | BLEU | Match | Success (F1, Prec., Rec.) |
| - | - | - | | - | - | - |
| 0.0691 | 0.4994 |(0.5059, 0.5925, 0.4414)| | 0.1506 | 0.6057 |(0.5662, 0.6701, 0.4902)|
## Reference ## Reference
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
"multiwoz/data/hotel_db.json", "multiwoz/data/hotel_db.json",
"multiwoz/data/restaurant_db.json", "multiwoz/data/restaurant_db.json",
"multiwoz/data/hospital_db.json", "multiwoz/data/hospital_db.json",
"multiwoz/data/train_db.json" "multiwoz/data/train_db.json",
"multiwoz/data/police_db.json"
], ],
"glove_path": "multiwoz/data/glove.6B.50d.txt", "glove_path": "multiwoz/data/glove.6B.50d.txt",
"batch_size": 32, "batch_size": 32,
......
...@@ -96,7 +96,7 @@ class Sequicity(Agent): ...@@ -96,7 +96,7 @@ class Sequicity(Agent):
in constraint_request else constraint_request in constraint_request else constraint_request
for j, ent in enumerate(constraints): for j, ent in enumerate(constraints):
constraints[j] = ent.replace('_', ' ') constraints[j] = ent.replace('_', ' ')
degree = self.m.reader.db_search(constraints) degree = self.m.reader.db_search(constraints[1:], constraints[0] if constraints else 'restaurant')
degree_input_list = self.m.reader._degree_vec_mapping(len(degree)) degree_input_list = self.m.reader._degree_vec_mapping(len(degree))
degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0))) degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0)))
return degree, degree_input return degree, degree_input
......
...@@ -767,30 +767,25 @@ class MultiWozReader(_ReaderBase): ...@@ -767,30 +767,25 @@ class MultiWozReader(_ReaderBase):
self.result_file = '' self.result_file = ''
def _get_tokenized_data(self, raw_data, db_data, construct_vocab): def _get_tokenized_data(self, raw_data, db_data, construct_vocab):
requestable_keys = ['addr', 'area', 'fee', 'name', 'phone', 'post', 'price', 'type', 'department', 'internet', 'parking', 'stars', 'food', 'arrive', 'day', 'depart', 'dest', 'leave', 'ticket', 'id']
tokenized_data = [] tokenized_data = []
vk_map = self._value_key_map(db_data) vk_map = self._value_key_map(db_data)
for dial_id, dial in enumerate(raw_data): for dial_id, dial in enumerate(raw_data):
tokenized_dial = [] tokenized_dial = []
for turn in dial['dial']: for turn in dial['dial']:
turn_num = turn['turn'] turn_num = turn['turn']
constraint = [] constraint = [turn['domain']]
requested = [] requested = []
for slot_act in turn['usr']['slu']: for slot_act in turn['usr']['slu']:
if slot_act == 'inform': if slot_act == 'inform':
slot_values = turn['usr']['slu'][slot_act] slot_values = turn['usr']['slu'][slot_act]
for v in slot_values: for v in slot_values:
s = v[1] if v[1] not in ['do_nt_care', 'none']:
if s not in ['dont_care', 'none']: constraint.append(v[1])
constraint.append(s)
elif slot_act == 'request': elif slot_act == 'request':
slot_values = turn['usr']['slu'][slot_act] slot_values = turn['usr']['slu'][slot_act]
for v in slot_values: for v in slot_values:
s = v[0] requested.append(v[0])
if s in requestable_keys: degree = len(self.db_search(constraint[1:], constraint[0]))
requested.append(s)
degree = len(self.db_search(constraint))
requested = sorted(requested) requested = sorted(requested)
constraint.append('EOS_Z1') constraint.append('EOS_Z1')
requested.append('EOS_Z2') requested.append('EOS_Z2')
...@@ -837,7 +832,7 @@ class MultiWozReader(_ReaderBase): ...@@ -837,7 +832,7 @@ class MultiWozReader(_ReaderBase):
string = re.sub(r'_+', '_', string) string = re.sub(r'_+', '_', string)
string = re.sub(r'children', 'child_-s', string) string = re.sub(r'children', 'child_-s', string)
return string return string
requestable_dict = {'address':'addr', slot_dict = {'address':'addr',
'area':'area', 'area':'area',
'entrance fee':'fee', 'entrance fee':'fee',
'name':'name', 'name':'name',
...@@ -856,12 +851,15 @@ class MultiWozReader(_ReaderBase): ...@@ -856,12 +851,15 @@ class MultiWozReader(_ReaderBase):
'destination':'dest', 'destination':'dest',
'leaveAt':'leave', 'leaveAt':'leave',
'price':'ticket', 'price':'ticket',
'trainId':'id'} 'trainId':'id',
'time':'time',
'ref':'ref'}
value_key = {} value_key = {}
for db_entry in db_data: for domain in db_data:
for db_entry in db_data[domain]:
for k, v in db_entry.items(): for k, v in db_entry.items():
if k in requestable_dict: if k in slot_dict:
value_key[normal(v)] = requestable_dict[k] value_key[normal(v)] = slot_dict[k]
return value_key return value_key
def _get_encoded_data(self, tokenized_data): def _get_encoded_data(self, tokenized_data):
...@@ -897,7 +895,8 @@ class MultiWozReader(_ReaderBase): ...@@ -897,7 +895,8 @@ class MultiWozReader(_ReaderBase):
return encoded_data return encoded_data
def _get_clean_db(self, raw_db_data): def _get_clean_db(self, raw_db_data):
for entry in raw_db_data: for domain in raw_db_data:
for entry in raw_db_data[domain]:
for k, v in list(entry.items()): for k, v in list(entry.items()):
if not isinstance(v, str) or v == '?': if not isinstance(v, str) or v == '?':
entry.pop(k) entry.pop(k)
...@@ -921,13 +920,14 @@ class MultiWozReader(_ReaderBase): ...@@ -921,13 +920,14 @@ class MultiWozReader(_ReaderBase):
dev_raw_data = json.loads(f.read().lower()) dev_raw_data = json.loads(f.read().lower())
with open(test_json_path) as f: with open(test_json_path) as f:
test_raw_data = json.loads(f.read().lower()) test_raw_data = json.loads(f.read().lower())
db_data = list() db_data = dict()
for domain_db_json_path in db_json_path: for domain_db_json_path in db_json_path:
with open(domain_db_json_path) as f: with open(domain_db_json_path) as f:
db_data_domain = json.loads(f.read().lower()) db_data_domain = json.loads(f.read().lower())
for i, item in enumerate(db_data_domain): for i, item in enumerate(db_data_domain):
item['ref'] = f'{i:08d}' item['ref'] = f'{i:08d}'
db_data += db_data_domain domain=domain_db_json_path.split('/')[-1][:-8]
db_data[domain] = db_data_domain
self._get_clean_db(db_data) self._get_clean_db(db_data)
self.db = db_data self.db = db_data
...@@ -946,9 +946,14 @@ class MultiWozReader(_ReaderBase): ...@@ -946,9 +946,14 @@ class MultiWozReader(_ReaderBase):
random.shuffle(self.dev) random.shuffle(self.dev)
random.shuffle(self.test) random.shuffle(self.test)
def db_search(self, constraints): def db_search(self, constraints, domain):
if domain == 'taxi':
match_results = [{'phone':'0123456789','car':'black toyota'}]
return match_results
elif domain not in self.db:
return []
match_results = [] match_results = []
for entry in self.db: for entry in self.db[domain]:
entry_values = ' '.join(entry.values()) entry_values = ' '.join(entry.values())
match = True match = True
for c in constraints: for c in constraints:
...@@ -993,7 +998,7 @@ class MultiWozReader(_ReaderBase): ...@@ -993,7 +998,7 @@ class MultiWozReader(_ReaderBase):
in constraint_request else constraint_request in constraint_request else constraint_request
for j, ent in enumerate(constraints): for j, ent in enumerate(constraints):
constraints[j] = ent.replace('_', ' ') constraints[j] = ent.replace('_', ' ')
degree = self.db_search(constraints) degree = self.db_search(constraints[1:], constraints[0] if constraints else 'restaurant')
#print('constraints',constraints) #print('constraints',constraints)
#print('degree',degree) #print('degree',degree)
venue = random.sample(degree, 1)[0] if degree else dict() venue = random.sample(degree, 1)[0] if degree else dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment