diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index f376bde76b59db47409d6ad7c5a55425a2d52e4e..703a55005b8c07578b85765a626d9871deebf26e 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -199,7 +199,7 @@ if __name__ == '__main__': logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ init_logging(os.path.dirname(os.path.abspath(__file__)), mode) - args = [('model', 'seed', seed)] if seed else list() + args = [('model', 'seed', seed)] if seed is not None else list() environment_config = load_config_file(path) save_config(vars(parser.parse_args()), environment_config, config_save_path) @@ -228,6 +228,7 @@ if __name__ == '__main__': env, sess = env_config(conf, policy_sys) + policy_sys.current_time = current_time policy_sys.log_dir = config_save_path.replace('configs', 'logs') policy_sys.save_dir = save_path diff --git a/convlab/policy/vector/dataset.py b/convlab/policy/vector/dataset.py index 0aa1b7ad879f2d814cece3a98bb457b71ad99033..5b233e6659abc18da69a3efe6bb2d52185aa30fd 100755 --- a/convlab/policy/vector/dataset.py +++ b/convlab/policy/vector/dataset.py @@ -18,6 +18,26 @@ class ActDataset(data.Dataset): return self.num_total +class ActDatasetKG(data.Dataset): + def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch): + self.action_batch = action_batch + self.a_masks = a_masks + self.current_domain_mask_batch = current_domain_mask_batch + self.non_current_domain_mask_batch = non_current_domain_mask_batch + self.num_total = len(action_batch) + + def __getitem__(self, index): + action = self.action_batch[index] + action_mask = self.a_masks[index] + current_domain_mask = self.current_domain_mask_batch[index] + non_current_domain_mask = self.non_current_domain_mask_batch[index] + + return action, action_mask, current_domain_mask, non_current_domain_mask, index + + def __len__(self): + return self.num_total + + class ActStateDataset(data.Dataset): def __init__(self, s_s, a_s, next_s): self.s_s = s_s @@ -32,4 +52,4 @@ class ActStateDataset(data.Dataset): return s, a, next_s def __len__(self): - return self.num_total \ No newline at end of file + return self.num_total diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 8b7d8ff0ddafed41efc91b249003ae55c525bc93..8f72144ce37a970fe4855a19d5bc8002fc2b4034 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -2,10 +2,11 @@ import os import sys import numpy as np +import logging from copy import deepcopy from convlab.policy.vec import Vector -from convlab.util.custom_util import flatten_acts +from convlab.util.custom_util import flatten_acts, timeout from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab.util import load_ontology, load_database, load_dataset @@ -22,18 +23,20 @@ class VectorBase(Vector): super().__init__() + logging.info(f"Vectorizer: Data set used is {dataset_name}") self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: # execute to make sure that the database exists or is downloaded otherwise - load_database(dataset_name) + if dataset_name == "multiwoz21": + load_database(dataset_name) # the following two lines are needed for pickling correctly during multi-processing exec(f'from data.unified_datasets.{dataset_name}.database import Database') self.db = eval('Database()') self.db_domains = self.db.domains except Exception as e: self.db = None - self.db_domains = None + self.db_domains = [] print(f"VectorBase: {e}") self.dataset_name = dataset_name @@ -272,6 +275,8 @@ class VectorBase(Vector): 2. If there is an entity available, can not say NoOffer or NoBook ''' mask_list = np.zeros(self.da_dim) + if number_entities_dict is None: + return mask_list for i in range(self.da_dim): action = self.vec2act[i] domain, intent, slot, value = action.split('-')