diff --git a/convlab/dst/rule/multiwoz/usr_dst.py b/convlab/dst/rule/multiwoz/usr_dst.py index 1de65476b6f0a6a0f419d733af867717c4c1729a..16d6a5c5e5891caee265dbfe4753d4363f6bcb4b 100755 --- a/convlab/dst/rule/multiwoz/usr_dst.py +++ b/convlab/dst/rule/multiwoz/usr_dst.py @@ -7,16 +7,11 @@ from convlab.dst.rule.multiwoz import RuleDST from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple -import importlib from pprint import pprint from copy import deepcopy -from convlab.util import load_ontology +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da +reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') SLOT2SEMI = { "arriveby": "arriveBy", diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index 115f39e64958da727d51ea91431bb8eaa6ef5a5d..b7331479e2ea2ee1a1f5c4bce7aef82546dcbd2e 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -8,13 +8,9 @@ from convlab.evaluator.evaluator import Evaluator from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database import os -import importlib +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da +reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') requestable = \ {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 432f8eb214c6384cd8cec439abfb8d4dc4ec1e7c..0ffc5b38e318826e3de8370484d19845801db4c1 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -16,14 +16,9 @@ import logging from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA -import importlib - -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da -normalize_domain_slot_value = module.normalize_domain_slot_value +from convlab.util import relative_import_module_from_unified_datasets + +reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) def unified_format(acts): new_acts = {'categorical': []} diff --git a/convlab/policy/tus/multiwoz/TUS.py b/convlab/policy/tus/multiwoz/TUS.py index 0fef5f0f09a65caffdecd24214c6da9cd94c2c98..725098d9162d6900746aa7398c14a8aafe49d786 100644 --- a/convlab/policy/tus/multiwoz/TUS.py +++ b/convlab/policy/tus/multiwoz/TUS.py @@ -16,14 +16,9 @@ from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab.util.custom_util import model_downloader from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple -import importlib - -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da -normalize_domain_slot_value = module.normalize_domain_slot_value +from convlab.util import relative_import_module_from_unified_datasets + +reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/convlab/util/__init__.py b/convlab/util/__init__.py index 6d90f3198efcfdadc18ea3a68e03071bb839a453..6a84b7db276389d9bbcd6ba097a0b7bb00440a48 100755 --- a/convlab/util/__init__.py +++ b/convlab/util/__init__.py @@ -1,2 +1,3 @@ from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ - load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data \ No newline at end of file + load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \ + download_unified_datasets, relative_import_module_from_unified_datasets \ No newline at end of file diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 4c30014dbc936978d356df40ff39a9f67c0eec0a..9eab25e26d82d7f20d92343a2ac64bb807239074 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -1,14 +1,9 @@ from copy import deepcopy from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -import os -import importlib +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da_slot_name_map = module.reverse_da_slot_name_map +reverse_da_slot_name_map = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map') def delexicalize_da(da, requestable): delexicalized_da = [] diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index 7c807ae72368d32922df5f7ec87ffb366176506c..1e3b0c20bd959ea3098b07b813ed98189aac840f 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from pprint import pprint from convlab.util.file_util import cached_path import shutil +import importlib class BaseDatabase(ABC): @@ -20,13 +21,13 @@ class BaseDatabase(ABC): def query(self, domain:str, state:dict, topk:int, **kwargs)->list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" -def load_from_hf_datasets(dataset_name, filename, data_dir): +def download_unified_datasets(dataset_name, filename, data_dir): """ - It downloads the file from the Hugging Face if it doesn't exist in the data directory + It downloads the file of unified datasets from HuggingFace's datasets if it doesn't exist in the data directory :param dataset_name: The name of the dataset :param filename: the name of the file you want to download - :param data_dir: the directory where the data will be downloaded to + :param data_dir: the directory where the file will be downloaded to :return: The data path """ data_path = os.path.join(data_dir, filename) @@ -38,6 +39,32 @@ def load_from_hf_datasets(dataset_name, filename, data_dir): shutil.move(cache_path, data_path) return data_path +def relative_import_module_from_unified_datasets(dataset_name, filename, names2import): + """ + It downloads a file from the unified datasets repository, imports it as a module, and returns the + variable(s) you want from that module + + :param dataset_name: the name of the dataset, e.g. 'multiwoz21' + :param filename: the name of the file to download, e.g. 'preprocess.py' + :param names2import: a string or a list of strings. If it's a string, it's the name of the variable + to import. If it's a list of strings, it's the names of the variables to import + :return: the variable(s) that are being imported from the module. + """ + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + assert filename.endswith('.py') + assert isinstance(names2import, str) or (isinstance(names2import, list) and len(names2import) > 0) + data_path = download_unified_datasets(dataset_name, filename, data_dir) + module_spec = importlib.util.spec_from_file_location(filename[:-3], data_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + if isinstance(names2import, str): + return eval(f'module.{names2import}') + else: + variables = [] + for name in names2import: + variables.append(eval(f'module.{name}')) + return variables + def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: """load unified dataset from `data/unified_datasets/$dataset_name` @@ -49,14 +76,14 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: dataset (dict): keys are data splits and the values are lists of dialogues """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir) archive = ZipFile(data_path) with archive.open('data/dialogues.json') as f: dialogues = json.loads(f.read()) dataset = {} if dial_ids_order is not None: - data_path = load_from_hf_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) + data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) dial_ids = json.load(open(data_path))[dial_ids_order] for data_split in dial_ids: dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]] @@ -78,7 +105,7 @@ def load_ontology(dataset_name:str) -> Dict: ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir) archive = ZipFile(data_path) with archive.open('data/ontology.json') as f: @@ -95,11 +122,11 @@ def load_database(dataset_name:str): database: an instance of BaseDatabase """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'database.py', data_dir) + data_path = download_unified_datasets(dataset_name, 'database.py', data_dir) module_spec = importlib.util.spec_from_file_location('database', data_path) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) - Database = module.Database + Database = relative_import_module_from_unified_datasets(dataset_name, 'database.py', 'Database') assert issubclass(Database, BaseDatabase) database = Database() assert isinstance(database, BaseDatabase)