Skip to content
Snippets Groups Projects
Commit 30b584aa authored by zqwerty's avatar zqwerty
Browse files

modify load_dataset interface

parent 4f4290da
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,6 @@ from convlab2.policy import Policy ...@@ -6,8 +6,6 @@ from convlab2.policy import Policy
from convlab2.nlg import NLG from convlab2.nlg import NLG
from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Agent, PipelineAgent
from convlab2.dialog_agent import Session, BiSession, DealornotSession from convlab2.dialog_agent import Session, BiSession, DealornotSession
from convlab2.util.unified_datasets_util import load_dataset, load_database, load_unified_data, \
load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data
from os.path import abspath, dirname from os.path import abspath, dirname
......
from convlab2.util.unified_datasets_util import load_dataset, load_ontology, load_database, \
load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data
\ No newline at end of file
...@@ -18,20 +18,17 @@ class BaseDatabase(ABC): ...@@ -18,20 +18,17 @@ class BaseDatabase(ABC):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: def load_dataset(dataset_name:str) -> Dict:
"""load unified datasets from `data/unified_datasets/$dataset_name` """load unified dataset from `data/unified_datasets/$dataset_name`
Args: Args:
dataset_name (str): unique dataset name in `data/unified_datasets` dataset_name (str): unique dataset name in `data/unified_datasets`
Returns: Returns:
dataset (dict): keys are data splits and the values are lists of dialogues dataset (dict): keys are data splits and the values are lists of dialogues
ontology (dict): dataset ontology
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip')) archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
with archive.open('data/dialogues.json') as f: with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read()) dialogues = json.loads(f.read())
dataset = {} dataset = {}
...@@ -40,7 +37,22 @@ def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: ...@@ -40,7 +37,22 @@ def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
dataset[dialogue['data_split']] = [dialogue] dataset[dialogue['data_split']] = [dialogue]
else: else:
dataset[dialogue['data_split']].append(dialogue) dataset[dialogue['data_split']].append(dialogue)
return dataset, ontology return dataset
def load_ontology(dataset_name:str) -> Dict:
"""load unified ontology from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
ontology (dict): dataset ontology
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
return ontology
def load_database(dataset_name:str): def load_database(dataset_name:str):
"""load database from `data/unified_datasets/$dataset_name` """load database from `data/unified_datasets/$dataset_name`
...@@ -107,62 +119,57 @@ def load_unified_data( ...@@ -107,62 +119,57 @@ def load_unified_data(
return data_by_split return data_by_split
def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split kwargs.setdefault('data_split', data_split)
kwargs['speaker'] = speaker kwargs.setdefault('speaker', speaker)
kwargs['use_context'] = use_context kwargs.setdefault('use_context', use_context)
kwargs['context_window_size'] = context_window_size kwargs.setdefault('context_window_size', context_window_size)
kwargs['utterance'] = True kwargs.setdefault('utterance', True)
kwargs['dialogue_acts'] = True kwargs.setdefault('dialogue_acts', True)
data_by_split = load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
return data_by_split
def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs): def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split kwargs.setdefault('data_split', data_split)
kwargs['speaker'] = speaker kwargs.setdefault('speaker', speaker)
kwargs['use_context'] = True kwargs.setdefault('use_context', True)
kwargs['context_window_size'] = context_window_size kwargs.setdefault('context_window_size', context_window_size)
kwargs['utterance'] = True kwargs.setdefault('utterance', True)
kwargs['state'] = True kwargs.setdefault('state', True)
data_by_split = load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
return data_by_split
def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs): def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs):
kwargs['data_split'] = data_split kwargs.setdefault('data_split', data_split)
kwargs['speaker'] = speaker kwargs.setdefault('speaker', speaker)
kwargs['use_context'] = True kwargs.setdefault('use_context', True)
kwargs['context_window_size'] = context_window_size kwargs.setdefault('context_window_size', context_window_size)
kwargs['utterance'] = True kwargs.setdefault('utterance', True)
kwargs['state'] = True kwargs.setdefault('state', True)
kwargs['db_results'] = True kwargs.setdefault('db_results', True)
kwargs['dialogue_acts'] = True kwargs.setdefault('dialogue_acts', True)
data_by_split = load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
return data_by_split
def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs): def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs):
kwargs['data_split'] = data_split kwargs.setdefault('data_split', data_split)
kwargs['speaker'] = speaker kwargs.setdefault('speaker', speaker)
kwargs['use_context'] = use_context kwargs.setdefault('use_context', use_context)
kwargs['context_window_size'] = context_window_size kwargs.setdefault('context_window_size', context_window_size)
kwargs['utterance'] = True kwargs.setdefault('utterance', True)
kwargs['dialogue_acts'] = True kwargs.setdefault('dialogue_acts', True)
data_by_split = load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
return data_by_split
def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs):
kwargs['data_split'] = data_split kwargs.setdefault('data_split', data_split)
kwargs['speaker'] = speaker kwargs.setdefault('speaker', speaker)
kwargs['use_context'] = True kwargs.setdefault('use_context', True)
kwargs['context_window_size'] = context_window_size kwargs.setdefault('context_window_size', context_window_size)
kwargs['utterance'] = True kwargs.setdefault('utterance', True)
kwargs['state'] = True kwargs.setdefault('state', True)
kwargs['db_results'] = True kwargs.setdefault('db_results', True)
kwargs['dialogue_acts'] = True kwargs.setdefault('dialogue_acts', True)
data_by_split = load_unified_data(dataset, **kwargs) return load_unified_data(dataset, **kwargs)
return data_by_split
if __name__ == "__main__": if __name__ == "__main__":
dataset, ontology = load_dataset('multiwoz21') dataset = load_dataset('multiwoz21')
print(dataset.keys()) print(dataset.keys())
print(len(dataset['test'])) print(len(dataset['test']))
...@@ -171,5 +178,5 @@ if __name__ == "__main__": ...@@ -171,5 +178,5 @@ if __name__ == "__main__":
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res)) print(res[0], len(res))
data_by_split = load_e2e_data(dataset, data_split='test') data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
pprint(data_by_split['test'][3]) pprint(data_by_split['test'][0])
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets: We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets:
```python ```python
from convlab2 import load_dataset, load_database from convlab2.util import load_dataset, load_ontology, load_database
dataset, ontology = load_dataset('multiwoz21') dataset = load_dataset('multiwoz21')
ontology = load_ontology('multiwoz21')
database = load_database('multiwoz21') database = load_database('multiwoz21')
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment