From eeb903df689150b8032c077022c8b0f838f60b80 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 20 Dec 2021 09:50:55 +0000 Subject: [PATCH] add BaseDatabase class for unified datasets --- convlab2/util/unified_datasets_util.py | 32 +++++++++++++------- data/unified_datasets/README.md | 4 ++- data/unified_datasets/check.py | 2 -- data/unified_datasets/multiwoz21/database.py | 2 +- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 341d21e0..0683e780 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -16,23 +16,29 @@ class BaseDatabase(ABC): """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" -def load_dataset(dataset_name:str) -> Tuple[List, Dict]: +def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: """load unified datasets from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: - dialogues (list): each element is a dialog in unified format + dataset (dict): keys are data splits and the values are lists of dialogues ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) - with archive.open('data/dialogues.json') as f: - dialogues = json.loads(f.read()) with archive.open('data/ontology.json') as f: ontology = json.loads(f.read()) - return dialogues, ontology + with archive.open('data/dialogues.json') as f: + dialogues = json.loads(f.read()) + dataset = {} + for dialogue in dialogues: + if dialogue['data_split'] not in dataset: + dataset[dialogue['data_split']] = [dialogue] + else: + dataset[dialogue['data_split']].append(dialogue) + return dataset, ontology def load_database(dataset_name:str): """load database from `data/unified_datasets/$dataset_name` @@ -43,17 +49,21 @@ def load_database(dataset_name:str): Returns: database: an instance of BaseDatabase """ - data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - cwd = os.getcwd() - os.chdir(data_dir) - Database = importlib.import_module('database').Database - os.chdir(cwd) + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) + module_spec = importlib.util.spec_from_file_location('database', data_dir) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + Database = module.Database + assert issubclass(Database, BaseDatabase) database = Database() assert isinstance(database, BaseDatabase) return database if __name__ == "__main__": - dialogues, ontology = load_dataset('multiwoz21') + # dataset, ontology = load_dataset('multiwoz21') + # print(dataset.keys()) + # print(len(dataset['train'])) + from convlab2.util.unified_datasets_util import BaseDatabase database = load_database('multiwoz21') res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res[0], len(res)) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index df077aea..7400504a 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase ```python from convlab2 import load_dataset, load_database -dialogues, ontology = load_dataset('multiwoz21') +dataset, ontology = load_dataset('multiwoz21') database = load_database('multiwoz21') ``` +`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below. + Each dataset contains at least these files: - `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format. diff --git a/data/unified_datasets/check.py b/data/unified_datasets/check.py index 47e75e60..4809c857 100644 --- a/data/unified_datasets/check.py +++ b/data/unified_datasets/check.py @@ -315,10 +315,8 @@ if __name__ == '__main__': if args.preprocess: print('pre-processing') - os.chdir(name) preprocess = importlib.import_module(f'{name}.preprocess') preprocess.preprocess() - os.chdir('..') data_file = f'{name}/data.zip' if not os.path.exists(data_file): diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index dcb3d702..43ea5896 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -35,7 +35,7 @@ class Database(BaseDatabase): 'leaveAt': 'leave at' } - def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): + def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" # query the db if domain == 'taxi': -- GitLab