diff --git a/README.md b/README.md index 5c31245f80c959e6c09e91398bed57de0bbd6b58..5433d94db5996b48368df4354bcd795ee9461c24 100755 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ Performance (the first row is the default config for each module. Empty entries | BERTNLU | RuleDST | **HDSA** | None | 35.6 | 27.5 | 5.4 | 47.8/57.2/48.8 | 13.0/31.5 | | BERTNLU | RuleDST | **LaRL** | None | 40.6 | 34.0 | 45.6 | 47.8/54.1/47.6 | 15.0/28.6 | | None | **SUMBT** | **LaRL** | None | 39.4| 33.1| 39.5 | 48.5/56.0/48.8| 15.5/28.7| -| None | None | **Sequicity*** | None | 13.1 | 10.5 | 5.1 | 41.4/30.8/31.3 | 12.9/38.3 | +| None | None | **Sequicity*** | None | 21.7 | 14.0 | 4.9 | 36.3/35.1/32.0 | 18.2/35.2 | | None | None | **DAMD*** | None | 38.5 | 33.6 | 50.9 | 62.1/60.7/57.4 | 10.4/28.2 | *: end-to-end models used as sys_agent directly. diff --git a/convlab2/e2e/sequicity/metric.py b/convlab2/e2e/sequicity/metric.py index 7df1cdc03d43c4ff9b047300046e0a6c1e4d9e14..e52c88507d2015f0e8d87ee29ef8d7cc252b94a1 100755 --- a/convlab2/e2e/sequicity/metric.py +++ b/convlab2/e2e/sequicity/metric.py @@ -1,3 +1,4 @@ +import os import argparse import csv import functools @@ -124,6 +125,8 @@ class GenericEvaluator: self.metric_dict = {} self.entity_dict = {} filename = result_path.split('/')[-1] + if not os.path.exists('./sheets/'): + os.makedirs('./sheets/') dump_dir = './sheets/' + filename.replace('.csv','.report.txt') self.dump_file = open(dump_dir,'w') diff --git a/convlab2/e2e/sequicity/multiwoz/README.md b/convlab2/e2e/sequicity/multiwoz/README.md index 31e6a9923695ead4de3308efae265c94035ea9bb..4204c038ed34d9530b9d96b8d5e2091620e43c04 100755 --- a/convlab2/e2e/sequicity/multiwoz/README.md +++ b/convlab2/e2e/sequicity/multiwoz/README.md @@ -61,7 +61,7 @@ In terms of `success F1`, Sequicity by order shows the (F1, Precision, Recall) | BLEU | Match | Success (F1, Prec., Rec.) | | - | - | - | -| 0.0691 | 0.4994 |(0.5059, 0.5925, 0.4414)| +| 0.1506 | 0.6057 |(0.5662, 0.6701, 0.4902)| ## Reference diff --git a/convlab2/e2e/sequicity/multiwoz/configs/multiwoz.json b/convlab2/e2e/sequicity/multiwoz/configs/multiwoz.json index cb6fb21cda12ecdf37dc0556b9a161508322ec7c..2e82d42f0a487246cfb40bf7b9664c877d13be85 100755 --- a/convlab2/e2e/sequicity/multiwoz/configs/multiwoz.json +++ b/convlab2/e2e/sequicity/multiwoz/configs/multiwoz.json @@ -19,7 +19,8 @@ "multiwoz/data/hotel_db.json", "multiwoz/data/restaurant_db.json", "multiwoz/data/hospital_db.json", - "multiwoz/data/train_db.json" + "multiwoz/data/train_db.json", + "multiwoz/data/police_db.json" ], "glove_path": "multiwoz/data/glove.6B.50d.txt", "batch_size": 32, diff --git a/convlab2/e2e/sequicity/multiwoz/sequicity.py b/convlab2/e2e/sequicity/multiwoz/sequicity.py index d62982f0e8e47a80bbd333cfcb124d248395983b..be27609ea877f92b29bccffe0713f03a0e14f634 100755 --- a/convlab2/e2e/sequicity/multiwoz/sequicity.py +++ b/convlab2/e2e/sequicity/multiwoz/sequicity.py @@ -96,7 +96,7 @@ class Sequicity(Agent): in constraint_request else constraint_request for j, ent in enumerate(constraints): constraints[j] = ent.replace('_', ' ') - degree = self.m.reader.db_search(constraints) + degree = self.m.reader.db_search(constraints[1:], constraints[0] if constraints else 'restaurant') degree_input_list = self.m.reader._degree_vec_mapping(len(degree)) degree_input = cuda_(Variable(torch.Tensor(degree_input_list).unsqueeze(0))) return degree, degree_input diff --git a/convlab2/e2e/sequicity/reader.py b/convlab2/e2e/sequicity/reader.py index f6f94a27f77fb9ec6b9727a8b921a057399dc7ce..d35a07070fdad2cd32c9c83dbddcee3418af8862 100755 --- a/convlab2/e2e/sequicity/reader.py +++ b/convlab2/e2e/sequicity/reader.py @@ -767,30 +767,25 @@ class MultiWozReader(_ReaderBase): self.result_file = '' def _get_tokenized_data(self, raw_data, db_data, construct_vocab): - requestable_keys = ['addr', 'area', 'fee', 'name', 'phone', 'post', 'price', 'type', 'department', 'internet', 'parking', 'stars', 'food', 'arrive', 'day', 'depart', 'dest', 'leave', 'ticket', 'id'] - tokenized_data = [] vk_map = self._value_key_map(db_data) for dial_id, dial in enumerate(raw_data): tokenized_dial = [] for turn in dial['dial']: turn_num = turn['turn'] - constraint = [] + constraint = [turn['domain']] requested = [] for slot_act in turn['usr']['slu']: if slot_act == 'inform': slot_values = turn['usr']['slu'][slot_act] for v in slot_values: - s = v[1] - if s not in ['dont_care', 'none']: - constraint.append(s) + if v[1] not in ['do_nt_care', 'none']: + constraint.append(v[1]) elif slot_act == 'request': slot_values = turn['usr']['slu'][slot_act] for v in slot_values: - s = v[0] - if s in requestable_keys: - requested.append(s) - degree = len(self.db_search(constraint)) + requested.append(v[0]) + degree = len(self.db_search(constraint[1:], constraint[0])) requested = sorted(requested) constraint.append('EOS_Z1') requested.append('EOS_Z2') @@ -837,31 +832,34 @@ class MultiWozReader(_ReaderBase): string = re.sub(r'_+', '_', string) string = re.sub(r'children', 'child_-s', string) return string - requestable_dict = {'address':'addr', - 'area':'area', - 'entrance fee':'fee', - 'name':'name', - 'phone':'phone', - 'postcode':'post', - 'pricerange':'price', - 'type':'type', - 'department':'department', - 'internet':'internet', - 'parking':'parking', - 'stars':'stars', - 'food':'food', - 'arriveBy':'arrive', - 'day':'day', - 'departure':'depart', - 'destination':'dest', - 'leaveAt':'leave', - 'price':'ticket', - 'trainId':'id'} + slot_dict = {'address':'addr', + 'area':'area', + 'entrance fee':'fee', + 'name':'name', + 'phone':'phone', + 'postcode':'post', + 'pricerange':'price', + 'type':'type', + 'department':'department', + 'internet':'internet', + 'parking':'parking', + 'stars':'stars', + 'food':'food', + 'arriveBy':'arrive', + 'day':'day', + 'departure':'depart', + 'destination':'dest', + 'leaveAt':'leave', + 'price':'ticket', + 'trainId':'id', + 'time':'time', + 'ref':'ref'} value_key = {} - for db_entry in db_data: - for k, v in db_entry.items(): - if k in requestable_dict: - value_key[normal(v)] = requestable_dict[k] + for domain in db_data: + for db_entry in db_data[domain]: + for k, v in db_entry.items(): + if k in slot_dict: + value_key[normal(v)] = slot_dict[k] return value_key def _get_encoded_data(self, tokenized_data): @@ -897,10 +895,11 @@ class MultiWozReader(_ReaderBase): return encoded_data def _get_clean_db(self, raw_db_data): - for entry in raw_db_data: - for k, v in list(entry.items()): - if not isinstance(v, str) or v == '?': - entry.pop(k) + for domain in raw_db_data: + for entry in raw_db_data[domain]: + for k, v in list(entry.items()): + if not isinstance(v, str) or v == '?': + entry.pop(k) def _construct(self, train_json_path, dev_json_path, test_json_path, db_json_path): """ @@ -921,13 +920,14 @@ class MultiWozReader(_ReaderBase): dev_raw_data = json.loads(f.read().lower()) with open(test_json_path) as f: test_raw_data = json.loads(f.read().lower()) - db_data = list() + db_data = dict() for domain_db_json_path in db_json_path: with open(domain_db_json_path) as f: db_data_domain = json.loads(f.read().lower()) for i, item in enumerate(db_data_domain): item['ref'] = f'{i:08d}' - db_data += db_data_domain + domain=domain_db_json_path.split('/')[-1][:-8] + db_data[domain] = db_data_domain self._get_clean_db(db_data) self.db = db_data @@ -946,9 +946,14 @@ class MultiWozReader(_ReaderBase): random.shuffle(self.dev) random.shuffle(self.test) - def db_search(self, constraints): + def db_search(self, constraints, domain): + if domain == 'taxi': + match_results = [{'phone':'0123456789','car':'black toyota'}] + return match_results + elif domain not in self.db: + return [] match_results = [] - for entry in self.db: + for entry in self.db[domain]: entry_values = ' '.join(entry.values()) match = True for c in constraints: @@ -993,7 +998,7 @@ class MultiWozReader(_ReaderBase): in constraint_request else constraint_request for j, ent in enumerate(constraints): constraints[j] = ent.replace('_', ' ') - degree = self.db_search(constraints) + degree = self.db_search(constraints[1:], constraints[0] if constraints else 'restaurant') #print('constraints',constraints) #print('degree',degree) venue = random.sample(degree, 1)[0] if degree else dict()