Skip to content
Snippets Groups Projects
Commit 4d4d0ac8 authored by zqwerty's avatar zqwerty
Browse files

add BaseDatabase class for unified dataset

parent 087fbfa3
No related branches found
No related tags found
Loading
from typing import Dict, List, Tuple
from zipfile import ZipFile from zipfile import ZipFile
import json import json
import os import os
import importlib import importlib
from abc import ABC, abstractmethod
def load_dataset(dataset_name):
class BaseDatabase(ABC):
"""Base class of unified database. Should override the query function."""
def __init__(self):
"""extract data.zip and load the database."""
@abstractmethod
def query(self, domain:str, state:dict, topk:int, **kwargs)->list:
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Tuple[List, Dict]:
"""load unified datasets from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
dialogues (list): each element is a dialog in unified format
ontology (dict): dataset ontology
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip')) archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/dialogues.json') as f: with archive.open('data/dialogues.json') as f:
...@@ -12,13 +34,22 @@ def load_dataset(dataset_name): ...@@ -12,13 +34,22 @@ def load_dataset(dataset_name):
ontology = json.loads(f.read()) ontology = json.loads(f.read())
return dialogues, ontology return dialogues, ontology
def load_database(dataset_name): def load_database(dataset_name:str):
"""load database from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
database: an instance of BaseDatabase
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
cwd = os.getcwd() cwd = os.getcwd()
os.chdir(data_dir) os.chdir(data_dir)
Database = importlib.import_module('database').Database Database = importlib.import_module('database').Database
os.chdir(cwd) os.chdir(cwd)
database = Database() database = Database()
assert isinstance(database, BaseDatabase)
return database return database
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -31,7 +31,9 @@ if __name__ == '__main__': ...@@ -31,7 +31,9 @@ if __name__ == '__main__':
Datasets that require database interaction should also include the following file: Datasets that require database interaction should also include the following file:
- `database.py`: load the database and define the query function: - `database.py`: load the database and define the query function:
```python ```python
class Database: from convlab2.util.unified_datasets_util import BaseDatabase
class Database(BaseDatabase):
def __init__(self): def __init__(self):
"""extract data.zip and load the database.""" """extract data.zip and load the database."""
......
...@@ -5,9 +5,10 @@ from fuzzywuzzy import fuzz ...@@ -5,9 +5,10 @@ from fuzzywuzzy import fuzz
from itertools import chain from itertools import chain
from zipfile import ZipFile from zipfile import ZipFile
from copy import deepcopy from copy import deepcopy
from convlab2.util.unified_datasets_util import BaseDatabase
class Database: class Database(BaseDatabase):
def __init__(self): def __init__(self):
"""extract data.zip and load the database.""" """extract data.zip and load the database."""
archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip')) archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip'))
...@@ -102,6 +103,8 @@ class Database: ...@@ -102,6 +103,8 @@ class Database:
if __name__ == '__main__': if __name__ == '__main__':
db = Database() db = Database()
assert issubclass(Database, BaseDatabase)
assert isinstance(db, BaseDatabase)
res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res, len(res)) print(res, len(res))
# print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']])) # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment