Skip to content
Snippets Groups Projects
Commit dfed1f2d authored by truthless11's avatar truthless11
Browse files

fix sequicityy

parent 9d0061d3
No related branches found
No related tags found
No related merge requests found
......@@ -120,7 +120,7 @@ Performance (the first row is the default config for each module. Empty entries
| BERTNLU | RuleDST | **HDSA** | None | 35.6 | 27.5 | 5.4 | 47.8/57.2/48.8 | 13.0/31.5 |
| BERTNLU | RuleDST | **LaRL** | None | 40.6 | 34.0 | 45.6 | 47.8/54.1/47.6 | 15.0/28.6 |
| None | **SUMBT** | **LaRL** | None | 39.4| 33.1| 39.5 | 48.5/56.0/48.8| 15.5/28.7|
| None | None | **Sequicity*** | None | 13.1 | 10.5 | 5.1 | 41.4/30.8/31.3 | 12.9/38.3 |
| None | None | **Sequicity*** | None | 21.7 | 14.0 | 4.9 | 36.3/35.1/32.0 | 18.2/35.2 |
| None | None | **DAMD*** | None | 38.5 | 33.6 | 50.9 | 62.1/60.7/57.4 | 10.4/28.2 |
*: end-to-end models used as sys_agent directly.
......
import os
import argparse
import csv
import functools
......@@ -124,6 +125,8 @@ class GenericEvaluator:
self.metric_dict = {}
self.entity_dict = {}
filename = result_path.split('/')[-1]
if not os.path.exists('./sheets/'):
os.makedirs('./sheets/')
dump_dir = './sheets/' + filename.replace('.csv','.report.txt')
self.dump_file = open(dump_dir,'w')
......
......@@ -61,7 +61,7 @@ In terms of `success F1`, Sequicity by order shows the (F1, Precision, Recall)
| BLEU | Match | Success (F1, Prec., Rec.) |
| - | - | - |
| 0.0691 | 0.4994 |(0.5059, 0.5925, 0.4414)|
| 0.1506 | 0.6057 |(0.5662, 0.6701, 0.4902)|
## Reference
......
......@@ -19,7 +19,8 @@
"multiwoz/data/hotel_db.json",
"multiwoz/data/restaurant_db.json",
"multiwoz/data/hospital_db.json",
"multiwoz/data/train_db.json"
"multiwoz/data/train_db.json",
"multiwoz/data/police_db.json"
],
"glove_path": "multiwoz/data/glove.6B.50d.txt",
"batch_size": 32,
......
......@@ -96,7 +96,7 @@ class Sequicity(Agent):
in constraint_request else constraint_request
for j, ent in enumerate(constraints):
constraints[j] = ent.replace('_', ' ')
degree = self.m.reader.db_search(constraints)
degree = self.m.reader.db_search(constraints[1:], constraints[0] if constraints else 'restaurant')
degree_input_list = self.m.reader._degree_vec_mapping(len(degree))
degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0)))
return degree, degree_input
......
......@@ -767,30 +767,25 @@ class MultiWozReader(_ReaderBase):
self.result_file = ''
def _get_tokenized_data(self, raw_data, db_data, construct_vocab):
requestable_keys = ['addr', 'area', 'fee', 'name', 'phone', 'post', 'price', 'type', 'department', 'internet', 'parking', 'stars', 'food', 'arrive', 'day', 'depart', 'dest', 'leave', 'ticket', 'id']
tokenized_data = []
vk_map = self._value_key_map(db_data)
for dial_id, dial in enumerate(raw_data):
tokenized_dial = []
for turn in dial['dial']:
turn_num = turn['turn']
constraint = []
constraint = [turn['domain']]
requested = []
for slot_act in turn['usr']['slu']:
if slot_act == 'inform':
slot_values = turn['usr']['slu'][slot_act]
for v in slot_values:
s = v[1]
if s not in ['dont_care', 'none']:
constraint.append(s)
if v[1] not in ['do_nt_care', 'none']:
constraint.append(v[1])
elif slot_act == 'request':
slot_values = turn['usr']['slu'][slot_act]
for v in slot_values:
s = v[0]
if s in requestable_keys:
requested.append(s)
degree = len(self.db_search(constraint))
requested.append(v[0])
degree = len(self.db_search(constraint[1:], constraint[0]))
requested = sorted(requested)
constraint.append('EOS_Z1')
requested.append('EOS_Z2')
......@@ -837,7 +832,7 @@ class MultiWozReader(_ReaderBase):
string = re.sub(r'_+', '_', string)
string = re.sub(r'children', 'child_-s', string)
return string
requestable_dict = {'address':'addr',
slot_dict = {'address':'addr',
'area':'area',
'entrance fee':'fee',
'name':'name',
......@@ -856,12 +851,15 @@ class MultiWozReader(_ReaderBase):
'destination':'dest',
'leaveAt':'leave',
'price':'ticket',
'trainId':'id'}
'trainId':'id',
'time':'time',
'ref':'ref'}
value_key = {}
for db_entry in db_data:
for domain in db_data:
for db_entry in db_data[domain]:
for k, v in db_entry.items():
if k in requestable_dict:
value_key[normal(v)] = requestable_dict[k]
if k in slot_dict:
value_key[normal(v)] = slot_dict[k]
return value_key
def _get_encoded_data(self, tokenized_data):
......@@ -897,7 +895,8 @@ class MultiWozReader(_ReaderBase):
return encoded_data
def _get_clean_db(self, raw_db_data):
for entry in raw_db_data:
for domain in raw_db_data:
for entry in raw_db_data[domain]:
for k, v in list(entry.items()):
if not isinstance(v, str) or v == '?':
entry.pop(k)
......@@ -921,13 +920,14 @@ class MultiWozReader(_ReaderBase):
dev_raw_data = json.loads(f.read().lower())
with open(test_json_path) as f:
test_raw_data = json.loads(f.read().lower())
db_data = list()
db_data = dict()
for domain_db_json_path in db_json_path:
with open(domain_db_json_path) as f:
db_data_domain = json.loads(f.read().lower())
for i, item in enumerate(db_data_domain):
item['ref'] = f'{i:08d}'
db_data += db_data_domain
domain=domain_db_json_path.split('/')[-1][:-8]
db_data[domain] = db_data_domain
self._get_clean_db(db_data)
self.db = db_data
......@@ -946,9 +946,14 @@ class MultiWozReader(_ReaderBase):
random.shuffle(self.dev)
random.shuffle(self.test)
def db_search(self, constraints):
def db_search(self, constraints, domain):
if domain == 'taxi':
match_results = [{'phone':'0123456789','car':'black toyota'}]
return match_results
elif domain not in self.db:
return []
match_results = []
for entry in self.db:
for entry in self.db[domain]:
entry_values = ' '.join(entry.values())
match = True
for c in constraints:
......@@ -993,7 +998,7 @@ class MultiWozReader(_ReaderBase):
in constraint_request else constraint_request
for j, ent in enumerate(constraints):
constraints[j] = ent.replace('_', ' ')
degree = self.db_search(constraints)
degree = self.db_search(constraints[1:], constraints[0] if constraints else 'restaurant')
#print('constraints',constraints)
#print('degree',degree)
venue = random.sample(degree, 1)[0] if degree else dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment