diff --git a/data/unified_datasets/reddit/preprocess.py b/data/unified_datasets/reddit/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa1f03fe584590645bf2b0d7b9548581baedff4 --- /dev/null +++ b/data/unified_datasets/reddit/preprocess.py @@ -0,0 +1,75 @@ +import gzip +import json +from zipfile import ZipFile, ZIP_DEFLATED +import os +from shutil import rmtree +from tqdm import tqdm +import io + +def preprocess(): + original_data_dir = 'dstc8-reddit-corpus' + new_data_dir = 'data' + os.makedirs(new_data_dir, exist_ok=True) + + dataset = 'reddit' + splits = ['train', 'validation'] + dialogues_by_split = {split:[] for split in splits} + + ontology = { + 'domains': {}, + 'intents': {}, + 'state': {}, + "dialogue_acts": { + "categorical": {}, + "non-categorical": {}, + "binary": {} + } + } + + def process_dial(line, dial_id, data_split): + item = json.loads(line) + dialogue = { + 'dataset': dataset, + 'data_split': data_split, + 'dialogue_id': dial_id, + 'original_id': item['id'], + 'topic': item['domain'], + 'turns': [] + } + for i, utterance in enumerate(item['turns']): + if len(utterance) > 256: + # remove dialogs that contain too long utterances + return None + speaker = 'system' if i % 2 == 1 else 'user' + turn = { + 'speaker': speaker, + 'utterance': utterance.strip(), + 'utt_idx': len(dialogue['turns']), + } + dialogue['turns'].append(turn) + return dialogue + + for data_split, filename in zip(['train', 'validation'], ['training', 'validation_date_out_domain_out']): + with ZipFile(os.path.join(original_data_dir, f'{filename}.zip')) as zip_file: + for file in zip_file.namelist(): + with io.TextIOWrapper(zip_file.open(file), encoding="utf-8") as f: + for line in f: + dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}' + dialogue = process_dial(line, dial_id, data_split) + if dialogue: + dialogues_by_split[data_split].append(dialogue) + + dialogues = dialogues_by_split['train']+dialogues_by_split['validation'] + json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: + for filename in os.listdir(new_data_dir): + zf.write(f'{new_data_dir}/{filename}') + # rmtree(original_data_dir) + rmtree(new_data_dir) + return dialogues, ontology + + +if __name__ == '__main__': + preprocess() diff --git a/data/unified_datasets/wikidialog/preprocess.py b/data/unified_datasets/wikidialog/preprocess.py index fc2b0b73bbb52a941f041de4bd7194b1b79d9103..82e05e110f1b9a07f6f090302784d0f60d032f16 100644 --- a/data/unified_datasets/wikidialog/preprocess.py +++ b/data/unified_datasets/wikidialog/preprocess.py @@ -36,6 +36,9 @@ def preprocess(): 'turns': [] } for speaker, utterance in zip(item['author_num'], item['utterances']): + if len(utterance) > 256: + # remove dialogs that contain too long utterances + return None speaker = 'system' if speaker == 0 else 'user' turn = { 'speaker': speaker, @@ -46,21 +49,25 @@ def preprocess(): return dialogue data_split = 'train' - for shard in tqdm(range(1)): + for shard in tqdm(range(99)): with gzip.open(f'{original_data_dir}/data_train.jsonl-000{shard:02}-of-00099.gz','r') as fin: for line in fin: dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}' dialogue = process_dial(line, dial_id, data_split) - dialogues_by_split[data_split].append(dialogue) + if dialogue: + dialogues_by_split[data_split].append(dialogue) + if len(dialogues_by_split[data_split]) >= 1e6: + break + if len(dialogues_by_split[data_split]) >= 1e6: + break data_split = 'validation' with gzip.open(f'{original_data_dir}/data_validation.jsonl.gz','r') as fin: for line in fin: + dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}' dialogue = process_dial(line, dial_id, data_split) - dialogue['dialogue_id'] = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}' - dialogues_by_split[data_split].append(dialogue) - if len(dialogues_by_split[data_split]) >= len(dialogues_by_split['train']) // 10: - break + if dialogue: + dialogues_by_split[data_split].append(dialogue) dialogues = dialogues_by_split['train']+dialogues_by_split['validation'] json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)