From 6f53257a0350ba55e800778b961bb2981ffbf05f Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 27 Jun 2022 11:15:57 +0800 Subject: [PATCH] add func relative_import_module_from_unified_datasets: import module from unified datasets, support loading module from hf datasets --- convlab/dst/rule/multiwoz/usr_dst.py | 9 +--- convlab/evaluator/multiwoz_eval.py | 8 +--- .../rule/multiwoz/policy_agenda_multiwoz.py | 11 ++--- convlab/policy/tus/multiwoz/TUS.py | 11 ++--- convlab/util/__init__.py | 3 +- convlab/util/multiwoz/lexicalize.py | 9 +--- convlab/util/unified_datasets_util.py | 43 +++++++++++++++---- 7 files changed, 49 insertions(+), 45 deletions(-) diff --git a/convlab/dst/rule/multiwoz/usr_dst.py b/convlab/dst/rule/multiwoz/usr_dst.py index 1de65476..16d6a5c5 100755 --- a/convlab/dst/rule/multiwoz/usr_dst.py +++ b/convlab/dst/rule/multiwoz/usr_dst.py @@ -7,16 +7,11 @@ from convlab.dst.rule.multiwoz import RuleDST from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple -import importlib from pprint import pprint from copy import deepcopy -from convlab.util import load_ontology +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da +reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') SLOT2SEMI = { "arriveby": "arriveBy", diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index 115f39e6..b7331479 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -8,13 +8,9 @@ from convlab.evaluator.evaluator import Evaluator from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database import os -import importlib +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da +reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') requestable = \ {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 432f8eb2..0ffc5b38 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -16,14 +16,9 @@ import logging from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA -import importlib - -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da -normalize_domain_slot_value = module.normalize_domain_slot_value +from convlab.util import relative_import_module_from_unified_datasets + +reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) def unified_format(acts): new_acts = {'categorical': []} diff --git a/convlab/policy/tus/multiwoz/TUS.py b/convlab/policy/tus/multiwoz/TUS.py index 0fef5f0f..725098d9 100644 --- a/convlab/policy/tus/multiwoz/TUS.py +++ b/convlab/policy/tus/multiwoz/TUS.py @@ -16,14 +16,9 @@ from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab.util.custom_util import model_downloader from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple -import importlib - -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da = module.reverse_da -normalize_domain_slot_value = module.normalize_domain_slot_value +from convlab.util import relative_import_module_from_unified_datasets + +reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/convlab/util/__init__.py b/convlab/util/__init__.py index 6d90f319..6a84b7db 100755 --- a/convlab/util/__init__.py +++ b/convlab/util/__init__.py @@ -1,2 +1,3 @@ from convlab.util.unified_datasets_util import load_dataset, load_ontology, load_database, \ - load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data \ No newline at end of file + load_unified_data, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, \ + download_unified_datasets, relative_import_module_from_unified_datasets \ No newline at end of file diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 4c30014d..9eab25e2 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -1,14 +1,9 @@ from copy import deepcopy from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -import os -import importlib +from convlab.util import relative_import_module_from_unified_datasets -module_spec = importlib.util.spec_from_file_location('preprocess', \ - os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../data/unified_datasets/multiwoz21/preprocess.py'))) -module = importlib.util.module_from_spec(module_spec) -module_spec.loader.exec_module(module) -reverse_da_slot_name_map = module.reverse_da_slot_name_map +reverse_da_slot_name_map = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map') def delexicalize_da(da, requestable): delexicalized_da = [] diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index 7c807ae7..1e3b0c20 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from pprint import pprint from convlab.util.file_util import cached_path import shutil +import importlib class BaseDatabase(ABC): @@ -20,13 +21,13 @@ class BaseDatabase(ABC): def query(self, domain:str, state:dict, topk:int, **kwargs)->list: """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" -def load_from_hf_datasets(dataset_name, filename, data_dir): +def download_unified_datasets(dataset_name, filename, data_dir): """ - It downloads the file from the Hugging Face if it doesn't exist in the data directory + It downloads the file of unified datasets from HuggingFace's datasets if it doesn't exist in the data directory :param dataset_name: The name of the dataset :param filename: the name of the file you want to download - :param data_dir: the directory where the data will be downloaded to + :param data_dir: the directory where the file will be downloaded to :return: The data path """ data_path = os.path.join(data_dir, filename) @@ -38,6 +39,32 @@ def load_from_hf_datasets(dataset_name, filename, data_dir): shutil.move(cache_path, data_path) return data_path +def relative_import_module_from_unified_datasets(dataset_name, filename, names2import): + """ + It downloads a file from the unified datasets repository, imports it as a module, and returns the + variable(s) you want from that module + + :param dataset_name: the name of the dataset, e.g. 'multiwoz21' + :param filename: the name of the file to download, e.g. 'preprocess.py' + :param names2import: a string or a list of strings. If it's a string, it's the name of the variable + to import. If it's a list of strings, it's the names of the variables to import + :return: the variable(s) that are being imported from the module. + """ + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) + assert filename.endswith('.py') + assert isinstance(names2import, str) or (isinstance(names2import, list) and len(names2import) > 0) + data_path = download_unified_datasets(dataset_name, filename, data_dir) + module_spec = importlib.util.spec_from_file_location(filename[:-3], data_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + if isinstance(names2import, str): + return eval(f'module.{names2import}') + else: + variables = [] + for name in names2import: + variables.append(eval(f'module.{name}')) + return variables + def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: """load unified dataset from `data/unified_datasets/$dataset_name` @@ -49,14 +76,14 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict: dataset (dict): keys are data splits and the values are lists of dialogues """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir) archive = ZipFile(data_path) with archive.open('data/dialogues.json') as f: dialogues = json.loads(f.read()) dataset = {} if dial_ids_order is not None: - data_path = load_from_hf_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) + data_path = download_unified_datasets(dataset_name, 'shuffled_dial_ids.json', data_dir) dial_ids = json.load(open(data_path))[dial_ids_order] for data_split in dial_ids: dataset[data_split] = [dialogues[i] for i in dial_ids[data_split]] @@ -78,7 +105,7 @@ def load_ontology(dataset_name:str) -> Dict: ontology (dict): dataset ontology """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'data.zip', data_dir) + data_path = download_unified_datasets(dataset_name, 'data.zip', data_dir) archive = ZipFile(data_path) with archive.open('data/ontology.json') as f: @@ -95,11 +122,11 @@ def load_database(dataset_name:str): database: an instance of BaseDatabase """ data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/{dataset_name}')) - data_path = load_from_hf_datasets(dataset_name, 'database.py', data_dir) + data_path = download_unified_datasets(dataset_name, 'database.py', data_dir) module_spec = importlib.util.spec_from_file_location('database', data_path) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) - Database = module.Database + Database = relative_import_module_from_unified_datasets(dataset_name, 'database.py', 'Database') assert issubclass(Database, BaseDatabase) database = Database() assert isinstance(database, BaseDatabase) -- GitLab