From 6592823138d351b4052d6d8a9b8914e8f912b33f Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Fri, 15 Jul 2022 09:42:01 +0800
Subject: [PATCH] add reddit preprocess, update wikidialog preprocess, rm dials
 that have len(utterance)>256

---
 data/unified_datasets/reddit/preprocess.py    | 75 +++++++++++++++++++
 .../unified_datasets/wikidialog/preprocess.py | 19 +++--
 2 files changed, 88 insertions(+), 6 deletions(-)
 create mode 100644 data/unified_datasets/reddit/preprocess.py

diff --git a/data/unified_datasets/reddit/preprocess.py b/data/unified_datasets/reddit/preprocess.py
new file mode 100644
index 00000000..4aa1f03f
--- /dev/null
+++ b/data/unified_datasets/reddit/preprocess.py
@@ -0,0 +1,75 @@
+import gzip
+import json
+from zipfile import ZipFile, ZIP_DEFLATED
+import os
+from shutil import rmtree
+from tqdm import tqdm
+import io
+
+def preprocess():
+    original_data_dir = 'dstc8-reddit-corpus'
+    new_data_dir = 'data'
+    os.makedirs(new_data_dir, exist_ok=True)
+
+    dataset = 'reddit'
+    splits = ['train', 'validation']
+    dialogues_by_split = {split:[] for split in splits}
+
+    ontology = {
+        'domains': {},
+        'intents': {},
+        'state': {},
+        "dialogue_acts": {
+            "categorical": {},
+            "non-categorical": {},
+            "binary": {}
+        }
+    }
+
+    def process_dial(line, dial_id, data_split):
+        item = json.loads(line)
+        dialogue = {
+            'dataset': dataset,
+            'data_split': data_split,
+            'dialogue_id': dial_id,
+            'original_id': item['id'],
+            'topic': item['domain'],
+            'turns': []
+        }
+        for i, utterance in enumerate(item['turns']):
+            if len(utterance) > 256:
+                # remove dialogs that contain too long utterances
+                return None
+            speaker = 'system' if i % 2 == 1 else 'user'
+            turn = {
+                'speaker': speaker,
+                'utterance': utterance.strip(),
+                'utt_idx': len(dialogue['turns']),
+            }
+            dialogue['turns'].append(turn)
+        return dialogue
+            
+    for data_split, filename in zip(['train', 'validation'], ['training', 'validation_date_out_domain_out']):
+        with ZipFile(os.path.join(original_data_dir, f'{filename}.zip')) as zip_file:
+            for file in zip_file.namelist():
+                with io.TextIOWrapper(zip_file.open(file), encoding="utf-8") as f:
+                    for line in f:
+                        dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
+                        dialogue = process_dial(line, dial_id, data_split)
+                        if dialogue:
+                            dialogues_by_split[data_split].append(dialogue)
+    
+    dialogues = dialogues_by_split['train']+dialogues_by_split['validation']
+    json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+    with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
+        for filename in os.listdir(new_data_dir):
+            zf.write(f'{new_data_dir}/{filename}')
+    # rmtree(original_data_dir)
+    rmtree(new_data_dir)
+    return dialogues, ontology
+
+
+if __name__ == '__main__':
+    preprocess()
diff --git a/data/unified_datasets/wikidialog/preprocess.py b/data/unified_datasets/wikidialog/preprocess.py
index fc2b0b73..82e05e11 100644
--- a/data/unified_datasets/wikidialog/preprocess.py
+++ b/data/unified_datasets/wikidialog/preprocess.py
@@ -36,6 +36,9 @@ def preprocess():
             'turns': []
         }
         for speaker, utterance in zip(item['author_num'], item['utterances']):
+            if len(utterance) > 256:
+                # remove dialogs that contain too long utterances
+                return None
             speaker = 'system' if speaker == 0 else 'user'
             turn = {
                 'speaker': speaker,
@@ -46,21 +49,25 @@ def preprocess():
         return dialogue
             
     data_split = 'train'
-    for shard in tqdm(range(1)):
+    for shard in tqdm(range(99)):
         with gzip.open(f'{original_data_dir}/data_train.jsonl-000{shard:02}-of-00099.gz','r') as fin:
             for line in fin:
                 dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
                 dialogue = process_dial(line, dial_id, data_split)
-                dialogues_by_split[data_split].append(dialogue)
+                if dialogue:
+                    dialogues_by_split[data_split].append(dialogue)
+                if len(dialogues_by_split[data_split]) >= 1e6:
+                    break
+        if len(dialogues_by_split[data_split]) >= 1e6:
+            break
 
     data_split = 'validation'
     with gzip.open(f'{original_data_dir}/data_validation.jsonl.gz','r') as fin:
         for line in fin:
+            dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
             dialogue = process_dial(line, dial_id, data_split)
-            dialogue['dialogue_id'] = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
-            dialogues_by_split[data_split].append(dialogue)
-            if len(dialogues_by_split[data_split]) >= len(dialogues_by_split['train']) // 10:
-                break
+            if dialogue:
+                dialogues_by_split[data_split].append(dialogue)
     
     dialogues = dialogues_by_split['train']+dialogues_by_split['validation']
     json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
-- 
GitLab