from copy import deepcopy from typing import Dict, List, Tuple from zipfile import ZipFile import json import os import importlib from abc import ABC, abstractmethod from pprint import pprint class BaseDatabase(ABC): """Base class of unified database. Should override the query function.""" def __init__(self): """extract data.zip and load the database.""" @abstractmethod def query(self, domain:str, state:dict, topk:int, **kwargs)->list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" def load_dataset(dataset_name:str) -> Dict: """load unified dataset from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: dataset (dict): keys are data splits and the values are lists of dialogues """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) with archive.open('data/dialogues.json') as f: dialogues = json.loads(f.read()) dataset = {} for dialogue in dialogues: if dialogue['data_split'] not in dataset: dataset[dialogue['data_split']] = [dialogue] else: dataset[dialogue['data_split']].append(dialogue) return dataset def load_ontology(dataset_name:str) -> Dict: """load unified ontology from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) with archive.open('data/ontology.json') as f: ontology = json.loads(f.read()) return ontology def load_database(dataset_name:str): """load database from `data/unified_datasets/$dataset_name` Args: dataset_name (str): unique dataset name in `data/unified_datasets` Returns: database: an instance of BaseDatabase """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) module_spec = importlib.util.spec_from_file_location('database', data_dir) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) Database = module.Database assert issubclass(Database, BaseDatabase) database = Database() assert isinstance(database, BaseDatabase) return database def load_unified_data( dataset, data_split='all', speaker='all', utterance=False, dialogue_acts=False, state=False, db_results=False, use_context=False, context_window_size=0, terminated=False, goal=False, active_domains=False ): data_splits = dataset.keys() if data_split == 'all' else [data_split] assert speaker in ['user', 'system', 'all'] assert not use_context or context_window_size > 0 info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) data_by_split = {} for data_split in data_splits: data_by_split[data_split] = [] for dialogue in dataset[data_split]: context = [] for turn in dialogue['turns']: sample = {'speaker': turn['speaker']} for ele in info_list: if ele in turn: sample[ele] = turn[ele] if use_context: sample_copy = deepcopy(sample) context.append(sample_copy) if speaker == turn['speaker'] or speaker == 'all': if use_context: sample['context'] = context[-context_window_size-1:-1] if goal: sample['goal'] = dialogue['goal'] if active_domains: sample['domains'] = dialogue['domains'] if terminated: sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1 data_by_split[data_split].append(sample) return data_by_split def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): kwargs.setdefault('data_split', data_split) kwargs.setdefault('speaker', speaker) kwargs.setdefault('use_context', use_context) kwargs.setdefault('context_window_size', context_window_size) kwargs.setdefault('utterance', True) kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) def load_dst_data(dataset, data_split='all', speaker='user', context_window_size=100, **kwargs): kwargs.setdefault('data_split', data_split) kwargs.setdefault('speaker', speaker) kwargs.setdefault('use_context', True) kwargs.setdefault('context_window_size', context_window_size) kwargs.setdefault('utterance', True) kwargs.setdefault('state', True) return load_unified_data(dataset, **kwargs) def load_policy_data(dataset, data_split='all', speaker='system', context_window_size=1, **kwargs): kwargs.setdefault('data_split', data_split) kwargs.setdefault('speaker', speaker) kwargs.setdefault('use_context', True) kwargs.setdefault('context_window_size', context_window_size) kwargs.setdefault('utterance', True) kwargs.setdefault('state', True) kwargs.setdefault('db_results', True) kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) def load_nlg_data(dataset, data_split='all', speaker='system', use_context=False, context_window_size=0, **kwargs): kwargs.setdefault('data_split', data_split) kwargs.setdefault('speaker', speaker) kwargs.setdefault('use_context', use_context) kwargs.setdefault('context_window_size', context_window_size) kwargs.setdefault('utterance', True) kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) def load_e2e_data(dataset, data_split='all', speaker='system', context_window_size=100, **kwargs): kwargs.setdefault('data_split', data_split) kwargs.setdefault('speaker', speaker) kwargs.setdefault('use_context', True) kwargs.setdefault('context_window_size', context_window_size) kwargs.setdefault('utterance', True) kwargs.setdefault('state', True) kwargs.setdefault('db_results', True) kwargs.setdefault('dialogue_acts', True) return load_unified_data(dataset, **kwargs) if __name__ == "__main__": dataset = load_dataset('multiwoz21') print(dataset.keys()) print(len(dataset['test'])) from convlab2.util.unified_datasets_util import BaseDatabase database = load_database('multiwoz21') res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res[0], len(res)) data_by_split = load_nlu_data(dataset, data_split='test', speaker='user') pprint(data_by_split['test'][0])