From ec1dd65e151c3d06078252e150a540b526fb678f Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 27 Dec 2021 02:33:02 +0000
Subject: [PATCH] add load_rg_data

---
 convlab2/util/__init__.py              | 2 +-
 convlab2/util/unified_datasets_util.py | 7 +++++++
 setup.py                               | 1 +
 3 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/convlab2/util/__init__.py b/convlab2/util/__init__.py
index 66c72332..8c48bc11 100755
--- a/convlab2/util/__init__.py
+++ b/convlab2/util/__init__.py
@@ -1,2 +1,2 @@
 from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
-    load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data
\ No newline at end of file
+    load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
\ No newline at end of file
diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py
index b81fd177..165b29af 100644
--- a/convlab2/util/unified_datasets_util.py
+++ b/convlab2/util/unified_datasets_util.py
@@ -167,6 +167,13 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si
     kwargs.setdefault('dialogue_acts', True)
     return load_unified_data(dataset, **kwargs)
 
+def load_rg_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
+    kwargs.setdefault('data_split', data_split)
+    kwargs.setdefault('speaker', speaker)
+    kwargs.setdefault('use_context', True)
+    kwargs.setdefault('context_window_size', context_window_size)
+    kwargs.setdefault('utterance', True)
+    return load_unified_data(dataset, **kwargs)
 
 if __name__ == "__main__":
     dataset = load_dataset('multiwoz21')
diff --git a/setup.py b/setup.py
index bb24ca30..3fe3b3c9 100755
--- a/setup.py
+++ b/setup.py
@@ -41,6 +41,7 @@ setup(
         'scipy',
         'torch>=1.6',
         'transformers>=4.0',
+        'datasets>=1.8',
         'spacy',
         'allennlp',
         'simplejson',
-- 
GitLab