Skip to content
Snippets Groups Projects
Commit ec1dd65e authored by zqwerty's avatar zqwerty
Browse files

add load_rg_data

parent ba27f92d
No related branches found
No related tags found
No related merge requests found
from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
\ No newline at end of file \ No newline at end of file
...@@ -167,6 +167,13 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si ...@@ -167,6 +167,13 @@ def load_e2e_data(dataset, data_split='all', speaker='system', context_window_si
kwargs.setdefault('dialogue_acts', True) kwargs.setdefault('dialogue_acts', True)
return load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
def load_rg_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', True)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
return load_unified_data(dataset, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
dataset = load_dataset('multiwoz21') dataset = load_dataset('multiwoz21')
......
...@@ -41,6 +41,7 @@ setup( ...@@ -41,6 +41,7 @@ setup(
'scipy', 'scipy',
'torch>=1.6', 'torch>=1.6',
'transformers>=4.0', 'transformers>=4.0',
'datasets>=1.8',
'spacy', 'spacy',
'allennlp', 'allennlp',
'simplejson', 'simplejson',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment