diff --git a/convlab/util/__init__.py b/convlab/util/__init__.py index 6a84b7db276389d9bbcd6ba097a0b7bb00440a48..1688e21b80c08562a3ca1e5ca45fd181b57cbc98 100755 --- a/convlab/util/__init__.py +++ b/convlab/util/__init__.py @@ -1,3 +1 @@ -from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ - load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \ - download_unified_datasets, relative_import_module_from_unified_datasets \ No newline at end of file +from convlab.util.unified_datasets_util import * \ No newline at end of file diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index 1e3b0c20bd959ea3098b07b813ed98189aac840f..e24658410738b290da97149382c8c89030936679 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -65,12 +65,14 @@ def relative_import_module_from_unified_datasets(dataset_name, filename, names2i variables.append(eval(f'module.{name}')) return variables -def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: +def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict: """load unified dataset from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` dial_ids_order (int): idx of shuffled dial order in `data/unified_datasets/$dataset_name/shuffled_dial_ids.json` + split2ratio (dict): a dictionary that maps the data split to the ratio of the data you want to use. + For example, if you want to use only half of the training data, you can set split2ratio = {'train': 0.5} Returns: dataset (dict): keys are data splits and the values are lists of dialogues @@ -86,13 +88,17 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) dial_ids = json.load(open(data_path))[dial_ids_order] for data_split in dial_ids: - dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]] + ratio = split2ratio.get(data_split, 1) + dataset[data_split] = [dialogues[i] for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]] else: for dialogue in dialogues: if dialogue['data_split'] not in dataset: dataset[dialogue['data_split']] = [dialogue] else: dataset[dialogue['data_split']].append(dialogue) + for data_split in dataset: + if data_split in split2ratio: + dataset[data_split] = dataset[data_split][:round(len(dataset[data_split])*split2ratio[data_split])] return dataset def load_ontology(dataset_name:str) -> Dict: