Skip to content
Snippets Groups Projects
Commit 7ae31c8b authored by zqwerty's avatar zqwerty
Browse files

add load_dataset interface

parent 8962c9e0
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from convlab2.policy import Policy ...@@ -6,6 +6,7 @@ from convlab2.policy import Policy
from convlab2.nlg import NLG from convlab2.nlg import NLG
from convlab2.dialog_agent import Agent, PipelineAgent from convlab2.dialog_agent import Agent, PipelineAgent
from convlab2.dialog_agent import Session, BiSession, DealornotSession from convlab2.dialog_agent import Session, BiSession, DealornotSession
from convlab2.util.unified_datasets_util import load_dataset, load_database
from os.path import abspath, dirname from os.path import abspath, dirname
......
from zipfile import ZipFile
import json
import os
import importlib
def load_dataset(dataset_name):
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip'))
with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read())
with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read())
return dialogues, ontology
def load_database(dataset_name):
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
cwd = os.getcwd()
os.chdir(data_dir)
Database = importlib.import_module('database').Database
os.chdir(cwd)
database = Database()
return database
if __name__ == "__main__":
dialogues, ontology = load_dataset('multiwoz21')
database = load_database('multiwoz21')
res = database.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
print(res[0], len(res))
# Unified data format # Unified data format
## Overview ## Overview
We transform different datasets into a unified format under `data/unified_datasets` directory. We transform different datasets into a unified format under `data/unified_datasets` directory. To import a unified datasets:
```python
from convlab2 import load_dataset, load_database
dialogues, ontology = load_dataset('multiwoz21')
database = load_database('multiwoz21')
```
Each dataset contains at least these files: Each dataset contains at least these files:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment