diff --git a/convlab/util/__init__.py b/convlab/util/__init__.py
index 6a84b7db276389d9bbcd6ba097a0b7bb00440a48..1688e21b80c08562a3ca1e5ca45fd181b57cbc98 100755
--- a/convlab/util/__init__.py
+++ b/convlab/util/__init__.py
@@ -1,3 +1 @@
-from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
-    load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \
-    download_unified_datasets, relative_import_module_from_unified_datasets
\ No newline at end of file
+from convlab.util.unified_datasets_util import *
\ No newline at end of file
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index 1e3b0c20bd959ea3098b07b813ed98189aac840f..e24658410738b290da97149382c8c89030936679 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -65,12 +65,14 @@ def relative_import_module_from_unified_datasets(dataset_name, filename, names2i
             variables.append(eval(f'module.{name}'))
         return variables
 
-def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
+def load_dataset(dataset_name:str, dial_ids_order=None, split2ratio={}) -> Dict:
     """load unified dataset from `data/unified_datasets/$dataset_name`
 
     Args:
         dataset_name (str): unique dataset name in `data/unified_datasets`
         dial_ids_order (int): idx of shuffled dial order in `data/unified_datasets/$dataset_name/shuffled_dial_ids.json`
+        split2ratio (dict): a dictionary that maps the data split to the ratio of the data you want to use. 
+            For example, if you want to use only half of the training data, you can set split2ratio = {'train': 0.5}
 
     Returns:
         dataset (dict): keys are data splits and the values are lists of dialogues
@@ -86,13 +88,17 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
         data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir)
         dial_ids = json.load(open(data_path))[dial_ids_order]
         for data_split in dial_ids:
-            dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]]
+            ratio = split2ratio.get(data_split, 1)
+            dataset[data_split] = [dialogues[i] for i in dial_ids[data_split][:round(len(dial_ids[data_split])*ratio)]]
     else:
         for dialogue in dialogues:
             if dialogue['data_split'] not in dataset:
                 dataset[dialogue['data_split']] = [dialogue]
             else:
                 dataset[dialogue['data_split']].append(dialogue)
+        for data_split in dataset:
+            if data_split in split2ratio:
+                dataset[data_split] = dataset[data_split][:round(len(dataset[data_split])*split2ratio[data_split])]
     return dataset
 
 def load_ontology(dataset_name:str) -> Dict: