From 1bb7073ea291e4bdbc40713cc557abce4875b9d4 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Mon, 27 Jun 2022 10:00:24 +0800 Subject: [PATCH] change 'from data.unified_datasets.xxx import xxx' to using importlib and relative path, do not regard ConvLab-3/data as a package --- convlab/dst/rule/multiwoz/usr_dst.py | 8 +++++++- convlab/evaluator/multiwoz_eval.py | 9 ++++++++- convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py | 10 ++++++++-- convlab/policy/tus/multiwoz/TUS.py | 9 ++++++++- convlab/policy/vector/vector_base.py | 5 ++--- convlab/util/multiwoz/lexicalize.py | 8 +++++++- 6 files changed, 40 insertions(+), 9 deletions(-) diff --git a/convlab/dst/rule/multiwoz/usr_dst.py b/convlab/dst/rule/multiwoz/usr_dst.py index 962582b0..1de65476 100755 --- a/convlab/dst/rule/multiwoz/usr_dst.py +++ b/convlab/dst/rule/multiwoz/usr_dst.py @@ -6,12 +6,18 @@ from convlab.dst.rule.multiwoz.dst_util import normalize_value from convlab.dst.rule.multiwoz import RuleDST from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple +import importlib from pprint import pprint from copy import deepcopy from convlab.util import load_ontology +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da + SLOT2SEMI = { "arriveby": "arriveBy", "leaveat": "leaveAt", diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index e32e2fe9..115f39e6 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -7,7 +7,14 @@ from copy import deepcopy from convlab.evaluator.evaluator import Evaluator from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database -from data.unified_datasets.multiwoz21.preprocess import reverse_da +import os +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da requestable = \ {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 0919f0bd..432f8eb2 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -16,8 +16,14 @@ import logging from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da - +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da +normalize_domain_slot_value = module.normalize_domain_slot_value def unified_format(acts): new_acts = {'categorical': []} diff --git a/convlab/policy/tus/multiwoz/TUS.py b/convlab/policy/tus/multiwoz/TUS.py index b9f7c5cc..0fef5f0f 100644 --- a/convlab/policy/tus/multiwoz/TUS.py +++ b/convlab/policy/tus/multiwoz/TUS.py @@ -15,8 +15,15 @@ from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab.util.custom_util import model_downloader -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da +normalize_domain_slot_value = module.normalize_domain_slot_value DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 72970ba1..a5d1b382 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -4,7 +4,6 @@ import sys import numpy as np import logging -from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy from convlab.policy.vec import Vector from convlab.util.custom_util import flatten_acts @@ -27,8 +26,8 @@ class VectorBase(Vector): self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: - #self.db = load_database(dataset_name) - self.db = Database() + self.db = load_database(dataset_name) + # self.db = Database() self.db_domains = self.db.domains except Exception as e: self.db = None diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 7afced06..4c30014d 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -1,8 +1,14 @@ from copy import deepcopy from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map +import os +import importlib +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da_slot_name_map = module.reverse_da_slot_name_map def delexicalize_da(da, requestable): delexicalized_da = [] -- GitLab