From ec1dd65e151c3d06078252e150a540b526fb678f Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 27 Dec 2021 02:33:02 +0000 Subject: [PATCH] add load_rg_data --- convlab2/util/__init__.py | 2 +- convlab2/util/unified_datasets_util.py | 7 +++++++ setup.py | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/convlab2/util/__init__.py b/convlab2/util/__init__.py index 66c72332..8c48bc11 100755 --- a/convlab2/util/__init__.py +++ b/convlab2/util/__init__.py @@ -1,2 +1,2 @@ from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ - load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data \ No newline at end of file + load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data \ No newline at end of file diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index b81fd177..165b29af 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -167,6 +167,13 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) +def load_rg_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', True) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + return load_unified_data(dataset, **kwargs) if __name__ == "__main__": dataset = load_dataset('multiwoz21') diff --git a/setup.py b/setup.py index bb24ca30..3fe3b3c9 100755 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ setup( 'scipy', 'torch>=1.6', 'transformers>=4.0', + 'datasets>=1.8', 'spacy', 'allennlp', 'simplejson', -- GitLab