Select Git revision
MovingParticles.mch
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
critic_json.py 13.22 KiB
import time
import os
import sys
import random
sys.path.append('../')
import json
import torch as th
import pdb
from tqdm import trange
from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
from latent_dialog.corpora import NormMultiWozCorpus
from latent_dialog.models_task import *
from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
from latent_dialog.main import OfflineCritic
from latent_dialog.evaluators import MultiWozEvaluator
from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE
from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic
from argparse import ArgumentParser
def main(seed, pretrained_folder, pretrained_model_id, response_path):
start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
print('[START]', start_time, '='*30)
# RL configuration
env = 'gpu'
exp_dir = os.path.join('sys_config_log_model', "/".join(response_path.split("/")[-2:-1]).replace(".json", ""), "critic-"+start_time)
if "rl" in pretrained_folder:
join_fmt = "."
config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json")
else:
join_fmt = "-"
config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json")
# create exp folder
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
critic_config = Pack(
config_path = config_path,
model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic
vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False
vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model",
actor_path = None,
critic_config_path = os.path.join(exp_dir, 'critic_config.json'),
critic_model_path = os.path.join(exp_dir, 'critic_model'),
saved_path = exp_dir,
# ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'),
# reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'),
record_path = exp_dir,
record_freq = 500,
sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there
use_gpu = env == 'gpu',
nepoch = 1,
nepisode = 1000,
word_plas=True,
raw_response=True,
response_path=response_path,
fix_episode=True,
train_with_pseudotraj=False,
train_with_full_data=False,
reward_type="default", #default, turnPenalty, or infoGain
infoGain_threshhold = 0.2, # only if reward_type is infoGain
add_match_to_reward=False,
soft_success=False,
goal_to_critic=True,
add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
critic_kl_loss=False,
critic_kl_alpha=0.1,
critic_dropout=True, # use with regularize_critic=False
critic_dropout_rate=0.3, # only if dropout is true
critic_dropout_agg="min", #avg or min
critic_sample=False,
critic_transformer=False,
critic_actf="sigmoid", #relu or sigmoid or tanh or none
embed_z_for_critic=True, #for categorical action only, when word_plas=False
# critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default
critic_loss="mse", #mse or huber
critic_rl_lr = 0.01,
decay_critic_lr=False,
train_vae=False,
train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
train_vae_nepisode=0, # nepisode for the rest of VAE training
train_vae_nepisode_init=10000, # nepisode for first VAE training
weighted_vae_nll=True,
rl_lr = 0.01,
max_words = 50,
temperature=1.0,
momentum = 0.0,
nesterov = False,
gamma = 0.99,
tau=0.005,
lmbda=0.5,
beta=0.001,
batch_size=16,
rl_clip = 1.0,
n_z=10,
random_seed = seed,
policy_dropout=False,
dropout_on_eval=False,
fail_info_penalty=False
)
prepare_dirs_loggers(critic_config)
# list config keys that are being compared for tensorboard naming
tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"]
tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys])
# load previous supervised learning configuration and corpus
config = Pack(json.load(open(critic_config.config_path)))
config['dropout'] = 0.0
config['use_gpu'] = critic_config.use_gpu
config['policy_dropout'] = critic_config.policy_dropout
config['dropout_on_eval'] = critic_config.dropout_on_eval
# assert config.train_path == critic_config.train_path
# set random seed
if critic_config.random_seed is None:
try:
critic_config.random_seed = config.seed
except:
critic_config.random_seed = config.random_seed
set_seed(critic_config.random_seed)
try:
corpus = NormMultiWozCorpus(config)
except FileNotFoundError:
config['train_path'] = config.train_path.replace("/home/lubis", "")
config['valid_path'] = config.valid_path.replace("/home/lubis", "")
config['test_path'] = config.test_path.replace("/home/lubis", "")
corpus = NormMultiWozCorpus(config)
critic_config['train_path'] = config['train_path']
critic_config['valid_path'] = config['valid_path']
critic_config['test_path'] = config['test_path']
if critic_config.reward_type == "default":
critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill")
critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill")
critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill")
else:
critic_config['train_memory_path'] = config['train_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
critic_config['valid_memory_path'] = config['valid_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
critic_config['test_memory_path'] = config['test_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
if critic_config.fix_episode:
critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill")
critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill")
critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill")
if critic_config.soft_success:
critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-soft.dill")
critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-soft.dill")
critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-soft.dill")
if critic_config.add_match_to_reward:
critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-wMatch.dill")
critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-wMatch.dill")
critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-wMatch.dill")
critic_config['y_size'] = config['y_size']
# save configuration
with open(critic_config.critic_config_path, 'w') as f:
json.dump(critic_config, f, indent=4)
if "rl" in pretrained_folder:
if "gauss" in pretrained_folder:
sys_model = SysPerfectBD2Gauss(corpus, config)
else:
sys_model = SysPerfectBD2Cat(corpus, config)
else:
if "actz" in pretrained_folder:
if "gauss" in pretrained_folder:
sys_model = SysActZGauss(corpus, config)
else:
sys_model = SysActZCat(corpus, config)
elif "mt" in pretrained_folder:
if "gauss" in pretrained_folder:
sys_model = SysMTGauss(corpus, config)
else:
sys_model = SysMTCat(corpus, config)
else:
if "gauss" in pretrained_folder:
sys_model = SysPerfectBD2Gauss(corpus, config)
else:
sys_model = SysPerfectBD2Cat(corpus, config)
vae_config = Pack(json.load(open(critic_config.vae_config_path)))
if critic_config.raw_response and not critic_config.word_plas:
if "gauss" in critic_config.vae_model_path:
vae_model = SysAEGauss(corpus, vae_config)
else:
vae_model = SysAECat(corpus, vae_config)
vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage)
vae_model.load_state_dict(vae_model_dict)
else:
vae_model = None
if config.use_gpu:
sys_model.cuda()
if vae_model is not None:
vae_model.cuda()
mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
sys_model.load_state_dict(mt_model_dict)
sys_model.eval()
evaluator = MultiWozEvaluator('SysWoz', config)
if critic_config.word_plas:
agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
else:
agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model)
main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic, name=tensorboard_name, vae_gen=task_generate)
# save sys model
# th.save(sys_model.state_dict(), critic_config.rl_model_path)
# initialize train buffer
if os.path.isfile(critic_config.train_memory_path):
print("Loading replay buffer for training from {}".format(critic_config.train_memory_path))
agent.train_buffer.load(critic_config.train_memory_path)
print(len(agent.train_buffer))
if "train_with_full_data" in critic_config and critic_config.train_with_full_data:
print(f"adding buffer {critic_config.valid_memory_path} to training buffer")
agent.train_buffer.load_add(critic_config.valid_memory_path)
print(len(agent.train_buffer))
print(f"adding buffer {critic_config.test_memory_path} to training buffer")
agent.train_buffer.load_add(critic_config.test_memory_path)
print(len(agent.train_buffer))
else:
print("Extracting experiences from training data")
main.extract(main.train_data, main.agent.train_buffer)
print("Saving experiences to {}".format(critic_config.train_memory_path))
agent.train_buffer.save(critic_config.train_memory_path)
# initialize valid buffer
if os.path.isfile(critic_config.valid_memory_path):
print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path))
agent.valid_buffer.load(critic_config.valid_memory_path)
else:
print("Extracting experiences from valid data")
main.extract(main.val_data, main.agent.valid_buffer)
print("Saving experiences to {}".format(critic_config.valid_memory_path))
agent.valid_buffer.save(critic_config.valid_memory_path)
# initialize test buffer
if os.path.isfile(critic_config.test_memory_path):
print("Loading replay buffer for test from {}".format(critic_config.test_memory_path))
agent.test_buffer.load(critic_config.test_memory_path)
else:
print("Extracting experiences from test data")
main.extract(main.test_data, main.agent.test_buffer)
print("Saving experiences to {}".format(critic_config.test_memory_path))
agent.test_buffer.save(critic_config.test_memory_path)
if critic_config.use_gpu:
agent.critic.cuda()
agent.critic_target.cuda()
#check system performance
train_dial, val_dial, test_dial = corpus.get_corpus()
train_data = BeliefDbDataLoaders('Train', train_dial, vae_config)
val_data = BeliefDbDataLoaders('Val', val_dial, vae_config)
test_data = BeliefDbDataLoaders('Test', test_dial, vae_config)
with open(exp_dir + "/test_performance_start.txt", "w") as f:
task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
#train critic
main.run()
with open(exp_dir + "/test_performance_end.txt", "w") as f:
task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
print('[END]', end_time, '='*30)
if __name__ == '__main__' :
parser = ArgumentParser()
parser.add_argument("--infile", type=str, default="../data/augpt/test-predictions.json")
args = parser.parse_args()
# pick corresponding encoder
if "hdsa" in args.infile or "HDSA" in args.infile:
# MWOZ 2.0
folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48"
id_ = "reward_best"
elif "augpt" in args.infile:
#MWOZ 2.1
folder = "2021-11-25-11-52-47-mt_gauss"
id_ = "29"
main(None, folder, id_, args.infile)