Skip to content
Snippets Groups Projects
Commit bcf3e287 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Merge branch 'fixes_merge' into 'github_master'

Update convlab/policy/ppo/train.py, convlab/policy/vector/dataset.py,...

See merge request dsml/convlab/ConvLab3!45
parents 46939fb3 3addcaa1
Branches
No related tags found
No related merge requests found
......@@ -199,7 +199,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] if seed else list()
args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path)
save_config(vars(parser.parse_args()), environment_config, config_save_path)
......@@ -228,6 +228,7 @@ if __name__ == '__main__':
env, sess = env_config(conf, policy_sys)
policy_sys.current_time = current_time
policy_sys.log_dir = config_save_path.replace('configs', 'logs')
policy_sys.save_dir = save_path
......
......@@ -18,6 +18,26 @@ class ActDataset(data.Dataset):
return self.num_total
class ActDatasetKG(data.Dataset):
def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch):
self.action_batch = action_batch
self.a_masks = a_masks
self.current_domain_mask_batch = current_domain_mask_batch
self.non_current_domain_mask_batch = non_current_domain_mask_batch
self.num_total = len(action_batch)
def __getitem__(self, index):
action = self.action_batch[index]
action_mask = self.a_masks[index]
current_domain_mask = self.current_domain_mask_batch[index]
non_current_domain_mask = self.non_current_domain_mask_batch[index]
return action, action_mask, current_domain_mask, non_current_domain_mask, index
def __len__(self):
return self.num_total
class ActStateDataset(data.Dataset):
def __init__(self, s_s, a_s, next_s):
self.s_s = s_s
......
......@@ -2,10 +2,11 @@
import os
import sys
import numpy as np
import logging
from copy import deepcopy
from convlab.policy.vec import Vector
from convlab.util.custom_util import flatten_acts
from convlab.util.custom_util import flatten_acts, timeout
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
from convlab.util import load_ontology, load_database, load_dataset
......@@ -22,10 +23,12 @@ class VectorBase(Vector):
super().__init__()
logging.info(f"Vectorizer: Data set used is {dataset_name}")
self.set_seed(seed)
self.ontology = load_ontology(dataset_name)
try:
# execute to make sure that the database exists or is downloaded otherwise
if dataset_name == "multiwoz21":
load_database(dataset_name)
# the following two lines are needed for pickling correctly during multi-processing
exec(f'from data.unified_datasets.{dataset_name}.database import Database')
......@@ -33,7 +36,7 @@ class VectorBase(Vector):
self.db_domains = self.db.domains
except Exception as e:
self.db = None
self.db_domains = None
self.db_domains = []
print(f"VectorBase: {e}")
self.dataset_name = dataset_name
......@@ -272,6 +275,8 @@ class VectorBase(Vector):
2. If there is an entity available, can not say NoOffer or NoBook
'''
mask_list = np.zeros(self.da_dim)
if number_entities_dict is None:
return mask_list
for i in range(self.da_dim):
action = self.vec2act[i]
domain, intent, slot, value = action.split('-')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment