diff --git a/convlab/dst/rule/multiwoz/usr_dst.py b/convlab/dst/rule/multiwoz/usr_dst.py index 962582b0c3b9886c7cb581efa28bd3acd8915bb3..1de65476b6f0a6a0f419d733af867717c4c1729a 100755 --- a/convlab/dst/rule/multiwoz/usr_dst.py +++ b/convlab/dst/rule/multiwoz/usr_dst.py @@ -6,12 +6,18 @@ from convlab.dst.rule.multiwoz.dst_util import normalize_value from convlab.dst.rule.multiwoz import RuleDST from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple +import importlib from pprint import pprint from copy import deepcopy from convlab.util import load_ontology +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da + SLOT2SEMI = { "arriveby": "arriveBy", "leaveat": "leaveAt", diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index e32e2fe9f483476e1658540ff73502bce53986cc..115f39e64958da727d51ea91431bb8eaa6ef5a5d 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -7,7 +7,14 @@ from copy import deepcopy from convlab.evaluator.evaluator import Evaluator from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database -from data.unified_datasets.multiwoz21.preprocess import reverse_da +import os +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da requestable = \ {'attraction': ['post', 'phone', 'addr', 'fee', 'area', 'type'], diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 0919f0bd18b83120a4c572a72c05b3054a479600..432f8eb214c6384cd8cec439abfb8d4dc4ec1e7c 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -16,8 +16,14 @@ import logging from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da - +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da +normalize_domain_slot_value = module.normalize_domain_slot_value def unified_format(acts): new_acts = {'categorical': []} diff --git a/convlab/policy/tus/multiwoz/TUS.py b/convlab/policy/tus/multiwoz/TUS.py index b9f7c5cc2c769261e0a0388a818919e7f29cdf3f..0fef5f0f09a65caffdecd24214c6da9cd94c2c98 100644 --- a/convlab/policy/tus/multiwoz/TUS.py +++ b/convlab/policy/tus/multiwoz/TUS.py @@ -15,8 +15,15 @@ from convlab.policy.policy import Policy from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab.util.custom_util import model_downloader -from data.unified_datasets.multiwoz21.preprocess import normalize_domain_slot_value, reverse_da from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple +import importlib + +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da = module.reverse_da +normalize_domain_slot_value = module.normalize_domain_slot_value DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 72970ba1d46b293666477550ac33c5497b7c6d6a..a5d1b382262757e508f7e0e44b93543d07d846b5 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -4,7 +4,6 @@ import sys import numpy as np import logging -from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy from convlab.policy.vec import Vector from convlab.util.custom_util import flatten_acts @@ -27,8 +26,8 @@ class VectorBase(Vector): self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: - #self.db = load_database(dataset_name) - self.db = Database() + self.db = load_database(dataset_name) + # self.db = Database() self.db_domains = self.db.domains except Exception as e: self.db = None diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 7afced06dd35253ea998feaca97fa3b9872f903d..4c30014dbc936978d356df40ff39a9f67c0eec0a 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -1,8 +1,14 @@ from copy import deepcopy from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA -from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map +import os +import importlib +module_spec = importlib.util.spec_from_file_location('preprocess', \ + os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../../../data/unified_datasets/multiwoz21/preprocess.py'))) +module = importlib.util.module_from_spec(module_spec) +module_spec.loader.exec_module(module) +reverse_da_slot_name_map = module.reverse_da_slot_name_map def delexicalize_da(da, requestable): delexicalized_da = []