diff --git a/convlab2/util/__init__.py b/convlab2/util/__init__.py index 66c7233217086bea51d8fe8b5952fe70dca0b4bc..8c48bc11d3892be2c8a0511a6b833e9e9c8b24ff 100755 --- a/convlab2/util/__init__.py +++ b/convlab2/util/__init__.py @@ -1,2 +1,2 @@ from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ - load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data \ No newline at end of file + load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data \ No newline at end of file diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index b81fd17792ad212bc3ff8832106e5b3707c7de0c..165b29af217bf28f4e3d993859ccbec3989a518b 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -167,6 +167,13 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) +def load_rg_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): + kwargs.setdefault('data_split', data_split) + kwargs.setdefault('speaker', speaker) + kwargs.setdefault('use_context', True) + kwargs.setdefault('context_window_size', context_window_size) + kwargs.setdefault('utterance', True) + return load_unified_data(dataset, **kwargs) if __name__ == "__main__": dataset = load_dataset('multiwoz21') diff --git a/setup.py b/setup.py index bb24ca3037a6016759e470fff5764eff9e3e219d..3fe3b3c914f71c0db87bf5217930bfd3b6831c45 100755 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ setup( 'scipy', 'torch>=1.6', 'transformers>=4.0', + 'datasets>=1.8', 'spacy', 'allennlp', 'simplejson',