Skip to content
Snippets Groups Projects
Commit 9a486a52 authored by zqwerty's avatar zqwerty
Browse files

update multiwoz21 preprocess and add dbquery

parent 44c1ca94
No related branches found
No related tags found
No related merge requests found
No preview for this file type
import json
import os
import random
from fuzzywuzzy import fuzz
from itertools import chain
from zipfile import ZipFile
from copy import deepcopy
class Database:
def __init__(self):
"""extract data.zip and load the database."""
archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip'))
domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'police']
self.dbs = {}
for domain in domains:
with archive.open('data/{}_db.json'.format(domain)) as f:
self.dbs[domain] = json.loads(f.read())
# add some missing information
self.dbs['taxi'] = {
"taxi_colors": ["black","white","red","yellow","blue","grey"],
"taxi_types": ["toyota","skoda","bmw","honda","ford","audi","lexus","volvo","volkswagen","tesla"],
"taxi_phone": ["^[0-9]{10}$"]
}
self.dbs['police'][0]['postcode'] = "cb11jg"
for entity in self.dbs['hospital']:
entity['postcode'] = "cb20qq"
entity['address'] = "Hills Rd, Cambridge"
self.dbattr2slot = {
'openhours': 'open hours',
'pricerange': 'price range',
'arriveBy': 'arrive by',
'leaveAt': 'leave at'
}
def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
# query the db
if domain == 'taxi':
return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']),
'taxi_types': random.choice(self.dbs[domain]['taxi_types']),
'taxi_phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}]
if domain == 'police':
return deepcopy(self.dbs['police'])
if domain == 'hospital':
department = None
for key, val in state:
if key == 'department':
department = val
if not department:
return deepcopy(self.dbs['hospital'])
else:
return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()]
state = list(map(lambda ele: ele if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state))
found = []
for i, record in enumerate(self.dbs[domain]):
constraints_iterator = zip(state, [False] * len(state))
soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints))
for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator):
if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]:
pass
else:
try:
record_keys = [self.dbattr2slot.get(k, k) for k in record]
if key.lower() not in record_keys:
continue
if key == 'leave at':
val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1])
if val1 > val2:
break
elif key == 'arrive by':
val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1])
if val1 < val2:
break
# elif ignore_open and key in ['destination', 'departure', 'name']:
elif ignore_open and key in ['destination', 'departure']:
continue
elif record[key].strip() == '?':
# '?' matches any constraint
continue
else:
if not fuzzy_match:
if val.strip().lower() != record[key].strip().lower():
break
else:
if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio:
break
except:
continue
else:
res = deepcopy(record)
res['Ref'] = '{0:08d}'.format(i)
found.append(res)
if len(found) == topk:
return found
return found
if __name__ == '__main__':
db = Database()
res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res, len(res))
# print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']]))
import copy import copy
import re import re
from zipfile import ZipFile, ZIP_DEFLATED from zipfile import ZipFile, ZIP_DEFLATED
from shutil import copy2 from shutil import copy2, rmtree
import json import json
import os import os
from tqdm import tqdm from tqdm import tqdm
...@@ -684,7 +684,6 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): ...@@ -684,7 +684,6 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer):
}) })
# correct some value and try to give char level span # correct some value and try to give char level span
match = False match = False
ori_value = value
value = value.lower() value = value.lower()
if span and span[0] <= span[1]: if span and span[0] <= span[1]:
# use original span annotation, but tokenizations are different # use original span annotation, but tokenizations are different
...@@ -813,7 +812,7 @@ def preprocess(): ...@@ -813,7 +812,7 @@ def preprocess():
} }
for turn_id, turn in enumerate(ori_dialog['log']): for turn_id, turn in enumerate(ori_dialog['log']):
# correct some grammar error in text, mainly follow tokenization.md in MultiWOZ_2.1 # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1
text = turn['text'] text = turn['text']
text = re.sub(" Im ", " I'm ", text) text = re.sub(" Im ", " I'm ", text)
text = re.sub(" im ", " i'm ", text) text = re.sub(" im ", " i'm ", text)
...@@ -877,13 +876,15 @@ def preprocess(): ...@@ -877,13 +876,15 @@ def preprocess():
dialogues = [] dialogues = []
for split in splits: for split in splits:
dialogues += dialogues_by_split[split] dialogues += dialogues_by_split[split]
init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in init_ontology['binary_dialogue_acts']] init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(init_ontology['binary_dialogue_acts'])]
json.dump(dialogues[:10], open(f'dummy_data.json', 'w'), indent=2) json.dump(dialogues[:10], open(f'dummy_data.json', 'w'), indent=2)
json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w'), indent=2) json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w'), indent=2)
json.dump(init_ontology, open(f'{new_data_dir}/ontology.json', 'w'), indent=2) json.dump(init_ontology, open(f'{new_data_dir}/ontology.json', 'w'), indent=2)
with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
for filename in os.listdir(new_data_dir): for filename in os.listdir(new_data_dir):
zf.write(f'{new_data_dir}/{filename}') zf.write(f'{new_data_dir}/{filename}')
rmtree(original_data_dir)
rmtree(new_data_dir)
return dialogues, init_ontology return dialogues, init_ontology
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment