Skip to content
Snippets Groups Projects
Commit 30b584aa authored by zqwerty's avatar zqwerty
Browse files

modify load_dataset interface

parent 4f4290da
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,6 @@ from convlab2.policy import Policy
from convlab2.nlg import NLG
from convlab2.dialog_agent import Agent, PipelineAgent
from convlab2.dialog_agent import Session, BiSession, DealornotSession
from convlab2.util.unified_datasets_util import load_dataset, load_database, load_unified_data, \
load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data
from os.path import abspath, dirname
......
from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data
\ No newline at end of file
......@@ -18,20 +18,17 @@ class BaseDatabase(ABC):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
"""load unified datasets from `data/unified_datasets/$dataset_name`
def load_dataset(dataset_name:str) -> Dict:
"""load unified dataset from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
dataset (dict): keys are data splits and the values are lists of dialogues
ontology (dict): dataset ontology
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
dataset = {}
......@@ -40,7 +37,22 @@ def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
dataset[dialogue['data_split']] = [dialogue]
else:
dataset[dialogue['data_split']].append(dialogue)
return dataset, ontology
return dataset
def load_ontology(dataset_name:str) -> Dict:
"""load unified ontology from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
ontology (dict): dataset ontology
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
return ontology
def load_database(dataset_name:str):
"""load database from `data/unified_datasets/$dataset_name`
......@@ -107,62 +119,57 @@ def load_unified_data(
return data_by_split
def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = use_context
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', use_context)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
kwargs.setdefault('dialogue_acts', True)
return load_unified_data(dataset, **kwargs)
def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', True)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
kwargs.setdefault('state', True)
return load_unified_data(dataset, **kwargs)
def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
kwargs['db_results'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', True)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
kwargs.setdefault('state', True)
kwargs.setdefault('db_results', True)
kwargs.setdefault('dialogue_acts', True)
return load_unified_data(dataset, **kwargs)
def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = use_context
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', use_context)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
kwargs.setdefault('dialogue_acts', True)
return load_unified_data(dataset, **kwargs)
def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split
kwargs['speaker'] = speaker
kwargs['use_context'] = True
kwargs['context_window_size'] = context_window_size
kwargs['utterance'] = True
kwargs['state'] = True
kwargs['db_results'] = True
kwargs['dialogue_acts'] = True
data_by_split = load_unified_data(dataset, **kwargs)
return data_by_split
kwargs.setdefault('data_split', data_split)
kwargs.setdefault('speaker', speaker)
kwargs.setdefault('use_context', True)
kwargs.setdefault('context_window_size', context_window_size)
kwargs.setdefault('utterance', True)
kwargs.setdefault('state', True)
kwargs.setdefault('db_results', True)
kwargs.setdefault('dialogue_acts', True)
return load_unified_data(dataset, **kwargs)
if __name__ == "__main__":
dataset, ontology = load_dataset('multiwoz21')
dataset = load_dataset('multiwoz21')
print(dataset.keys())
print(len(dataset['test']))
......@@ -171,5 +178,5 @@ if __name__ == "__main__":
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res))
data_by_split = load_e2e_data(dataset, data_split='test')
pprint(data_by_split['test'][3])
data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
pprint(data_by_split['test'][0])
......@@ -4,9 +4,10 @@
We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets:
```python
from convlab2 import load_dataset, load_database
from convlab2.util import load_dataset, load_ontology, load_database
dataset, ontology = load_dataset('multiwoz21')
dataset = load_dataset('multiwoz21')
ontology = load_ontology('multiwoz21')
database = load_database('multiwoz21')
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment