diff --git a/convlab2/__init__.py b/convlab2/__init__.py index 0fe7d5bf01a7ac06fff476149427712f8c7c7ee4..a7b4c3c2d972cbb2dab7b091a099e8591b50895c 100755 --- a/convlab2/__init__.py +++ b/convlab2/__init__.py @@ -6,7 +6,8 @@ from convlab2.policy import Policy from convlab2.nlg import NLG from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Session, BiSession, DealornotSession -from convlab2.util.unified_datasets_util import load_dataset, load_database +from convlab2.util.unified_datasets_util import load_dataset, load_database, load_unified_data, \ + load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data from os.path import abspath, dirname diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 921c6c451042c754450cf93afd7e2a06055026b2..d6d637aa1c60257e861abea37d3f06e52c2f3c71 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -1,28 +1,175 @@ +from copy import deepcopy +from typing import Dict, List, Tuple from zipfile import ZipFile import json import os import importlib +from abc import ABC, abstractmethod +from pprint import pprint -def load_dataset(dataset_name): + +class BaseDatabase(ABC): + """Base class of unified database. Should override the query function.""" + def __init__(self): + """extract data.zip and load the database.""" + + @abstractmethod + def query(self, domain:str, state:dict, topk:int, **kwargs)->list: + """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" + + +def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]: + """load unified datasets from `data/unified_datasets/$dataset_name` + + Args: + dataset_name (str): unique dataset name in `data/unified_datasets` + + Returns: + dataset (dict): keys are data splits and the values are lists of dialogues + ontology (dict): dataset ontology + """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) - with archive.open('data/dialogues.json') as f: - dialogues = json.loads(f.read()) with archive.open('data/ontology.json') as f: ontology = json.loads(f.read()) - return dialogues, ontology + with archive.open('data/dialogues.json') as f: + dialogues = json.loads(f.read()) + dataset = {} + for dialogue in dialogues: + if dialogue['data_split'] not in dataset: + dataset[dialogue['data_split']] = [dialogue] + else: + dataset[dialogue['data_split']].append(dialogue) + return dataset, ontology -def load_database(dataset_name): - data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - cwd = os.getcwd() - os.chdir(data_dir) - Database = importlib.import_module('database').Database - os.chdir(cwd) +def load_database(dataset_name:str): + """load database from `data/unified_datasets/$dataset_name` + + Args: + dataset_name (str): unique dataset name in `data/unified_datasets` + + Returns: + database: an instance of BaseDatabase + """ + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) + module_spec = importlib.util.spec_from_file_location('database', data_dir) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + Database = module.Database + assert issubclass(Database, BaseDatabase) database = Database() + assert isinstance(database, BaseDatabase) return database +def load_unified_data( + dataset, + data_split='all', + speaker='all', + utterance=False, + dialogue_acts=False, + state=False, + db_results=False, + use_context=False, + context_window_size=0, + terminated=False, + goal=False, + active_domains=False +): + data_splits = dataset.keys() if data_split == 'all' else [data_split] + assert speaker in ['user', 'system', 'all'] + assert not use_context or context_window_size > 0 + info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) + data_by_split = {} + for data_split in data_splits: + data_by_split[data_split] = [] + for dialogue in dataset[data_split]: + context = [] + for turn in dialogue['turns']: + sample = {'speaker': turn['speaker']} + for ele in info_list: + if ele in turn: + sample[ele] = turn[ele] + + if use_context: + sample_copy = deepcopy(sample) + context.append(sample_copy) + + if speaker == turn['speaker'] or speaker == 'all': + if use_context: + sample['context'] = context[-context_window_size-1:-1] + if goal: + sample['goal'] = dialogue['goal'] + if active_domains: + sample['domains'] = dialogue['domains'] + if terminated: + sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1 + data_by_split[data_split].append(sample) + return data_by_split + +def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): + kwargs['data_split'] = data_split + kwargs['speaker'] = speaker + kwargs['use_context'] = use_context + kwargs['context_window_size'] = context_window_size + kwargs['utterance'] = True + kwargs['dialogue_acts'] = True + data_by_split = load_unified_data(dataset, **kwargs) + return data_by_split + +def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs): + kwargs['data_split'] = data_split + kwargs['speaker'] = speaker + kwargs['use_context'] = True + kwargs['context_window_size'] = context_window_size + kwargs['utterance'] = True + kwargs['state'] = True + data_by_split = load_unified_data(dataset, **kwargs) + return data_by_split + +def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs): + kwargs['data_split'] = data_split + kwargs['speaker'] = speaker + kwargs['use_context'] = True + kwargs['context_window_size'] = context_window_size + kwargs['utterance'] = True + kwargs['state'] = True + kwargs['db_results'] = True + kwargs['dialogue_acts'] = True + data_by_split = load_unified_data(dataset, **kwargs) + return data_by_split + +def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs): + kwargs['data_split'] = data_split + kwargs['speaker'] = speaker + kwargs['use_context'] = use_context + kwargs['context_window_size'] = context_window_size + kwargs['utterance'] = True + kwargs['dialogue_acts'] = True + data_by_split = load_unified_data(dataset, **kwargs) + return data_by_split + +def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): + kwargs['data_split'] = data_split + kwargs['speaker'] = speaker + kwargs['use_context'] = True + kwargs['context_window_size'] = context_window_size + kwargs['utterance'] = True + kwargs['state'] = True + kwargs['db_results'] = True + kwargs['dialogue_acts'] = True + data_by_split = load_unified_data(dataset, **kwargs) + return data_by_split + + if __name__ == "__main__": - dialogues, ontology = load_dataset('multiwoz21') + dataset, ontology = load_dataset('multiwoz21') + print(dataset.keys()) + print(len(dataset['test'])) + + from convlab2.util.unified_datasets_util import BaseDatabase database = load_database('multiwoz21') res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res[0], len(res)) + + data_by_split = load_e2e_data(dataset, data_split='test') + pprint(data_by_split['test'][3]) diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 615d1b52f68168351cb7541339ebdf9670cb28fb..7400504a8fc797164f7f495bbedc6bfe4399997d 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase ```python from convlab2 import load_dataset, load_database -dialogues, ontology = load_dataset('multiwoz21') +dataset, ontology = load_dataset('multiwoz21') database = load_database('multiwoz21') ``` +`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below. + Each dataset contains at least these files: - `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format. @@ -31,7 +33,9 @@ if __name__ == '__main__': Datasets that require database interaction should also include the following file: - `database.py`: load the database and define the query function: ```python -class Database: +from convlab2.util.unified_datasets_util import BaseDatabase + +class Database(BaseDatabase): def __init__(self): """extract data.zip and load the database.""" diff --git a/data/unified_datasets/check.py b/data/unified_datasets/check.py index 47e75e602b5009e0a6b9bbf53b7a1291491522a6..4809c857041a9144012880627b7f8745dd70c075 100644 --- a/data/unified_datasets/check.py +++ b/data/unified_datasets/check.py @@ -315,10 +315,8 @@ if __name__ == '__main__': if args.preprocess: print('pre-processing') - os.chdir(name) preprocess = importlib.import_module(f'{name}.preprocess') preprocess.preprocess() - os.chdir('..') data_file = f'{name}/data.zip' if not os.path.exists(data_file): diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index 0dbf50c85d4808d96503cb1861bd84a8054f0965..43ea5896285ebf9e8e38f99c89823bf6538a2bad 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -5,9 +5,10 @@ from fuzzywuzzy import fuzz from itertools import chain from zipfile import ZipFile from copy import deepcopy +from convlab2.util.unified_datasets_util import BaseDatabase -class Database: +class Database(BaseDatabase): def __init__(self): """extract data.zip and load the database.""" archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip')) @@ -34,7 +35,7 @@ class Database: 'leaveAt': 'leave at' } - def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): + def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" # query the db if domain == 'taxi': @@ -102,6 +103,8 @@ class Database: if __name__ == '__main__': db = Database() + assert issubclass(Database, BaseDatabase) + assert isinstance(db, BaseDatabase) res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res, len(res)) # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']]))