From 7d5388bb8b357120aaf20149c86aa8d2a5b678d7 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Wed, 29 Jun 2022 10:04:24 +0800
Subject: [PATCH] add split2ratio arg in load_dataset

---
 convlab/util/__init__.py              |  4 +---
 convlab/util/unified_datasets_util.py | 10 ++++++++--
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/convlab/util/__init__.py b/convlab/util/__init__.py
index 6a84b7db..1688e21b 100755
--- a/convlab/util/__init__.py
+++ b/convlab/util/__init__.py
@@ -1,3 +1 @@
-from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
-    load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \
-    download_unified_datasets, relative_import_module_from_unified_datasets
\ No newline at end of file
+from convlab.util.unified_datasets_util import *
\ No newline at end of file
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index 1e3b0c20..e2465841 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -65,12 +65,14 @@ def relative_import_module_from_unified_datasets(dataset_name, filename, names2i
             variables.append(eval(f'module.{name}'))
         return variables
 
-def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
+def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
     """load unified dataset from `data/unified_datasets/$dataset_name`
 
     Args:
         dataset_name (str): unique dataset name in `data/unified_datasets`
         dial_ids_order (int): idx of shuffled dial order in `data/unified_datasets/$dataset_name/shuffled_dial_ids.json`
+        split2ratio (dict): a dictionary that maps the data split to the ratio of the data you want to use. 
+            For example, if you want to use only half of the training data, you can set split2ratio = {'train': 0.5}
 
     Returns:
         dataset (dict): keys are data splits and the values are lists of dialogues
@@ -86,13 +88,17 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
         data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir)
         dial_ids = json.load(open(data_path))[dial_ids_order]
         for data_split in dial_ids:
-            dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]]
+            ratio = split2ratio.get(data_split, 1)
+            dataset[data_split] = [dialogues[i] for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]]
     else:
         for dialogue in dialogues:
             if dialogue['data_split'] not in dataset:
                 dataset[dialogue['data_split']] = [dialogue]
             else:
                 dataset[dialogue['data_split']].append(dialogue)
+        for data_split in dataset:
+            if data_split in split2ratio:
+                dataset[data_split] = dataset[data_split][:round(len(dataset[data_split])*split2ratio[data_split])]
     return dataset
 
 def load_ontology(dataset_name:str) -> Dict:
-- 
GitLab