From 4d4d0ac8f0ba734c30b4191f2af8f2e2d4cc2068 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 20 Dec 2021 07:47:06 +0000
Subject: [PATCH] add BaseDatabase class for unified dataset

---
 convlab2/util/unified_datasets_util.py       | 35 ++++++++++++++++++--
 data/unified_datasets/README.md              |  4 ++-
 data/unified_datasets/multiwoz21/database.py |  5 ++-
 3 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py
index 921c6c45..341d21e0 100644
--- a/convlab2/util/unified_datasets_util.py
+++ b/convlab2/util/unified_datasets_util.py
@@ -1,9 +1,31 @@
+from typing import Dict, List, Tuple
 from zipfile import ZipFile
 import json
 import os
 import importlib
+from abc import ABC, abstractmethod
 
-def load_dataset(dataset_name):
+
+class BaseDatabase(ABC):
+    """Base class of unified database. Should override the query function."""
+    def __init__(self):
+        """extract data.zip and load the database."""
+
+    @abstractmethod
+    def query(self, domain:str, state:dict, topk:int, **kwargs)->list:
+        """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
+
+
+def load_dataset(dataset_name:str) -> Tuple[List, Dict]:
+    """load unified datasets from `data/unified_datasets/$dataset_name`
+
+    Args:
+        dataset_name (str): unique dataset name in `data/unified_datasets`
+
+    Returns:
+        dialogues (list): each element is a dialog in unified format
+        ontology (dict): dataset ontology
+    """
     data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
     archive = ZipFile(os.path.join(data_dir, 'data.zip'))
     with archive.open('data/dialogues.json') as f:
@@ -12,13 +34,22 @@ def load_dataset(dataset_name):
         ontology = json.loads(f.read())
     return dialogues, ontology
 
-def load_database(dataset_name):
+def load_database(dataset_name:str):
+    """load database from `data/unified_datasets/$dataset_name`
+
+    Args:
+        dataset_name (str): unique dataset name in `data/unified_datasets`
+
+    Returns:
+        database: an instance of BaseDatabase
+    """
     data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
     cwd = os.getcwd()
     os.chdir(data_dir)
     Database = importlib.import_module('database').Database
     os.chdir(cwd)
     database = Database()
+    assert isinstance(database, BaseDatabase)
     return database
 
 if __name__ == "__main__":
diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md
index 615d1b52..df077aea 100644
--- a/data/unified_datasets/README.md
+++ b/data/unified_datasets/README.md
@@ -31,7 +31,9 @@ if __name__ == '__main__':
 Datasets that require database interaction should also include the following file:
 - `database.py`: load the database and define the query function:
 ```python
-class Database:
+from convlab2.util.unified_datasets_util import BaseDatabase
+
+class Database(BaseDatabase):
     def __init__(self):
         """extract data.zip and load the database."""
 
diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py
index 0dbf50c8..dcb3d702 100644
--- a/data/unified_datasets/multiwoz21/database.py
+++ b/data/unified_datasets/multiwoz21/database.py
@@ -5,9 +5,10 @@ from fuzzywuzzy import fuzz
 from itertools import chain
 from zipfile import ZipFile
 from copy import deepcopy
+from convlab2.util.unified_datasets_util import BaseDatabase
 
 
-class Database:
+class Database(BaseDatabase):
     def __init__(self):
         """extract data.zip and load the database."""
         archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip'))
@@ -102,6 +103,8 @@ class Database:
 
 if __name__ == '__main__':
     db = Database()
+    assert issubclass(Database, BaseDatabase)
+    assert isinstance(db, BaseDatabase)
     res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
     print(res, len(res))
     # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']]))
-- 
GitLab