diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 341d21e0d75a0327c55939891c90b0c3ad5b2642..0683e780db8be7ce8ef7b00bcd40e5b47f4fc57c 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -16,23 +16,29 @@ class BaseDatabase(ABC): """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" -def load_dataset(dataset_name:str) -> Tuple[List, Dict]: +def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: """load unified datasets from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: - dialogues (list): each element is a dialog in unified format + dataset (dict): keys are data splits and the values are lists of dialogues ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) - with archive.open('data/dialogues.json') as f: - dialogues = json.loads(f.read()) with archive.open('data/ontology.json') as f: ontology = json.loads(f.read()) - return dialogues, ontology + with archive.open('data/dialogues.json') as f: + dialogues = json.loads(f.read()) + dataset = {} + for dialogue in dialogues: + if dialogue['data_split'] not in dataset: + dataset[dialogue['data_split']] = [dialogue] + else: + dataset[dialogue['data_split']].append(dialogue) + return dataset, ontology def load_database(dataset_name:str): """load database from `data/unified_datasets/$dataset_name` @@ -43,17 +49,21 @@ def load_database(dataset_name:str): Returns: database: an instance of BaseDatabase """ - data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - cwd = os.getcwd() - os.chdir(data_dir) - Database = importlib.import_module('database').Database - os.chdir(cwd) + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) + module_spec = importlib.util.spec_from_file_location('database', data_dir) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + Database = module.Database + assert issubclass(Database, BaseDatabase) database = Database() assert isinstance(database, BaseDatabase) return database if __name__ == "__main__": - dialogues, ontology = load_dataset('multiwoz21') + # dataset, ontology = load_dataset('multiwoz21') + # print(dataset.keys()) + # print(len(dataset['train'])) + from convlab2.util.unified_datasets_util import BaseDatabase database = load_database('multiwoz21') res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res[0], len(res)) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index df077aeac4d82764b0220e96a8287bc3df6e4588..7400504a8fc797164f7f495bbedc6bfe4399997d 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase ```python from convlab2 import load_dataset, load_database -dialogues, ontology = load_dataset('multiwoz21') +dataset, ontology = load_dataset('multiwoz21') database = load_database('multiwoz21') ``` +`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below. + Each dataset contains at least these files: - `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format. diff --git a/data/unified_datasets/check.py b/data/unified_datasets/check.py index 47e75e602b5009e0a6b9bbf53b7a1291491522a6..4809c857041a9144012880627b7f8745dd70c075 100644 --- a/data/unified_datasets/check.py +++ b/data/unified_datasets/check.py @@ -315,10 +315,8 @@ if __name__ == '__main__': if args.preprocess: print('pre-processing') - os.chdir(name) preprocess = importlib.import_module(f'{name}.preprocess') preprocess.preprocess() - os.chdir('..') data_file = f'{name}/data.zip' if not os.path.exists(data_file): diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index dcb3d70225faa4a593206c5783057884b71ec65e..43ea5896285ebf9e8e38f99c89823bf6538a2bad 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -35,7 +35,7 @@ class Database(BaseDatabase): 'leaveAt': 'leave at' } - def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): + def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" # query the db if domain == 'taxi':