Skip to content
Snippets Groups Projects
Commit 7d5388bb authored by zqwerty's avatar zqwerty
Browse files

add split2ratio arg in load_dataset

parent 9f54db83
No related branches found
No related tags found
No related merge requests found
from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \
download_unified_datasets, relative_import_module_from_unified_datasets
\ No newline at end of file
from convlab.util.unified_datasets_util import *
\ No newline at end of file
......@@ -65,12 +65,14 @@ def relative_import_module_from_unified_datasets(dataset_name, filename, names2i
variables.append(eval(f'module.{name}'))
return variables
def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
"""load unified dataset from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
dial_ids_order (int): idx of shuffled dial order in `data/unified_datasets/$dataset_name/shuffled_dial_ids.json`
split2ratio (dict): a dictionary that maps the data split to the ratio of the data you want to use.
For example, if you want to use only half of the training data, you can set split2ratio = {'train': 0.5}
Returns:
dataset (dict): keys are data splits and the values are lists of dialogues
......@@ -86,13 +88,17 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir)
dial_ids = json.load(open(data_path))[dial_ids_order]
for data_split in dial_ids:
dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]]
ratio = split2ratio.get(data_split, 1)
dataset[data_split] = [dialogues[i] for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]]
else:
for dialogue in dialogues:
if dialogue['data_split'] not in dataset:
dataset[dialogue['data_split']] = [dialogue]
else:
dataset[dialogue['data_split']].append(dialogue)
for data_split in dataset:
if data_split in split2ratio:
dataset[data_split] = dataset[data_split][:round(len(dataset[data_split])*split2ratio[data_split])]
return dataset
def load_ontology(dataset_name:str) -> Dict:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment