diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index 42963ad770cdb87dbe3fa1605e3b305b361f0e10..7c807ae72368d32922df5f7ec87ffb366176506c 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -7,6 +7,8 @@ import re import importlib from abc import ABC, abstractmethod from pprint import pprint +from convlab.util.file_util import cached_path +import shutil class BaseDatabase(ABC): @@ -18,6 +20,23 @@ class BaseDatabase(ABC): def query(self, domain:str, state:dict, topk:int, **kwargs)->list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" +def load_from_hf_datasets(dataset_name, filename, data_dir): + """ + It downloads the file from the Hugging Face if it doesn't exist in the data directory + + :param dataset_name: The name of the dataset + :param filename: the name of the file you want to download + :param data_dir: the directory where the data will be downloaded to + :return: The data path + """ + data_path = os.path.join(data_dir, filename) + if not os.path.exists(data_path): + if not os.path.exists(data_dir): + os.makedirs(data_dir, exist_ok=True) + data_url = f'https://huggingface.co/datasets/ConvLab/{dataset_name}/resolve/main/{filename}' + cache_path = cached_path(data_url) + shutil.move(cache_path, data_path) + return data_path def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: """load unified dataset from `data/unified_datasets/$dataset_name` @@ -30,12 +49,15 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: dataset (dict): keys are data splits and the values are lists of dialogues """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - archive = ZipFile(os.path.join(data_dir, 'data.zip')) + data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + + archive = ZipFile(data_path) with archive.open('data/dialogues.json') as f: dialogues = json.loads(f.read()) dataset = {} if dial_ids_order is not None: - dial_ids = json.load(open(os.path.join(data_dir, 'shuffled_dial_ids.json')))[dial_ids_order] + data_path = load_from_hf_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) + dial_ids = json.load(open(data_path))[dial_ids_order] for data_split in dial_ids: dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]] else: @@ -56,7 +78,9 @@ def load_ontology(dataset_name:str) -> Dict: ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - archive = ZipFile(os.path.join(data_dir, 'data.zip')) + data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + + archive = ZipFile(data_path) with archive.open('data/ontology.json') as f: ontology = json.loads(f.read()) return ontology @@ -70,8 +94,9 @@ def load_database(dataset_name:str): Returns: database: an instance of BaseDatabase """ - data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}/database.py')) - module_spec = importlib.util.spec_from_file_location('database', data_dir) + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + data_path = load_from_hf_datasets(dataset_name, 'database.py', data_dir) + module_spec = importlib.util.spec_from_file_location('database', data_path) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) Database = module.Database diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py index d6fcfa9816362c9ca764e76a5d58bc8f9d3492cf..138572cdbedd6b97f93fb0ca57007c7bd84a62b0 100644 --- a/data/unified_datasets/multiwoz21/preprocess.py +++ b/data/unified_datasets/multiwoz21/preprocess.py @@ -8,7 +8,7 @@ from tqdm import tqdm from collections import Counter from pprint import pprint from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer -from data.unified_datasets.multiwoz21.booking_remapper import BookingActRemapper +from .booking_remapper import BookingActRemapper ontology = { "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different