From 4d4d0ac8f0ba734c30b4191f2af8f2e2d4cc2068 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 20 Dec 2021 07:47:06 +0000 Subject: [PATCH] add BaseDatabase class for unified dataset --- convlab2/util/unified_datasets_util.py | 35 ++++++++++++++++++-- data/unified_datasets/README.md | 4 ++- data/unified_datasets/multiwoz21/database.py | 5 ++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 921c6c45..341d21e0 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -1,9 +1,31 @@ +from typing import Dict, List, Tuple from zipfile import ZipFile import json import os import importlib +from abc import ABC, abstractmethod -def load_dataset(dataset_name): + +class BaseDatabase(ABC): + """Base class of unified database. Should override the query function.""" + def __init__(self): + """extract data.zip and load the database.""" + + @abstractmethod + def query(self, domain:str, state:dict, topk:int, **kwargs)->list: + """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" + + +def load_dataset(dataset_name:str) -> Tuple[List, Dict]: + """load unified datasets from `data/unified_datasets/$dataset_name` + + Args: + dataset_name (str): unique dataset name in `data/unified_datasets` + + Returns: + dialogues (list): each element is a dialog in unified format + ontology (dict): dataset ontology + """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) archive = ZipFile(os.path.join(data_dir, 'data.zip')) with archive.open('data/dialogues.json') as f: @@ -12,13 +34,22 @@ def load_dataset(dataset_name): ontology = json.loads(f.read()) return dialogues, ontology -def load_database(dataset_name): +def load_database(dataset_name:str): + """load database from `data/unified_datasets/$dataset_name` + + Args: + dataset_name (str): unique dataset name in `data/unified_datasets` + + Returns: + database: an instance of BaseDatabase + """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) cwd = os.getcwd() os.chdir(data_dir) Database = importlib.import_module('database').Database os.chdir(cwd) database = Database() + assert isinstance(database, BaseDatabase) return database if __name__ == "__main__": diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md index 615d1b52..df077aea 100644 --- a/data/unified_datasets/README.md +++ b/data/unified_datasets/README.md @@ -31,7 +31,9 @@ if __name__ == '__main__': Datasets that require database interaction should also include the following file: - `database.py`: load the database and define the query function: ```python -class Database: +from convlab2.util.unified_datasets_util import BaseDatabase + +class Database(BaseDatabase): def __init__(self): """extract data.zip and load the database.""" diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index 0dbf50c8..dcb3d702 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -5,9 +5,10 @@ from fuzzywuzzy import fuzz from itertools import chain from zipfile import ZipFile from copy import deepcopy +from convlab2.util.unified_datasets_util import BaseDatabase -class Database: +class Database(BaseDatabase): def __init__(self): """extract data.zip and load the database.""" archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip')) @@ -102,6 +103,8 @@ class Database: if __name__ == '__main__': db = Database() + assert issubclass(Database, BaseDatabase) + assert isinstance(db, BaseDatabase) res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) print(res, len(res)) # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']])) -- GitLab