Skip to content
Snippets Groups Projects
Commit d47f79e5 authored by zqwerty's avatar zqwerty
Browse files

support load dataset from hf dataset

parent f4695da9
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,8 @@ import re ...@@ -7,6 +7,8 @@ import re
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pprint import pprint from pprint import pprint
from convlab.util.file_util import cached_path
import shutil
class BaseDatabase(ABC): class BaseDatabase(ABC):
...@@ -18,6 +20,23 @@ class BaseDatabase(ABC): ...@@ -18,6 +20,23 @@ class BaseDatabase(ABC):
def query(self, domain:str, state:dict, topk:int, **kwargs)->list: def query(self, domain:str, state:dict, topk:int, **kwargs)->list:
"""return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
def load_from_hf_datasets(dataset_name, filename, data_dir):
"""
It downloads the file from the Hugging Face if it doesn't exist in the data directory
:param dataset_name: The name of the dataset
:param filename: the name of the file you want to download
:param data_dir: the directory where the data will be downloaded to
:return: The data path
"""
data_path = os.path.join(data_dir, filename)
if not os.path.exists(data_path):
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
data_url = f'https://huggingface.co/datasets/ConvLab/{dataset_name}/resolve/main/{filename}'
cache_path = cached_path(data_url)
shutil.move(cache_path, data_path)
return data_path
def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
"""load unified dataset from `data/unified_datasets/$dataset_name` """load unified dataset from `data/unified_datasets/$dataset_name`
...@@ -30,12 +49,15 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: ...@@ -30,12 +49,15 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
dataset (dict): keys are data splits and the values are lists of dialogues dataset (dict): keys are data splits and the values are lists of dialogues
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip')) data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir)
archive = ZipFile(data_path)
with archive.open('data/dialogues.json') as f: with archive.open('data/dialogues.json') as f:
dialogues = json.loads(f.read()) dialogues = json.loads(f.read())
dataset = {} dataset = {}
if dial_ids_order is not None: if dial_ids_order is not None:
dial_ids = json.load(open(os.path.join(data_dir, 'shuffled_dial_ids.json')))[dial_ids_order] data_path = load_from_hf_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir)
dial_ids = json.load(open(data_path))[dial_ids_order]
for data_split in dial_ids: for data_split in dial_ids:
dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]] dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]]
else: else:
...@@ -56,7 +78,9 @@ def load_ontology(dataset_name:str) -> Dict: ...@@ -56,7 +78,9 @@ def load_ontology(dataset_name:str) -> Dict:
ontology (dict): dataset ontology ontology (dict): dataset ontology
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
archive = ZipFile(os.path.join(data_dir, 'data.zip')) data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir)
archive = ZipFile(data_path)
with archive.open('data/ontology.json') as f: with archive.open('data/ontology.json') as f:
ontology = json.loads(f.read()) ontology = json.loads(f.read())
return ontology return ontology
...@@ -70,8 +94,9 @@ def load_database(dataset_name:str): ...@@ -70,8 +94,9 @@ def load_database(dataset_name:str):
Returns: Returns:
database: an instance of BaseDatabase database: an instance of BaseDatabase
""" """
data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}'))
module_spec = importlib.util.spec_from_file_location('database', data_dir) data_path = load_from_hf_datasets(dataset_name, 'database.py', data_dir)
module_spec = importlib.util.spec_from_file_location('database', data_path)
module = importlib.util.module_from_spec(module_spec) module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module) module_spec.loader.exec_module(module)
Database = module.Database Database = module.Database
......
...@@ -8,7 +8,7 @@ from tqdm import tqdm ...@@ -8,7 +8,7 @@ from tqdm import tqdm
from collections import Counter from collections import Counter
from pprint import pprint from pprint import pprint
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from data.unified_datasets.multiwoz21.booking_remapper import BookingActRemapper from .booking_remapper import BookingActRemapper
ontology = { ontology = {
"domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment