Skip to content
Snippets Groups Projects
Commit eeb903df authored by zqwerty's avatar zqwerty
Browse files

add BaseDatabase class for unified datasets

parent 4d4d0ac8
Branches
No related tags found
No related merge requests found
...@@ -16,23 +16,29 @@ class BaseDatabase(ABC): ...@@ -16,23 +16,29 @@ class BaseDatabase(ABC):
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_dataset(dataset_name:str) -> Tuple[List, Dict]: def load_dataset(dataset_name:str) -> Tuple[Dict, Dict]:
"""load unified datasets from `data/unified_datasets/$dataset_name` """load unified datasets from `data/unified_datasets/$dataset_name`
Args: Args:
dataset_name (str): unique dataset name in `data/unified_datasets` dataset_name (str): unique dataset name in `data/unified_datasets`
Returns: Returns:
dialogues (list): each element is a dialog in unified format dataset (dict): keys are data splits and the values are lists of dialogues
ontology (dict): dataset ontology ontology (dict): dataset ontology
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip')) archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
with archive.open('data/ontology.json') as f: with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read()) ontology = json.loads(f.read())
return dialogues, ontology with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
dataset = {}
for dialogue in dialogues:
if dialogue['data_split'] not in dataset:
dataset[dialogue['data_split']] = [dialogue]
else:
dataset[dialogue['data_split']].append(dialogue)
return dataset, ontology
def load_database(dataset_name:str): def load_database(dataset_name:str):
"""load database from `data/unified_datasets/$dataset_name` """load database from `data/unified_datasets/$dataset_name`
...@@ -43,17 +49,21 @@ def load_database(dataset_name:str): ...@@ -43,17 +49,21 @@ def load_database(dataset_name:str):
Returns: Returns:
database: an instance of BaseDatabase database: an instance of BaseDatabase
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py'))
cwd = os.getcwd() module_spec = importlib.util.spec_from_file_location('database', data_dir)
os.chdir(data_dir) module = importlib.util.module_from_spec(module_spec)
Database = importlib.import_module('database').Database module_spec.loader.exec_module(module)
os.chdir(cwd) Database = module.Database
assert issubclass(Database, BaseDatabase)
database = Database() database = Database()
assert isinstance(database, BaseDatabase) assert isinstance(database, BaseDatabase)
return database return database
if __name__ == "__main__": if __name__ == "__main__":
dialogues, ontology = load_dataset('multiwoz21') # dataset, ontology = load_dataset('multiwoz21')
# print(dataset.keys())
# print(len(dataset['train']))
from convlab2.util.unified_datasets_util import BaseDatabase
database = load_database('multiwoz21') database = load_database('multiwoz21')
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res)) print(res[0], len(res))
...@@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase ...@@ -6,10 +6,12 @@ We transform different datasets into a unified format under `data/unified_datase
```python ```python
from convlab2 import load_dataset, load_database from convlab2 import load_dataset, load_database
dialogues, ontology = load_dataset('multiwoz21') dataset, ontology = load_dataset('multiwoz21')
database = load_database('multiwoz21') database = load_database('multiwoz21')
``` ```
`dataset` is a dict where the keys are data splits and the values are lists of dialogues. `database` is an instance of `Database` class that has a `query` function. The format of dialogue, ontology, and Database are defined below.
Each dataset contains at least these files: Each dataset contains at least these files:
- `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format. - `README.md`: dataset description and the **main changes** from original data to processed data. Should include the instruction on how to get the original data and transform them into the unified format.
......
...@@ -315,10 +315,8 @@ if __name__ == '__main__': ...@@ -315,10 +315,8 @@ if __name__ == '__main__':
if args.preprocess: if args.preprocess:
print('pre-processing') print('pre-processing')
os.chdir(name)
preprocess = importlib.import_module(f'{name}.preprocess') preprocess = importlib.import_module(f'{name}.preprocess')
preprocess.preprocess() preprocess.preprocess()
os.chdir('..')
data_file = f'{name}/data.zip' data_file = f'{name}/data.zip'
if not os.path.exists(data_file): if not os.path.exists(data_file):
......
...@@ -35,7 +35,7 @@ class Database(BaseDatabase): ...@@ -35,7 +35,7 @@ class Database(BaseDatabase):
'leaveAt': 'leave at' 'leaveAt': 'leave at'
} }
def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): def query(self, domain: str, state: dict, topk: int, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60) -> list:
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
# query the db # query the db
if domain == 'taxi': if domain == 'taxi':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment