Skip to content
Snippets Groups Projects
Commit fb67c3bb authored by zqwerty's avatar zqwerty
Browse files

Merge branch 'unified_dataset' into pre-training

parents 3dd7ae72 69c810c3
No related branches found
No related tags found
No related merge requests found
......@@ -19,11 +19,12 @@ class BaseDatabase(ABC):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Dict:
def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
"""load unified dataset from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
dial_ids_order (int): idx of shuffled dial order in `data/unified_datasets/$dataset_name/shuffled_dial_ids.json`
Returns:
dataset (dict): keys are data splits and the values are lists of dialogues
......@@ -33,6 +34,11 @@ def load_dataset(dataset_name:str) -> Dict:
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
dataset = {}
if dial_ids_order is not None:
dial_ids = json.load(open(os.path.join(data_dir, 'shuffled_dial_ids.json')))[dial_ids_order]
for data_split in dial_ids:
dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]]
else:
for dialogue in dialogues:
if dialogue['data_split'] not in dataset:
dataset[dialogue['data_split']] = [dialogue]
......@@ -187,8 +193,12 @@ def load_rg_data(dataset, data_split='all', speaker='system', context_window_siz
return load_unified_data(dataset, **kwargs)
def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_values=['yes', 'no']):
# add delex_utterance to the dataset according to dialogue acts and belief_state
def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore_values=['yes', 'no']):
"""add delex_utterance to the dataset according to dialogue acts and belief_state
delex_func: function that return the placeholder (e.g. "[(domain_name)-(slot_name)]") given (domain, slot, value)
ignore_values: ignored values when delexicalizing using the categorical acts and states
"""
#
def delex_inplace(texts_placeholders, value_pattern):
res = []
......@@ -226,7 +236,7 @@ def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_valu
assert utt[start:end] == value
# make sure there are no words/number prepend & append and no overlap with other spans
if start >= last_end and (start == 0 or re.match('\W', utt[start-1])) and (end == len(utt) or re.match('\W', utt[end])):
placeholder = delex_format.format(domain=domain, slot=slot, value=value)
placeholder = delex_func(domain, slot, value)
delex_vocab.add(placeholder)
delex_utt.append((utt[last_end:start], False))
delex_utt.append((placeholder, True))
......@@ -237,7 +247,7 @@ def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_valu
for da in sorted(turn['dialogue_acts']['categorical'], key=lambda x: len(x['value'])):
domain, slot, value = da['domain'], da['slot'], da['value']
if value.lower() not in ignore_values:
placeholder = delex_format.format(domain=domain, slot=slot, value=value)
placeholder = delex_func(domain, slot, value)
pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
if delex_inplace(delex_utt, pattern):
delex_vocab.add(placeholder)
......@@ -251,7 +261,7 @@ def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_valu
# has value
for value in values.split('|'):
if value.lower() not in ignore_values:
placeholder = delex_format.format(domain=domain, slot=slot, value=value)
placeholder = delex_func(domain, slot, value)
pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
if delex_inplace(delex_utt, pattern):
delex_vocab.add(placeholder)
......@@ -262,7 +272,10 @@ def create_delex_data(dataset, delex_format='[({domain})-({slot})]', ignore_valu
if __name__ == "__main__":
dataset = load_dataset('multiwoz21')
dataset = load_dataset('multiwoz21', dial_ids_order=0)
train_ratio = 0.1
dataset['train'] = dataset['train'][:round(len(dataset['train'])*train_ratio)]
print(len(dataset['train']))
print(dataset.keys())
print(len(dataset['test']))
......@@ -274,7 +287,11 @@ if __name__ == "__main__":
data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
pprint(data_by_split['test'][0])
dataset, delex_vocab = create_delex_data(dataset)
def delex_slot(domain, slot, value):
# only use slot name for delexicalization
return f'[{slot}]'
dataset, delex_vocab = create_delex_data(dataset, delex_slot)
json.dump(dataset['test'], open('new_delex_multiwoz21_test.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(delex_vocab, open('new_delex_vocab.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
with open('new_delex_cmp.txt', 'w') as f:
......
......@@ -42,6 +42,7 @@ if __name__ == '__main__':
- `dialogues.json`: a list of all dialogues in the dataset.
- other necessary files such as databases.
- `dummy_data.json`: a list of 10 dialogues from `dialogues.json` for illustration.
- `shuffled_dial_ids.json`: 10 random shuffled data orders created by `check.py` for experiment reproducibility, can be used in `load_dataset` function by passing the `dial_ids_order` in [0, 9]
Datasets that require database interaction should also include the following file:
- `database.py`: load the database and define the query function:
......@@ -60,7 +61,7 @@ class Database(BaseDatabase):
We first introduce the unified format of `ontology` and `dialogues`. To transform a new dataset into the unified format:
1. Create `data/unified_datasets/$dataset` folder, where `$dataset` is the name of the dataset.
2. Write `preprocess.py` to transform the original dataset into the unified format, producing `data.zip` and `dummy_data.json`.
3. Run `python check.py $dataset` in the `data/unified_datasets` directory to check the validation of processed dataset and get data statistics.
3. Run `python check.py $dataset` in the `data/unified_datasets` directory to check the validation of processed dataset and get data statistics and shuffled dialog ids.
4. Write `README.md` to describe the data following [How to create dataset README](#how-to-create-dataset-readme).
### Ontology
......@@ -120,7 +121,7 @@ Note that multiple descriptions/values are separated by `"|"`.
Other attributes are optional.
> **Necessary**: Run `python check.py $dataset` in the `data/unified_datasets` directory to check the validation of processed dataset and get data statistics in `data/unified_datasets/$dataset/stat.txt`.
> **Necessary**: Run `python check.py $dataset` in the `data/unified_datasets` directory to check the validation of processed dataset and get data statistics in `data/unified_datasets/$dataset/stat.txt` as well as shuffled dialog ids in `data/unified_datasets/$dataset/shuffled_dial_ids.json`.
### How to create dataset README
Each dataset has a README.md to describe the original and transformed data. Please follow the `README_TEMPLATE.md` and make sure that you:
......
This diff is collapsed.
......@@ -4,6 +4,7 @@ from copy import deepcopy
from zipfile import ZipFile
import importlib
from tabulate import tabulate
import random
special_values = ['', 'dontcare', None, '?']
......@@ -279,6 +280,20 @@ def check_dialogues(name, dialogues, ontology):
return tabulate(table, headers='keys', tablefmt='github')
def create_shuffled_dial_ids(dialogues, rng=random.Random(42), num_orders=10):
dial_ids = {}
for i, dialogue in enumerate(dialogues):
dial_ids.setdefault(dialogue['data_split'], [])
dial_ids[dialogue['data_split']].append(i)
id_orders = []
for _ in range(num_orders):
for data_split in dial_ids:
rng.shuffle(dial_ids[data_split])
id_orders.append(deepcopy(dial_ids))
return id_orders
if __name__ == '__main__':
from argparse import ArgumentParser
......@@ -339,6 +354,10 @@ if __name__ == '__main__':
dialogues = json.load(f)
stat = check_dialogues(name, dialogues, ontology)
print('pass')
print('creating shuffled_dial_ids')
id_orders = create_shuffled_dial_ids(dialogues)
with open(os.path.join(name, 'shuffled_dial_ids.json'), 'w', encoding='utf-8') as f:
json.dump(id_orders, f, ensure_ascii=False)
print(f'Please copy and paste the statistics in {name}/stat.txt to dataset README.md->Data Splits section\n')
with open(f'{name}/stat.txt', 'w') as f:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment