diff --git a/convlab2/__init__.py b/convlab2/__init__.py index a7b4c3c2d972cbb2dab7b091a099e8591b50895c..87a7442310d0d5bad9dbeae9b1b29041d4490067 100755 --- a/convlab2/__init__.py +++ b/convlab2/__init__.py @@ -6,8 +6,6 @@ from convlab2.policy import Policy from convlab2.nlg import NLG from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Session, BiSession, DealornotSession -from convlab2.util.unified_datasets_util import load_dataset, load_database, load_unified_data, \ - load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data from os.path import abspath, dirname diff --git a/convlab2/util/__init__.py b/convlab2/util/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..66c7233217086bea51d8fe8b5952fe70dca0b4bc 100755 --- a/convlab2/util/__init__.py +++ b/convlab2/util/__init__.py @@ -0,0 +1,2 @@ +from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ + load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data \ No newline at end of file diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index d6d637aa1c60257e861abea37d3f06e52c2f3c71..b81fd17792ad212bc3ff8832106e5b3707c7de0c 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -18,20 +18,17 @@ class BaseDatabase(ABC): """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" -def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: - """load unified datasets from `data/unified_datasets/$dataset_name` +def load_dataset(dataset_name:str) -> Dict: + """load unified dataset from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: dataset (dict): keys are data splits and the values are lists of dialogues - ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) - with archive.open('data/ontology.json') as f: - ontology = json.loads(f.read()) with archive.open('data/dialogues.json') as f: dialogues = json.loads(f.read()) dataset = {} @@ -40,7 +37,22 @@ def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: dataset[dialogue['data_split']] = [dialogue] else: dataset[dialogue['data_split']].append(dialogue) - return dataset, ontology + return dataset + +def load_ontology(dataset_name:str) -> Dict: + """load unified ontology from `data/unified_datasets/$dataset_name` + + Args: + dataset_name (str): unique dataset name in `data/unified_datasets` + + Returns: + ontology (dict): dataset ontology + """ + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + archive = ZipFile(os.path.join(data_dir, 'data.zip')) + with archive.open('data/ontology.json') as f: + ontology = json.loads(f.read()) + return ontology def load_database(dataset_name:str): """load database from `data/unified_datasets/$dataset_name` @@ -107,62 +119,57 @@ def load_unified_data( return data_by_split def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): - kwargs['data_split'] = data_split - kwargs['speaker'] = speaker - kwargs['use_context'] = use_context - kwargs['context_window_size'] = context_window_size - kwargs['utterance'] = True - kwargs['dialogue_acts'] = True - data_by_split = load_unified_data(dataset, **kwargs) - return data_by_split + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', use_context) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + kwargs.setdefault('dialogue_acts', True) + return load_unified_data(dataset, **kwargs) def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs): - kwargs['data_split'] = data_split - kwargs['speaker'] = speaker - kwargs['use_context'] = True - kwargs['context_window_size'] = context_window_size - kwargs['utterance'] = True - kwargs['state'] = True - data_by_split = load_unified_data(dataset, **kwargs) - return data_by_split + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', True) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + kwargs.setdefault('state', True) + return load_unified_data(dataset, **kwargs) def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs): - kwargs['data_split'] = data_split - kwargs['speaker'] = speaker - kwargs['use_context'] = True - kwargs['context_window_size'] = context_window_size - kwargs['utterance'] = True - kwargs['state'] = True - kwargs['db_results'] = True - kwargs['dialogue_acts'] = True - data_by_split = load_unified_data(dataset, **kwargs) - return data_by_split + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', True) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + kwargs.setdefault('state', True) + kwargs.setdefault('db_results', True) + kwargs.setdefault('dialogue_acts', True) + return load_unified_data(dataset, **kwargs) def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs): - kwargs['data_split'] = data_split - kwargs['speaker'] = speaker - kwargs['use_context'] = use_context - kwargs['context_window_size'] = context_window_size - kwargs['utterance'] = True - kwargs['dialogue_acts'] = True - data_by_split = load_unified_data(dataset, **kwargs) - return data_by_split + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', use_context) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + kwargs.setdefault('dialogue_acts', True) + return load_unified_data(dataset, **kwargs) def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): - kwargs['data_split'] = data_split - kwargs['speaker'] = speaker - kwargs['use_context'] = True - kwargs['context_window_size'] = context_window_size - kwargs['utterance'] = True - kwargs['state'] = True - kwargs['db_results'] = True - kwargs['dialogue_acts'] = True - data_by_split = load_unified_data(dataset, **kwargs) - return data_by_split + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', True) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + kwargs.setdefault('state', True) + kwargs.setdefault('db_results', True) + kwargs.setdefault('dialogue_acts', True) + return load_unified_data(dataset, **kwargs) if __name__ == "__main__": - dataset, ontology = load_dataset('multiwoz21') + dataset = load_dataset('multiwoz21') print(dataset.keys()) print(len(dataset['test'])) @@ -171,5 +178,5 @@ if __name__ == "__main__": res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res[0], len(res)) - data_by_split = load_e2e_data(dataset, data_split='test') - pprint(data_by_split['test'][3]) + data_by_split = load_nlu_data(dataset, data_split='test', speaker='user') + pprint(data_by_split['test'][0]) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 7400504a8fc797164f7f495bbedc6bfe4399997d..082eb2a684662028cade7934799ba947c407ea70 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -4,9 +4,10 @@ We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets: ```python -from convlab2 import load_dataset, load_database +from convlab2.util import load_dataset, load_ontology, load_database -dataset, ontology = load_dataset('multiwoz21') +dataset = load_dataset('multiwoz21') +ontology = load_ontology('multiwoz21') database = load_database('multiwoz21') ```