From eeb903df689150b8032c077022c8b0f838f60b80 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Mon, 20 Dec 2021 09:50:55 +0000
Subject: [PATCH] add BaseDatabase class for unified datasets

---
 convlab2/util/unified_datasets_util.py       | 32 +++++++++++++-------
 data/unified_datasets/README.md              |  4 ++-
 data/unified_datasets/check.py               |  2 --
 data/unified_datasets/multiwoz21/database.py |  2 +-
 4 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py
index 341d21e0..0683e780 100644
--- a/convlab2/util/unified_datasets_util.py
+++ b/convlab2/util/unified_datasets_util.py
@@ -16,23 +16,29 @@ class BaseDatabase(ABC):
         """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
 
 
-def load_dataset(dataset_name:str) -> Tuple[List, Dict]:
+def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
     """load unified datasets from `data/unified_datasets/$dataset_name`
 
     Args:
         dataset_name (str): unique dataset name in `data/unified_datasets`
 
     Returns:
-        dialogues (list): each element is a dialog in unified format
+        dataset (dict): keys are data splits and the values are lists of dialogues
         ontology (dict): dataset ontology
     """
     data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
     archive = ZipFile(os.path.join(data_dir, 'data.zip'))
-    with archive.open('data/dialogues.json') as f:
-        dialogues = json.loads(f.read())
     with archive.open('data/ontology.json') as f:
         ontology = json.loads(f.read())
-    return dialogues, ontology
+    with archive.open('data/dialogues.json') as f:
+        dialogues = json.loads(f.read())
+    dataset = {}
+    for dialogue in dialogues:
+        if dialogue['data_split'] not in dataset:
+            dataset[dialogue['data_split']] = [dialogue]
+        else:
+            dataset[dialogue['data_split']].append(dialogue)
+    return dataset, ontology
 
 def load_database(dataset_name:str):
     """load database from `data/unified_datasets/$dataset_name`
@@ -43,17 +49,21 @@ def load_database(dataset_name:str):
     Returns:
         database: an instance of BaseDatabase
     """
-    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
-    cwd = os.getcwd()
-    os.chdir(data_dir)
-    Database = importlib.import_module('database').Database
-    os.chdir(cwd)
+    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py'))
+    module_spec = importlib.util.spec_from_file_location('database', data_dir)
+    module = importlib.util.module_from_spec(module_spec)
+    module_spec.loader.exec_module(module)
+    Database = module.Database
+    assert issubclass(Database, BaseDatabase)
     database = Database()
     assert isinstance(database, BaseDatabase)
     return database
 
 if __name__ == "__main__":
-    dialogues, ontology = load_dataset('multiwoz21')
+    # dataset, ontology = load_dataset('multiwoz21')
+    # print(dataset.keys())
+    # print(len(dataset['train']))    
+    from convlab2.util.unified_datasets_util import BaseDatabase
     database = load_database('multiwoz21')
     res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
     print(res[0], len(res))
diff --git a/data/unified_datasets/README.md b/data/unified_datasets/README.md
index df077aea..7400504a 100644
--- a/data/unified_datasets/README.md
+++ b/data/unified_datasets/README.md
@@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase
 ```python
 from convlab2 import load_dataset, load_database
 
-dialogues, ontology = load_dataset('multiwoz21')
+dataset, ontology = load_dataset('multiwoz21')
 database = load_database('multiwoz21')
 ```
 
+`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below.
+
 Each dataset contains at least these files:
 
 - `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format.
diff --git a/data/unified_datasets/check.py b/data/unified_datasets/check.py
index 47e75e60..4809c857 100644
--- a/data/unified_datasets/check.py
+++ b/data/unified_datasets/check.py
@@ -315,10 +315,8 @@ if __name__ == '__main__':
             if args.preprocess:
                 print('pre-processing')
 
-                os.chdir(name)
                 preprocess = importlib.import_module(f'{name}.preprocess')
                 preprocess.preprocess()
-                os.chdir('..')
 
             data_file = f'{name}/data.zip'
             if not os.path.exists(data_file):
diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py
index dcb3d702..43ea5896 100644
--- a/data/unified_datasets/multiwoz21/database.py
+++ b/data/unified_datasets/multiwoz21/database.py
@@ -35,7 +35,7 @@ class Database(BaseDatabase):
             'leaveAt': 'leave at'
         }
 
-    def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
+    def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list:
         """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
         # query the db
         if domain == 'taxi':
-- 
GitLab