Skip to content
Snippets Groups Projects
Commit 3addcaa1 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Update convlab/policy/ppo/train.py, convlab/policy/vector/dataset.py,...

parent 46939fb3
Branches
No related tags found
No related merge requests found
...@@ -199,7 +199,7 @@ if __name__ == '__main__': ...@@ -199,7 +199,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode) init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] if seed else list() args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path) environment_config = load_config_file(path)
save_config(vars(parser.parse_args()), environment_config, config_save_path) save_config(vars(parser.parse_args()), environment_config, config_save_path)
...@@ -228,6 +228,7 @@ if __name__ == '__main__': ...@@ -228,6 +228,7 @@ if __name__ == '__main__':
env, sess = env_config(conf, policy_sys) env, sess = env_config(conf, policy_sys)
policy_sys.current_time = current_time policy_sys.current_time = current_time
policy_sys.log_dir = config_save_path.replace('configs', 'logs') policy_sys.log_dir = config_save_path.replace('configs', 'logs')
policy_sys.save_dir = save_path policy_sys.save_dir = save_path
......
...@@ -18,6 +18,26 @@ class ActDataset(data.Dataset): ...@@ -18,6 +18,26 @@ class ActDataset(data.Dataset):
return self.num_total return self.num_total
class ActDatasetKG(data.Dataset):
def __init__(self, action_batch, a_masks, current_domain_mask_batch, non_current_domain_mask_batch):
self.action_batch = action_batch
self.a_masks = a_masks
self.current_domain_mask_batch = current_domain_mask_batch
self.non_current_domain_mask_batch = non_current_domain_mask_batch
self.num_total = len(action_batch)
def __getitem__(self, index):
action = self.action_batch[index]
action_mask = self.a_masks[index]
current_domain_mask = self.current_domain_mask_batch[index]
non_current_domain_mask = self.non_current_domain_mask_batch[index]
return action, action_mask, current_domain_mask, non_current_domain_mask, index
def __len__(self):
return self.num_total
class ActStateDataset(data.Dataset): class ActStateDataset(data.Dataset):
def __init__(self, s_s, a_s, next_s): def __init__(self, s_s, a_s, next_s):
self.s_s = s_s self.s_s = s_s
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
import os import os
import sys import sys
import numpy as np import numpy as np
import logging
from copy import deepcopy from copy import deepcopy
from convlab.policy.vec import Vector from convlab.policy.vec import Vector
from convlab.util.custom_util import flatten_acts from convlab.util.custom_util import flatten_acts, timeout
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
from convlab.util import load_ontology, load_database, load_dataset from convlab.util import load_ontology, load_database, load_dataset
...@@ -22,10 +23,12 @@ class VectorBase(Vector): ...@@ -22,10 +23,12 @@ class VectorBase(Vector):
super().__init__() super().__init__()
logging.info(f"Vectorizer: Data set used is {dataset_name}")
self.set_seed(seed) self.set_seed(seed)
self.ontology = load_ontology(dataset_name) self.ontology = load_ontology(dataset_name)
try: try:
# execute to make sure that the database exists or is downloaded otherwise # execute to make sure that the database exists or is downloaded otherwise
if dataset_name == "multiwoz21":
load_database(dataset_name) load_database(dataset_name)
# the following two lines are needed for pickling correctly during multi-processing # the following two lines are needed for pickling correctly during multi-processing
exec(f'from data.unified_datasets.{dataset_name}.database import Database') exec(f'from data.unified_datasets.{dataset_name}.database import Database')
...@@ -33,7 +36,7 @@ class VectorBase(Vector): ...@@ -33,7 +36,7 @@ class VectorBase(Vector):
self.db_domains = self.db.domains self.db_domains = self.db.domains
except Exception as e: except Exception as e:
self.db = None self.db = None
self.db_domains = None self.db_domains = []
print(f"VectorBase: {e}") print(f"VectorBase: {e}")
self.dataset_name = dataset_name self.dataset_name = dataset_name
...@@ -272,6 +275,8 @@ class VectorBase(Vector): ...@@ -272,6 +275,8 @@ class VectorBase(Vector):
2. If there is an entity available, can not say NoOffer or NoBook 2. If there is an entity available, can not say NoOffer or NoBook
''' '''
mask_list = np.zeros(self.da_dim) mask_list = np.zeros(self.da_dim)
if number_entities_dict is None:
return mask_list
for i in range(self.da_dim): for i in range(self.da_dim):
action = self.vec2act[i] action = self.vec2act[i]
domain, intent, slot, value = action.split('-') domain, intent, slot, value = action.split('-')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment