diff --git a/convlab/base_models/t5/key2gen/create_data.py b/convlab/base_models/t5/key2gen/create_data.py index cb4e12c8e720f3031808ec0972c38926a617ef68..138808f2a13d4b3fad71b57a2aa7977917f8143c 100644 --- a/convlab/base_models/t5/key2gen/create_data.py +++ b/convlab/base_models/t5/key2gen/create_data.py @@ -83,6 +83,9 @@ def create_personachat_data(dataset, data_dir, args): def create_wow_data(dataset, data_dir, args): data_by_split = dataset os.makedirs(data_dir, exist_ok=True) + data_by_split['test'] = data_by_split['test_seen'] + data_by_split['test_unseen'] + data_by_split.pop('test_seen') + data_by_split.pop('test_unseen') data_splits = data_by_split.keys() for data_split in data_splits: