Skip to content
Snippets Groups Projects
Commit 5c849330 authored by zqwerty's avatar zqwerty
Browse files

add function to load data for each module, #14

parent eeb903df
No related branches found
No related tags found
No related merge requests found
from copy import deepcopy
from typing import Dict, List, Tuple
from zipfile import ZipFile
import json
import os
import importlib
from abc import ABC, abstractmethod
from pprint import pprint
class BaseDatabase(ABC):
......@@ -59,11 +61,115 @@ def load_database(dataset_name:str):
assert isinstance(database, BaseDatabase)
return database
def load_unified_data(
dataset,
data_split='all',
speaker='all',
utterance=False,
dialogue_acts=False,
state=False,
db_results=False,
use_context=False,
context_window_size=0,
terminated=False,
goal=False,
active_domains=False
):
data_splits = dataset.keys() if data_split == 'all' else [data_split]
assert speaker in ['user', 'system', 'all']
assert not use_context or context_window_size > 0
info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
data_by_split = {}
for data_split in data_splits:
data_by_split[data_split] = []
for dialogue in dataset[data_split]:
context = []
for turn in dialogue['turns']:
sample = {'speaker': turn['speaker']}
for ele in info_list:
if ele in turn:
sample[ele] = turn[ele]
if use_context:
sample_copy = deepcopy(sample)
context.append(sample_copy)
if speaker == turn['speaker'] or speaker == 'all':
if use_context:
sample['context'] = context[-context_window_size-1:-1]
if goal:
sample['goal'] = dialogue['goal']
if active_domains:
sample['domains'] = dialogue['domains']
if terminated:
sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1
data_by_split[data_split].append(sample)
return data_by_split
def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = use_context
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
kwargs['db_results'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = use_context
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
kwargs['db_results'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
if __name__ == "__main__":
# dataset, ontology = load_dataset('multiwoz21')
# print(dataset.keys())
# print(len(dataset['train']))
dataset, ontology = load_dataset('multiwoz21')
print(dataset.keys())
print(len(dataset['test']))
from convlab2.util.unified_datasets_util import BaseDatabase
database = load_database('multiwoz21')
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res))
data_by_split = load_e2e_data(dataset, data_split='test')
pprint(data_by_split['test'][3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment