From 7ae31c8bb613273f6ee90954f5f4488b129490f1 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Wed, 15 Dec 2021 09:02:28 +0000 Subject: [PATCH] add load_dataset interface --- convlab2/__init__.py | 1 + convlab2/util/unified_datasets_util.py | 28 ++++++++++++++++++++++++++ data/unified_datasets/README.md | 9 ++++++++- 3 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 convlab2/util/unified_datasets_util.py diff --git a/convlab2/__init__.py b/convlab2/__init__.py index 87a74423..0fe7d5bf 100755 --- a/convlab2/__init__.py +++ b/convlab2/__init__.py @@ -6,6 +6,7 @@ from convlab2.policy import Policy from convlab2.nlg import NLG from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Session, BiSession, DealornotSession +from convlab2.util.unified_datasets_util import load_dataset, load_database from os.path import abspath, dirname diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py new file mode 100644 index 00000000..921c6c45 --- /dev/null +++ b/convlab2/util/unified_datasets_util.py @@ -0,0 +1,28 @@ +from zipfile import ZipFile +import json +import os +import importlib + +def load_dataset(dataset_name): + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + archive = ZipFile(os.path.join(data_dir, 'data.zip')) + with archive.open('data/dialogues.json') as f: + dialogues = json.loads(f.read()) + with archive.open('data/ontology.json') as f: + ontology = json.loads(f.read()) + return dialogues, ontology + +def load_database(dataset_name): + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + cwd = os.getcwd() + os.chdir(data_dir) + Database = importlib.import_module('database').Database + os.chdir(cwd) + database = Database() + return database + +if __name__ == "__main__": + dialogues, ontology = load_dataset('multiwoz21') + database = load_database('multiwoz21') + res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) + print(res[0], len(res)) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 7406169d..615d1b52 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -1,7 +1,14 @@ # Unified data format ## Overview -We transform different datasets into a unified format under `data/unified_datasets` directory. +We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets: + +```python +from convlab2 import load_dataset, load_database + +dialogues, ontology = load_dataset('multiwoz21') +database = load_database('multiwoz21') +``` Each dataset contains at least these files: -- GitLab