Skip to content
Snippets Groups Projects
Commit eeb903df authored by zqwerty's avatar zqwerty
Browse files

add BaseDatabase class for unified datasets

parent 4d4d0ac8
No related branches found
No related tags found
No related merge requests found
......@@ -16,23 +16,29 @@ class BaseDatabase(ABC):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Tuple[List, Dict]:
def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
"""load unified datasets from `data/unified_datasets/$dataset_name`
Args:
dataset_name (str): unique dataset name in `data/unified_datasets`
Returns:
dialogues (list): each element is a dialog in unified format
dataset (dict): keys are data splits and the values are lists of dialogues
ontology (dict): dataset ontology
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
return dialogues, ontology
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
dataset = {}
for dialogue in dialogues:
if dialogue['data_split'] not in dataset:
dataset[dialogue['data_split']] = [dialogue]
else:
dataset[dialogue['data_split']].append(dialogue)
return dataset, ontology
def load_database(dataset_name:str):
"""load database from `data/unified_datasets/$dataset_name`
......@@ -43,17 +49,21 @@ def load_database(dataset_name:str):
Returns:
database: an instance of BaseDatabase
"""
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
cwd = os.getcwd()
os.chdir(data_dir)
Database = importlib.import_module('database').Database
os.chdir(cwd)
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py'))
module_spec = importlib.util.spec_from_file_location('database', data_dir)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
Database = module.Database
assert issubclass(Database, BaseDatabase)
database = Database()
assert isinstance(database, BaseDatabase)
return database
if __name__ == "__main__":
dialogues, ontology = load_dataset('multiwoz21')
# dataset, ontology = load_dataset('multiwoz21')
# print(dataset.keys())
# print(len(dataset['train']))
from convlab2.util.unified_datasets_util import BaseDatabase
database = load_database('multiwoz21')
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res))
......@@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase
```python
from convlab2 import load_dataset, load_database
dialogues, ontology = load_dataset('multiwoz21')
dataset, ontology = load_dataset('multiwoz21')
database = load_database('multiwoz21')
```
`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below.
Each dataset contains at least these files:
- `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format.
......
......@@ -315,10 +315,8 @@ if __name__ == '__main__':
if args.preprocess:
print('pre-processing')
os.chdir(name)
preprocess = importlib.import_module(f'{name}.preprocess')
preprocess.preprocess()
os.chdir('..')
data_file = f'{name}/data.zip'
if not os.path.exists(data_file):
......
......@@ -35,7 +35,7 @@ class Database(BaseDatabase):
'leaveAt': 'leave at'
}
def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list:
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
# query the db
if domain == 'taxi':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment