From 7ae31c8bb613273f6ee90954f5f4488b129490f1 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Wed, 15 Dec 2021 09:02:28 +0000
Subject: [PATCH] add load_dataset interface

---
 convlab2/__init__.py                   |  1 +
 convlab2/util/unified_datasets_util.py | 28 ++++++++++++++++++++++++++
 data/unified_datasets/README.md        |  9 ++++++++-
 3 files changed, 37 insertions(+), 1 deletion(-)
 create mode 100644 convlab2/util/unified_datasets_util.py

diff --git a/convlab2/__init__.py b/convlab2/__init__.py
index 87a74423..0fe7d5bf 100755
--- a/convlab2/__init__.py
+++ b/convlab2/__init__.py
@@ -6,6 +6,7 @@ from convlab2.policy import Policy
 from convlab2.nlg import NLG
 from convlab2.dialog_agent import Agent, PipelineAgent
 from convlab2.dialog_agent import Session, BiSession, DealornotSession
+from convlab2.util.unified_datasets_util import load_dataset, load_database
 
 from os.path import abspath, dirname
 
diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py
new file mode 100644
index 00000000..921c6c45
--- /dev/null
+++ b/convlab2/util/unified_datasets_util.py
@@ -0,0 +1,28 @@
+from zipfile import ZipFile
+import json
+import os
+import importlib
+
+def load_dataset(dataset_name):
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
+    archive = ZipFile(os.path.join(data_dir, 'data.zip'))
+    with archive.open('data/dialogues.json') as f:
+        dialogues = json.loads(f.read())
+    with archive.open('data/ontology.json') as f:
+        ontology = json.loads(f.read())
+    return dialogues, ontology
+
+def load_database(dataset_name):
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
+    cwd = os.getcwd()
+    os.chdir(data_dir)
+    Database = importlib.import_module('database').Database
+    os.chdir(cwd)
+    database = Database()
+    return database
+
+if __name__ == "__main__":
+    dialogues, ontology = load_dataset('multiwoz21')
+    database = load_database('multiwoz21')
+    res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
+    print(res[0], len(res))
diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md
index 7406169d..615d1b52 100644
--- a/data/unified_datasets/README.md
+++ b/data/unified_datasets/README.md
@@ -1,7 +1,14 @@
 # Unified data format
 
 ## Overview
-We transform different datasets into a unified format under `data/unified_datasets` directory.
+We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets:
+
+```python
+from convlab2 import load_dataset, load_database
+
+dialogues, ontology = load_dataset('multiwoz21')
+database = load_database('multiwoz21')
+```
 
 Each dataset contains at least these files:
 
-- 
GitLab