diff --git a/convlab2/__init__.py b/convlab2/__init__.py index 87a7442310d0d5bad9dbeae9b1b29041d4490067..0fe7d5bf01a7ac06fff476149427712f8c7c7ee4 100755 --- a/convlab2/__init__.py +++ b/convlab2/__init__.py @@ -6,6 +6,7 @@ from convlab2.policy import Policy from convlab2.nlg import NLG from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Session, BiSession, DealornotSession +from convlab2.util.unified_datasets_util import load_dataset, load_database from os.path import abspath, dirname diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py new file mode 100644 index 0000000000000000000000000000000000000000..921c6c451042c754450cf93afd7e2a06055026b2 --- /dev/null +++ b/convlab2/util/unified_datasets_util.py @@ -0,0 +1,28 @@ +from zipfile import ZipFile +import json +import os +import importlib + +def load_dataset(dataset_name): + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + archive = ZipFile(os.path.join(data_dir, 'data.zip')) + with archive.open('data/dialogues.json') as f: + dialogues = json.loads(f.read()) + with archive.open('data/ontology.json') as f: + ontology = json.loads(f.read()) + return dialogues, ontology + +def load_database(dataset_name): + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + cwd = os.getcwd() + os.chdir(data_dir) + Database = importlib.import_module('database').Database + os.chdir(cwd) + database = Database() + return database + +if __name__ == "__main__": + dialogues, ontology = load_dataset('multiwoz21') + database = load_database('multiwoz21') + res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) + print(res[0], len(res)) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 7406169db869929dda862b04794a7e8412e3a227..615d1b52f68168351cb7541339ebdf9670cb28fb 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -1,7 +1,14 @@ # Unified data format ## Overview -We transform different datasets into a unified format under `data/unified_datasets` directory. +We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets: + +```python +from convlab2 import load_dataset, load_database + +dialogues, ontology = load_dataset('multiwoz21') +database = load_database('multiwoz21') +``` Each dataset contains at least these files: