diff --git a/data.zip b/data.zip new file mode 100644 index 0000000000000000000000000000000000000000..a0d264ee2b4380d22755f4b05ff749d56b56687e Binary files /dev/null and b/data.zip differ diff --git a/experiments_woz/__init__.py b/experiments_woz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b38317c87521b9439cf7fa52ab450c5f7b77a513 --- /dev/null +++ b/experiments_woz/__init__.py @@ -0,0 +1,2 @@ +# @Time : 10/18/18 1:49 PM +# @Author : Tiancheng Zhao \ No newline at end of file diff --git a/experiments_woz/__pycache__/__init__.cpython-36.pyc b/experiments_woz/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d79281e874bb19e4c5f22fe4b733796b45130d Binary files /dev/null and b/experiments_woz/__pycache__/__init__.cpython-36.pyc differ diff --git a/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc b/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d25cfe50e02ef84e0673d40529caddbf430334 Binary files /dev/null and b/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc differ diff --git a/experiments_woz/critic.py b/experiments_woz/critic.py new file mode 100644 index 0000000000000000000000000000000000000000..0438bea2aa59cd13a35ef79f8e9e874e58bcc8c0 --- /dev/null +++ b/experiments_woz/critic.py @@ -0,0 +1,262 @@ +import time +import os +import sys +import random +sys.path.append('../') +import json +import torch as th +import pdb +from tqdm import trange +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +from latent_dialog.corpora import NormMultiWozCorpus +from latent_dialog.models_task import * +from latent_dialog.agent_task import LatentCriticAgent, CriticAgent +from latent_dialog.main import OfflineCritic +from latent_dialog.evaluators import MultiWozEvaluator +from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic + + +def main(seed, pretrained_folder, pretrained_model_id): + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[START]', start_time, '='*30) + # RL configuration + env = 'gpu' + + exp_dir = os.path.join('sys_config_log_model', pretrained_folder, "critic-"+start_time) + if "rl" in pretrained_folder: + join_fmt = ".model" + config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json") + else: + join_fmt = "-model" + config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json") + + # create exp folder + if not os.path.exists(exp_dir): + os.mkdir(exp_dir) + + critic_config = Pack( + config_path = config_path, + model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}'.format(pretrained_model_id, join_fmt)), + critic_config_path = os.path.join(exp_dir, 'critic_config.json'), + critic_model_path = os.path.join(exp_dir, 'critic_model'), + saved_path = exp_dir, + record_path = exp_dir, + record_freq = 500, + sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there + use_gpu = env == 'gpu', + nepoch = 1, + nepisode = 10000, + word_plas=True, + fix_episode=True, + train_with_full_data=True, + reward_type="default", #default, turnPenalty, or infoGain + infoGain_threshhold = 0.2, # only if reward_type is infoGain + goal_to_critic=True, + add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True + critic_kl_loss=False, + critic_kl_alpha=0.1, + critic_dropout=True, # use with regularize_critic=False + critic_dropout_rate=0.3, # only if dropout is true + critic_dropout_agg="min", #avg or min + critic_sample=False, + critic_transformer=False, + critic_actf="sigmoid", #relu or sigmoid or tanh or none + embed_z_for_critic=True, #for categorical action only, when word_plas=False + critic_loss="mse", #mse or huber + critic_rl_lr = 0.01, + decay_critic_lr=False, + train_vae=False, + train_vae_freq=10001, # if larger than nepisode. only train once at the beginning + train_vae_nepisode=0, # nepisode for the rest of VAE training + train_vae_nepisode_init=10000, # nepisode for first VAE training + weighted_vae_nll=True, + rl_lr = 0.01, + max_words = 50, + temperature=1.0, + momentum = 0.0, + nesterov = False, + gamma = 0.99, + tau=0.005, + lmbda=0.5, + beta=0.001, + batch_size=16, + rl_clip = 1.0, + n_z=10, + random_seed = seed, + policy_dropout=False, + dropout_on_eval=False, + fail_info_penalty=False + ) + + if "plas" in pretrained_folder: + critic_config['actor_path'] = critic_config.model_path.replace(".model", ".actor") + critic_config['actor_config'] = critic_config.model_path.replace("reward_best.model", "rl_config.json") + else: + critic_config['actor_path'] = None + critic_config['actor_config'] = None + + prepare_dirs_loggers(critic_config) + + # list config keys that are being compared for tensorboard naming + tb_keys = ["critic_rl_lr", "reward_type", "word_plas", "train_with_full_data"] + tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys]) + + # load previous supervised learning configuration and corpus + config = Pack(json.load(open(critic_config.config_path))) + config['dropout'] = 0.0 + config['use_gpu'] = critic_config.use_gpu + config['policy_dropout'] = critic_config.policy_dropout + config['dropout_on_eval'] = critic_config.dropout_on_eval + # assert config.train_path == critic_config.train_path + + # set random seed + if critic_config.random_seed is None: + try: + critic_config.random_seed = config.seed + except: + critic_config.random_seed = config.random_seed + set_seed(critic_config.random_seed) + + + try: + corpus = NormMultiWozCorpus(config) + except FileNotFoundError: + config['train_path'] = config.train_path.replace("/home/lubis", "") + config['valid_path'] = config.valid_path.replace("/home/lubis", "") + config['test_path'] = config.test_path.replace("/home/lubis", "") + corpus = NormMultiWozCorpus(config) + except PermissionError: + config['train_path'] = config.train_path.replace("/root", "..") + config['valid_path'] = config.valid_path.replace("/root", "..") + config['test_path'] = config.test_path.replace("/root", "..") + corpus = NormMultiWozCorpus(config) + + critic_config['train_path'] = config['train_path'] + critic_config['valid_path'] = config['valid_path'] + critic_config['test_path'] = config['test_path'] + + critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill") + critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill") + critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill") + + if critic_config.fix_episode: + critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill") + critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill") + critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill") + + + critic_config['y_size'] = config['y_size'] + + + # save configuration + with open(critic_config.critic_config_path, 'w') as f: + json.dump(critic_config, f, indent=4) + + if "rl" in pretrained_folder: + if "gauss" in pretrained_folder: + if "plas" in pretrained_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + else: + if "actz" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysActZGauss(corpus, config) + else: + sys_model = SysActZCat(corpus, config) + elif "mt" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysMTCat(corpus, config) + else: + if "gauss" in pretrained_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + + if config.use_gpu: + sys_model.cuda() + + mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage) + sys_model.load_state_dict(mt_model_dict) + + sys_model.eval() + evaluator = MultiWozEvaluator('SysWoz', config) + + if critic_config.word_plas: + agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + else: + agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + + main = OfflineCritic(agent, corpus, config, critic_config, task_generate_critic, name=tensorboard_name, vae_gen=task_generate) + # save sys model + # th.save(sys_model.state_dict(), critic_config.rl_model_path) + + # initialize train buffer + if os.path.isfile(critic_config.train_memory_path): + # if False: + print("Loading replay buffer for training from {}".format(critic_config.train_memory_path)) + agent.train_buffer.load(critic_config.train_memory_path) + print(len(agent.train_buffer)) + if "train_with_full_data" in critic_config and critic_config.train_with_full_data: + print(f"adding buffer {critic_config.valid_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.valid_memory_path) + print(len(agent.train_buffer)) + print(f"adding buffer {critic_config.test_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.test_memory_path) + print(len(agent.train_buffer)) + + else: + print("Extracting experiences from training data") + main.extract(main.train_data, main.agent.train_buffer) + print("Saving experiences to {}".format(critic_config.train_memory_path)) + agent.train_buffer.save(critic_config.train_memory_path) + + # initialize valid buffer + if os.path.isfile(critic_config.valid_memory_path): + print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.load(critic_config.valid_memory_path) + else: + print("Extracting experiences from valid data") + main.extract(main.val_data, main.agent.valid_buffer) + print("Saving experiences to {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.save(critic_config.valid_memory_path) + + # initialize test buffer + if os.path.isfile(critic_config.test_memory_path): + print("Loading replay buffer for test from {}".format(critic_config.test_memory_path)) + agent.test_buffer.load(critic_config.test_memory_path) + else: + print("Extracting experiences from test data") + main.extract(main.test_data, main.agent.test_buffer) + print("Saving experiences to {}".format(critic_config.test_memory_path)) + agent.test_buffer.save(critic_config.test_memory_path) + + if critic_config.use_gpu: + agent.critic.cuda() + agent.critic_target.cuda() + + # with open(exp_dir + "/test_performance_start.txt", "w") as f: + # task_run_critic(main.test_data, agent, None, evaluator=evaluator, outfile=f) + + #train critic + main.run() + + # with open(exp_dir + "/test_performance_end.txt", "w") as f: + # task_run_critic(main.test_data, agent, None, evaluator=evaluator, outfile=f) + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '='*30) + + +if __name__ == '__main__' : + # fixed target actor, i.e. model that the critic wants to evaluate + # for non-LAVA based model, use critic_json.py + folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48" + id_ = "reward_best" + main(None, folder, id_) + + diff --git a/experiments_woz/critic_json.py b/experiments_woz/critic_json.py new file mode 100644 index 0000000000000000000000000000000000000000..ca94e953065649b6cb302a8c21cf5f3561237714 --- /dev/null +++ b/experiments_woz/critic_json.py @@ -0,0 +1,305 @@ +import time +import os +import sys +import random +sys.path.append('../') +import json +import torch as th +import pdb +from tqdm import trange +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +from latent_dialog.corpora import NormMultiWozCorpus +from latent_dialog.models_task import * +from latent_dialog.agent_task import LatentCriticAgent, CriticAgent +from latent_dialog.main import OfflineCritic +from latent_dialog.evaluators import MultiWozEvaluator +from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE +from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic +from argparse import ArgumentParser + + +def main(seed, pretrained_folder, pretrained_model_id, response_path): + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[START]', start_time, '='*30) + # RL configuration + env = 'gpu' + + exp_dir = os.path.join('sys_config_log_model', "/".join(response_path.split("/")[-2:-1]).replace(".json", ""), "critic-"+start_time) + + if "rl" in pretrained_folder: + join_fmt = "." + config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json") + else: + join_fmt = "-" + config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json") + + # create exp folder + if not os.path.exists(exp_dir): + os.makedirs(exp_dir) + + critic_config = Pack( + config_path = config_path, + model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic + vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False + vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model", + actor_path = None, + critic_config_path = os.path.join(exp_dir, 'critic_config.json'), + critic_model_path = os.path.join(exp_dir, 'critic_model'), + saved_path = exp_dir, + # ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'), + # reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'), + record_path = exp_dir, + record_freq = 500, + sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there + use_gpu = env == 'gpu', + nepoch = 1, + nepisode = 1000, + word_plas=True, + raw_response=True, + response_path=response_path, + fix_episode=True, + train_with_pseudotraj=False, + train_with_full_data=False, + reward_type="default", #default, turnPenalty, or infoGain + infoGain_threshhold = 0.2, # only if reward_type is infoGain + add_match_to_reward=False, + soft_success=False, + goal_to_critic=True, + add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True + critic_kl_loss=False, + critic_kl_alpha=0.1, + critic_dropout=True, # use with regularize_critic=False + critic_dropout_rate=0.3, # only if dropout is true + critic_dropout_agg="min", #avg or min + critic_sample=False, + critic_transformer=False, + critic_actf="sigmoid", #relu or sigmoid or tanh or none + embed_z_for_critic=True, #for categorical action only, when word_plas=False + # critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default + critic_loss="mse", #mse or huber + critic_rl_lr = 0.01, + decay_critic_lr=False, + train_vae=False, + train_vae_freq=10001, # if larger than nepisode. only train once at the beginning + train_vae_nepisode=0, # nepisode for the rest of VAE training + train_vae_nepisode_init=10000, # nepisode for first VAE training + weighted_vae_nll=True, + rl_lr = 0.01, + max_words = 50, + temperature=1.0, + momentum = 0.0, + nesterov = False, + gamma = 0.99, + tau=0.005, + lmbda=0.5, + beta=0.001, + batch_size=16, + rl_clip = 1.0, + n_z=10, + random_seed = seed, + policy_dropout=False, + dropout_on_eval=False, + fail_info_penalty=False + ) + + prepare_dirs_loggers(critic_config) + + # list config keys that are being compared for tensorboard naming + tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"] + tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys]) + + # load previous supervised learning configuration and corpus + config = Pack(json.load(open(critic_config.config_path))) + config['dropout'] = 0.0 + config['use_gpu'] = critic_config.use_gpu + config['policy_dropout'] = critic_config.policy_dropout + config['dropout_on_eval'] = critic_config.dropout_on_eval + # assert config.train_path == critic_config.train_path + + # set random seed + if critic_config.random_seed is None: + try: + critic_config.random_seed = config.seed + except: + critic_config.random_seed = config.random_seed + set_seed(critic_config.random_seed) + + + try: + corpus = NormMultiWozCorpus(config) + except FileNotFoundError: + config['train_path'] = config.train_path.replace("/home/lubis", "") + config['valid_path'] = config.valid_path.replace("/home/lubis", "") + config['test_path'] = config.test_path.replace("/home/lubis", "") + corpus = NormMultiWozCorpus(config) + + critic_config['train_path'] = config['train_path'] + critic_config['valid_path'] = config['valid_path'] + critic_config['test_path'] = config['test_path'] + + if critic_config.reward_type == "default": + critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill") + critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill") + critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill") + else: + critic_config['train_memory_path'] = config['train_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill") + critic_config['valid_memory_path'] = config['valid_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill") + critic_config['test_memory_path'] = config['test_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill") + + if critic_config.fix_episode: + critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill") + critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill") + critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill") + + if critic_config.soft_success: + critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-soft.dill") + critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-soft.dill") + critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-soft.dill") + + if critic_config.add_match_to_reward: + critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-wMatch.dill") + critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-wMatch.dill") + critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-wMatch.dill") + + + critic_config['y_size'] = config['y_size'] + + + # save configuration + with open(critic_config.critic_config_path, 'w') as f: + json.dump(critic_config, f, indent=4) + + if "rl" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + else: + if "actz" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysActZGauss(corpus, config) + else: + sys_model = SysActZCat(corpus, config) + elif "mt" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysMTCat(corpus, config) + else: + if "gauss" in pretrained_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + + vae_config = Pack(json.load(open(critic_config.vae_config_path))) + if critic_config.raw_response and not critic_config.word_plas: + if "gauss" in critic_config.vae_model_path: + vae_model = SysAEGauss(corpus, vae_config) + else: + vae_model = SysAECat(corpus, vae_config) + vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage) + vae_model.load_state_dict(vae_model_dict) + else: + vae_model = None + + if config.use_gpu: + sys_model.cuda() + if vae_model is not None: + vae_model.cuda() + + + mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage) + sys_model.load_state_dict(mt_model_dict) + + sys_model.eval() + evaluator = MultiWozEvaluator('SysWoz', config) + + if critic_config.word_plas: + agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + else: + agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model) + + main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic, name=tensorboard_name, vae_gen=task_generate) + # save sys model + # th.save(sys_model.state_dict(), critic_config.rl_model_path) + + # initialize train buffer + if os.path.isfile(critic_config.train_memory_path): + print("Loading replay buffer for training from {}".format(critic_config.train_memory_path)) + agent.train_buffer.load(critic_config.train_memory_path) + print(len(agent.train_buffer)) + if "train_with_full_data" in critic_config and critic_config.train_with_full_data: + print(f"adding buffer {critic_config.valid_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.valid_memory_path) + print(len(agent.train_buffer)) + print(f"adding buffer {critic_config.test_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.test_memory_path) + print(len(agent.train_buffer)) + + else: + print("Extracting experiences from training data") + main.extract(main.train_data, main.agent.train_buffer) + print("Saving experiences to {}".format(critic_config.train_memory_path)) + agent.train_buffer.save(critic_config.train_memory_path) + + # initialize valid buffer + if os.path.isfile(critic_config.valid_memory_path): + print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.load(critic_config.valid_memory_path) + else: + print("Extracting experiences from valid data") + main.extract(main.val_data, main.agent.valid_buffer) + print("Saving experiences to {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.save(critic_config.valid_memory_path) + + # initialize test buffer + if os.path.isfile(critic_config.test_memory_path): + print("Loading replay buffer for test from {}".format(critic_config.test_memory_path)) + agent.test_buffer.load(critic_config.test_memory_path) + else: + print("Extracting experiences from test data") + main.extract(main.test_data, main.agent.test_buffer) + print("Saving experiences to {}".format(critic_config.test_memory_path)) + agent.test_buffer.save(critic_config.test_memory_path) + + if critic_config.use_gpu: + agent.critic.cuda() + agent.critic_target.cuda() + + #check system performance + train_dial, val_dial, test_dial = corpus.get_corpus() + train_data = BeliefDbDataLoaders('Train', train_dial, vae_config) + val_data = BeliefDbDataLoaders('Val', val_dial, vae_config) + test_data = BeliefDbDataLoaders('Test', test_dial, vae_config) + + with open(exp_dir + "/test_performance_start.txt", "w") as f: + task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f) + #train critic + main.run() + + with open(exp_dir + "/test_performance_end.txt", "w") as f: + task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f) + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '='*30) + + +if __name__ == '__main__' : + parser = ArgumentParser() + + parser.add_argument("--infile", type=str, default="../data/augpt/test-predictions.json") + args = parser.parse_args() + + # pick corresponding encoder + if "hdsa" in args.infile or "HDSA" in args.infile: + # MWOZ 2.0 + folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48" + id_ = "reward_best" + elif "augpt" in args.infile: + #MWOZ 2.1 + folder = "2021-11-25-11-52-47-mt_gauss" + id_ = "29" + + main(None, folder, id_, args.infile) + + diff --git a/experiments_woz/critic_mwoz.py b/experiments_woz/critic_mwoz.py new file mode 100644 index 0000000000000000000000000000000000000000..87bda4e117d93747e914a1d316fb5559b22ccd94 --- /dev/null +++ b/experiments_woz/critic_mwoz.py @@ -0,0 +1,272 @@ +import time +import os +import sys +import random +sys.path.append('../') +import json +import torch as th +import pdb +from tqdm import trange +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +from latent_dialog.corpora import NormMultiWozCorpus +from latent_dialog.models_task import * +from latent_dialog.agent_task import LatentCriticAgent, CriticAgent +from latent_dialog.main import OfflineCritic +from latent_dialog.evaluators import MultiWozEvaluator +from experiments_woz.dialog_utils import task_generate_critic, task_run_critic_on_behavior_policy + + +def main(seed, pretrained_folder, pretrained_model_id): + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[START]', start_time, '='*30) + # RL configuration + env = 'gpu' + + exp_dir = os.path.join('sys_config_log_model', "/".join(["mwoz", "critic-"+start_time])) + if "rl" in pretrained_folder: + join_fmt = "." + config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json") + else: + join_fmt = "-" + config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json") + + # create exp folder + if not os.path.exists(exp_dir): + os.makedirs(exp_dir) + + critic_config = Pack( + config_path = config_path, + model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic + vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False + vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model", + actor_path = None, + + critic_config_path = os.path.join(exp_dir, 'critic_config.json'), + critic_model_path = os.path.join(exp_dir, 'critic_model'), + saved_path = exp_dir, + + # ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'), + # reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'), + record_path = exp_dir, + record_freq = 500, + sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there + use_gpu = env == 'gpu', + nepoch = 10, + nepisode = 10000, + word_plas=True, + raw_response=False, + corpus_response=True, + # response_path=response_path, + fix_episode=True, + train_with_pseudotraj=False, + train_with_full_data=True, + reward_type="default", #default, turnPenalty, or infoGain + infoGain_threshhold = 0.2, # only if reward_type is infoGain + add_match_to_reward=False, + soft_success=False, + goal_to_critic=True, + add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True + critic_kl_loss=False, + critic_kl_alpha=0.1, + critic_dropout=True, # use with regularize_critic=False + critic_dropout_rate=0.3, # only if dropout is true + critic_dropout_agg="min", #avg or min + critic_sample=False, + critic_transformer=False, + critic_actf="sigmoid", #relu or sigmoid or tanh or none + embed_z_for_critic=True, #for categorical action only, when word_plas=False + # critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default + critic_loss="mse", #mse or huber + critic_rl_lr = 0.01, + decay_critic_lr=False, + train_vae=False, + train_vae_freq=10001, # if larger than nepisode. only train once at the beginning + train_vae_nepisode=0, # nepisode for the rest of VAE training + train_vae_nepisode_init=10000, # nepisode for first VAE training + weighted_vae_nll=True, + rl_lr = 0.01, + max_words = 50, + temperature=1.0, + momentum = 0.0, + nesterov = False, + gamma = 0.99, + tau=0.005, + lmbda=0.5, + beta=0.001, + batch_size=16, + rl_clip = 1.0, + n_z=10, + random_seed = seed, + policy_dropout=False, + dropout_on_eval=False, + fail_info_penalty=False + ) + + prepare_dirs_loggers(critic_config) + + # list config keys that are being compared for tensorboard naming + tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"] + tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys]) + + # load previous supervised learning configuration and corpus + config = Pack(json.load(open(critic_config.config_path))) + config['dropout'] = 0.0 + config['use_gpu'] = critic_config.use_gpu + config['policy_dropout'] = critic_config.policy_dropout + config['dropout_on_eval'] = critic_config.dropout_on_eval + # assert config.train_path == critic_config.train_path + + # set random seed + if critic_config.random_seed is None: + try: + critic_config.random_seed = config.seed + except: + critic_config.random_seed = config.random_seed + set_seed(critic_config.random_seed) + + + try: + corpus = NormMultiWozCorpus(config) + except FileNotFoundError: + config['train_path'] = config.train_path.replace("/home/lubis", "") + config['valid_path'] = config.valid_path.replace("/home/lubis", "") + config['test_path'] = config.test_path.replace("/home/lubis", "") + corpus = NormMultiWozCorpus(config) + + critic_config['train_path'] = config['train_path'] + critic_config['valid_path'] = config['valid_path'] + critic_config['test_path'] = config['test_path'] + + critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill") + critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill") + critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill") + + if critic_config.fix_episode: + critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill") + critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill") + critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill") + + critic_config['y_size'] = config['y_size'] + + # save configuration + with open(critic_config.critic_config_path, 'w') as f: + json.dump(critic_config, f, indent=4) + + if "rl" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + else: + if "actz" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysActZGauss(corpus, config) + else: + sys_model = SysActZCat(corpus, config) + elif "mt" in pretrained_folder: + if "gauss" in pretrained_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysMTCat(corpus, config) + else: + if "gauss" in pretrained_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + + if critic_config.corpus_response and not critic_config.word_plas: + vae_config = Pack(json.load(open(critic_config.vae_config_path))) + if "gauss" in critic_config.vae_model_path: + vae_model = SysAEGauss(corpus, vae_config) + else: + vae_model = SysAECat(corpus, vae_config) + vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage) + vae_model.load_state_dict(vae_model_dict) + else: + vae_model = None + + if config.use_gpu: + sys_model.cuda() + if vae_model is not None: + vae_model.cuda() + + + mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage) + sys_model.load_state_dict(mt_model_dict) + + sys_model.eval() + evaluator = MultiWozEvaluator('SysWoz', config) + + if critic_config.word_plas: + agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + else: + agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model) + + main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic_on_behavior_policy, name=tensorboard_name, vae_gen=None) + # save sys model + # th.save(sys_model.state_dict(), critic_config.rl_model_path) + + # initialize train buffer + if os.path.isfile(critic_config.train_memory_path): + print("Loading replay buffer for training from {}".format(critic_config.train_memory_path)) + agent.train_buffer.load(critic_config.train_memory_path) + if "train_with_full_data" in critic_config and critic_config.train_with_full_data: + print(f"adding buffer {critic_config.valid_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.valid_memory_path) + print(len(agent.train_buffer)) + print(f"adding buffer {critic_config.test_memory_path} to training buffer") + agent.train_buffer.load_add(critic_config.test_memory_path) + print(len(agent.train_buffer)) + + else: + print("Extracting experiences from training data") + main.extract(main.train_data, main.agent.train_buffer) + print("Saving experiences to {}".format(critic_config.train_memory_path)) + agent.train_buffer.save(critic_config.train_memory_path) + + # initialize valid buffer + if os.path.isfile(critic_config.valid_memory_path): + print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.load(critic_config.valid_memory_path) + else: + print("Extracting experiences from valid data") + main.extract(main.val_data, main.agent.valid_buffer) + print("Saving experiences to {}".format(critic_config.valid_memory_path)) + agent.valid_buffer.save(critic_config.valid_memory_path) + + # initialize test buffer + if os.path.isfile(critic_config.test_memory_path): + print("Loading replay buffer for test from {}".format(critic_config.test_memory_path)) + agent.test_buffer.load(critic_config.test_memory_path) + else: + print("Extracting experiences from test data") + main.extract(main.test_data, main.agent.test_buffer) + print("Saving experiences to {}".format(critic_config.test_memory_path)) + agent.test_buffer.save(critic_config.test_memory_path) + + if critic_config.use_gpu: + agent.critic.cuda() + agent.critic_target.cuda() + + with open(exp_dir + "/test_performance_start.txt", "w") as f: + task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f) + + #train critic + main.run_behavior_policy() + + with open(exp_dir + "/test_performance_end.txt", "w") as f: + task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f) + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '='*30) + + +if __name__ == '__main__' : + + #pretrained encoder for critic + folder = "2021-11-25-11-52-47-mt_gauss" + id_ = "29" + + main(None, folder, id_) + + diff --git a/experiments_woz/dialog_utils.py b/experiments_woz/dialog_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af63c856d9f67c86bdd5175d9ed50f0936b5ba2e --- /dev/null +++ b/experiments_woz/dialog_utils.py @@ -0,0 +1,906 @@ +import numpy as np +import scipy.stats as st +import torch as th +from latent_dialog.enc2dec.decoders import GEN, DecoderRNN +from latent_dialog.main import get_sent +from latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type +from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS +from collections import defaultdict +import pdb +import warnings + +def mean_confidence_interval(data, confidence=0.95): + a = 1.0 * np.array(data) + n = len(a) + m, se = np.mean(a), st.sem(a) + h = se * st.t.ppf((1 + confidence) / 2., n-1) + return m, m-h, m+h + +def task_generate(model, data, config, evaluator, num_batch, dest_f=None, verbose=True, aux_mt=False): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + if aux_mt: + #TODO aux rl forward? + outputs, labels = model.forward_aux(batch, mode=GEN, gen_type=config.gen_type) + else: + outputs, labels = model.forward(batch, mode=GEN, gen_type=config.gen_type) + + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + # TODO attn + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + if verbose and (num_batch is None or batch_cnt < 2): + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('-' * 40) + + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + write(task_report_new) + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1 + +def task_generate_critic(model, data, config, critic_config, evaluator, num_batch, dest_f=None, verbose=True, critic=None, actor=None, outfile=None): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + Q = [] + generated_dialogs = defaultdict(list) + cur_key = "" + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + + batch_size = len(batch['bs']) + + # with intermediate latent step + if not critic_config.word_plas: + if critic_config.actor_path is not None: + raise NotImplementedError + else: + if "gauss" in critic_config.model_path: + sample_z, _, _ = model.get_z_via_rg(batch) + corpus_z, _, _ = model.get_z_via_vae(batch['outputs']) + else: + sample_z, _, _ = model.get_z_via_rg(batch, hard=True) + _, pred_labels = model.decode_z(sample_z, batch_size, batch, critic_config.max_words, critic_config.temperature) + if type(pred_labels[0]) == int: + pred_labels = [pred_labels] + pred_labels_gpu = model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in pred_labels]), LONG) + # NOTE the above pass result in lower success rate for categorical models + else: + # with forward pass + # outputs, labels = model.forward(batch, mode=GEN, gen_type=config.gen_type) + # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + # pred_labels_gpu = model.np2var(pred_labels, LONG) + + # with forward rl pass + if "actor_path" in critic_config and critic_config.actor_path is not None: + sample_z, _, _ = actor(batch) + _, pred_labels = model.decode_z(sample_z, batch_size, batch, critic_config.max_words, critic_config.temperature) + else: + + logprobs, pred_labels, _, _ = model.forward_rl(batch, max_words=critic_config.max_words, temp=critic_config.temperature) + if type(pred_labels[0]) == int: + pred_labels = [pred_labels] + pred_labels_gpu = model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in pred_labels]), LONG) + + + true_labels = batch['outputs'] + true_labels_gpu = model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in true_labels.tolist()]), LONG) + + + if critic is not None: + with th.no_grad(): + if critic_config.word_plas: + # since we are only taking the first Q, we do not need forward_target + Q.append(critic(batch, pred_labels_gpu)) + Q_pred = critic.forward_target(batch, pred_labels_gpu.unsqueeze(1), true_labels_gpu) + else: + if critic.is_gauss: + Q.append(critic(batch, sample_z)[0]) + Q_pred = critic.forward_target(batch, sample_z, corpus_z) + else: + if critic_config.embed_z_for_critic: + cat_action = model.z_embedding(sample_z.view(1, -1, model.y_size * model.k_size)).squeeze(0) + else: + cat_action = sample_z.view(-1, model.y_size * model.k_size) + + Q.append(critic(batch, cat_action)[0]) + + + # move from GPU to CPU + # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + # pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + # true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + if ctx_len is None: + pdb.set_trace() + keys = batch['keys'] + if cur_key == "": + cur_key = keys[0] + + for b_id in range(len(pred_labels)): + if keys[b_id] != cur_key: + tmp_gen = defaultdict(list) + tmp_gen[cur_key] = generated_dialogs[cur_key] + task_report_new, _, _= evaluator.evaluateModel(tmp_gen, mode=data.name, new_version=True) + print(task_report_new) + cur_key = keys[b_id] + + + # TODO attn + if critic_config.word_plas: + pred_str = get_sent(model.vocab, de_tknize, pred_labels_gpu, b_id) + else: + pred_str = get_sent(model.vocab, de_tknize, np.asarray(pred_labels[b_id]).reshape(1,-1), 0) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + # if verbose and (num_batch is None or batch_cnt < 2): + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('Q: {}'.format(Q_pred[b_id])) + write('-' * 40) + + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + write(task_report_new) + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + if critic is not None: + + data_q = np.asarray(th.cat(Q).cpu()).squeeze() + mean = np.mean(data_q) + std = np.std(data_q) + lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q)) + if outfile is not None: + outfile.write(f"{mean}, {std}, {lower}, {upper}") + print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}") + + # average_critic_return = th.mean(th.cat(Q)).item() + # variance_critic_return = th.var(th.cat(Q)).item() + # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}") + else: + mean = 0.0 + + # write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1, mean + +def task_run_critic(data, agent, num_batch, evaluator=None, dest_f=None, verbose=True, f=None, outfile=None): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + agent.cvae.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(agent.cvae.config, shuffle=num_batch is not None, verbose=False, fix_batch=agent.cvae.config.fix_batch) + if evaluator is not None: + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + Q = [] + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + + batch_size = len(batch['bs']) + key = batch['keys'][0] + + if key in agent.raw_responses.keys(): + pred_labels_gpu = agent.cvae.np2var(np.asarray([agent.cvae.pad_to(agent.critic.args.max_words, a, do_pad=True) for a in agent.raw_responses[key]]), LONG) + else: + print(key, "skipped") + continue + + true_labels = batch['outputs'] + + with th.no_grad(): + if agent.critic.args.word_plas: + # since we are only taking the first Q, we do not need forward_target + Q.append(agent.critic(batch, pred_labels_gpu)) + else: + z_t, _, _ = agent.vae.get_z_via_vae(pred_labels_gpu) + Q.append(agent.critic(batch, z_t)) + + # get context + if evaluator is not None: + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + if ctx_len is None: + pdb.set_trace() + keys = batch['keys'] + + for b_id in range(len(pred_labels_gpu.cpu())): + if agent.args.word_plas: + pred_str = get_sent(agent.cvae.vocab, de_tknize, pred_labels_gpu, b_id) + else: + pred_str = get_sent(agent.cvae.vocab, de_tknize, np.asarray(pred_labels_gpu[b_id]).reshape(1,-1), 0) + true_str = get_sent(agent.cvae.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(agent.cvae.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + # if verbose and (num_batch is None or batch_cnt < 2): + # write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + # write('True: {}'.format(true_str, )) + # write('Pred: {}'.format(pred_str, )) + # write('-' * 40) + + + if evaluator is not None: + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + outfile.write(task_report_new) + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + print(resp_report) + outfile.write(resp_report) + + data_q = np.asarray(th.cat(Q).cpu()).squeeze() + mean = np.mean(data_q) + std = np.std(data_q) + lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q)) + if outfile is not None: + outfile.write(f"\ncritic mean, std, lower_bound, upper_bound\n{mean}, {std}, {lower}, {upper}") + + print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}") + + # average_critic_return = th.mean(th.cat(Q)).item() + # variance_critic_return = th.var(th.cat(Q)).item() + # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}") + + # write(task_report) + # write(resp_report) + # write('Generation Done') + return mean + +def task_run_critic_on_behavior_policy(data, agent, num_batch, dest_f=None, verbose=True, outfile=None): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + agent.cvae.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(agent.cvae.config, shuffle=num_batch is not None, verbose=False, fix_batch=agent.cvae.config.fix_batch) + # evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + Q = [] + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + + batch_size = len(batch['bs']) + key = batch['keys'][0] + true_labels = agent.cvae.np2var(batch['outputs'], LONG) + with th.no_grad(): + if not agent.critic.args.word_plas: + z_t, _, _ = agent.vae.get_z_via_vae(pred_labels_gpu) + Q.append(agent.critic(batch, z_t)) + else: + Q.append(agent.critic(batch, true_labels)) + + data_q = np.asarray(th.cat(Q).cpu()).squeeze() + mean = np.mean(data_q) + std = np.std(data_q) + lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q)) + if outfile is not None: + outfile.write(f"{mean}, {std}, {lower}, {upper}") + + print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}") + + + # average_critic_return = th.mean(th.cat(Q)).item() + # variance_critic_return = th.var(th.cat(Q)).item() + # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}") + + # write(task_report) + # write(resp_report) + # write('Generation Done') + return mean + + +def task_generate_augpt(model, data, config, evaluator, num_batch, dest_f=None, verbose=True): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + # evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + generated_dialogs = defaultdict(list) + responses = [] + true_responses = [] + beliefs = [] + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + # TODO attn + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + responses.append(pred_str) + true_responses.append(true_str) + beliefs.append(batch['raw_bs'][b_id]) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, np.array(ctx[b_id][t_id]).reshape(1, -1), 0, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + bs_str = get_sent(model.vocab, de_tknize, np.array(batch['bs'][b_id]).reshape(1, -1), 0, stop_eos=False) + db_str = get_sent(model.vocab, de_tknize, np.array(batch['db'][b_id]).reshape(1, -1), 0, stop_eos=False) + prev_metadata = 'BS: {}, DB: {}'.format(bs_str, db_str) + + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + if verbose and (num_batch is None or batch_cnt < 2): + write(prev_metadata) + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('-' * 40) + + corpus_success, corpus_match, corpus_domain_result, corpus_task_report = evaluator.evaluate(beliefs, true_responses) + write('====Corpus Result====') + write(corpus_task_report) + + success, match, domain_result, task_report = evaluator.evaluate(beliefs, responses) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + write('====System result per domain====') + for d, (m, s) in domain_result.items(): + write("{}: match {:2.2f}% success {:2.2f}%".format(d, m*100, s*100)) + + write('====System result====') + write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1 + +def task_generate_plas(actor, cvae, data, config, evaluator, num_batch, dest_f=None, verbose=True, critic=None): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + actor.eval() + cvae.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + Q = [] + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + try: + # _, z = actor(batch) + action, _, _ = actor(batch) + except: + _, action, _, _ = actor(batch) + # z = actor(batch) + if critic is not None: + with th.no_grad(): + if actor.is_gauss: + Q.append(critic(batch, action)[0]) + else: + if actor.embed_z_for_critic: + cat_action = actor.z_embedding(action.view(1, -1, actor.y_size * actor.k_size)).squeeze(0) + else: + cat_action = action.view(-1, actor.y_size * actor.k_size) + Q.append(critic(batch, cat_action)[0]) + + + try: + logprobs, outputs = cvae.decode_z(action, len(batch['context_lens']), batch, config.max_dec_len) + except: + pdb.set_trace() + logprobs, outputs = cvae.decode_z(action, len(batch['context_lens']), batch, config.max_dec_len) + if type(outputs[0]) == int: + outputs = [outputs] + pred_labels = np.asarray([cvae.pad_to(config.max_dec_len, a, do_pad=True) for a in outputs]) + true_labels = batch['outputs'] + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + # TODO attn + pred_str = get_sent(cvae.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(cvae.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(cvae.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + if verbose and (num_batch is None or batch_cnt < 2): + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('-' * 40) + + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + write(task_report_new) + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + if critic is not None: + average_critic_return = th.mean(th.cat(Q)).item() + variance_critic_return = th.var(th.cat(Q)).item() + write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}") + else: + average_critic_return = 0.0 + write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1, average_critic_return + +def task_generate_wSampling(model, data, config, evaluator, num_batch, dest_f=None, verbose=True, n_z=10): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + def is_masked_action(bs_label, db_label, response): + """ + check if the generated response should be masked based on belief state and db result + a) inform when there is no db match + b) inform no option when there is a match + c) out of domain action + d) no offer with a particular slot? + e) inform/request time on domains other than train and restaurant + """ + for domain, bs_idx, db_idx in zip(DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX): + if bs_label[bs_idx] == 0: # if domain is inactive + if any([p in response for p in REQ_TOKENS[domain]]): # but a token from that domain is present + print(">> inactive domain {} is mentioned".format(domain)) + return True + else: # domain is active + if any([p in response for p in ["sorry", "no", "not" "cannot"]]): # system inform no offer + if db_idx < 0: # domain has no db + print(">> inform no offer on domain {} without DB".format(domain)) + return True + elif db_label[db_idx] != 1: # there are matches + print(">> inform no offer when there are matches on domain {}".format(domain)) + return True + # if "[value_" in response: + # print(">> inform no offer mentioning criteria") + # return True # always only inform "no match for your criteria" w/o mentioning them explicitly + elif any([p in response for p in REQ_TOKENS[domain]]) or "i have [value_count]" in response or "there are [value_count]" in response: # if requestable token is present + # TODO also check for i have [value_count] match, not only the requestable tokens + if db_idx >= 0 and int(db_label[db_idx]) == 1: # and domain has a DB to be queried and there are no matches + print(">> inform match when there are no DB match on domain {}".format(domain)) + return True + + return False + + model.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + + outputs = [] + for i in range(n_z): + output, labels = model(batch, mode=GEN, gen_type=config.gen_type) + pred = [t.cpu().data.numpy() for t in output[DecoderRNN.KEY_SEQUENCE]] + pred = np.array(pred, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + outputs.append(pred) #(n_z, batch_size, max_dec_len) + + # move from GPU to CPU + labels = labels.cpu() + # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + # pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(true_labels.shape[0]): + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + # TODO attn + for i in range(n_z): + flag = False + pred_labels = np.array(outputs[i], dtype=int) + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + + if not is_masked_action(batch['bs'][b_id], batch['db'][b_id], pred_str): + flag = True + break + else: + if i == 0: + print("----------") + print(true_str) + print(pred_str) + if not flag: + warnings.warn("No sampled actions passed the masking test, taking the last sampled action") + + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + if verbose and (num_batch is None or batch_cnt < 2): + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('-' * 40) + + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + write(task_report_new) + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1 + +def task_generate_actz(model, data, config, evaluator, num_batch, dest_f=None, enc="utt", verbose=True): + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch) + evaluator.initialize() + print('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + generated_dialogs = defaultdict(list) + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + if enc=="aux": + outputs, labels = model.forward_aez(batch, mode=GEN, gen_type=config.gen_type) + else: + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + # TODO attn + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + generated_dialogs[keys[b_id]].append(pred_str) + evaluator.add_example(true_str, pred_str) + + if verbose and (num_batch is None or batch_cnt < 2): + write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,)) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('-' * 40) + + task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True) + write(task_report_new) + + task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name) + resp_report, bleu, prec, rec, f1 = evaluator.get_report() + write(task_report) + write(resp_report) + write('Generation Done') + return success, match, bleu, f1 + +def dump_latent(model, data, config): + latent_results = defaultdict(list) + model.eval() + batch_cnt = 0 + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True) + + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None: + break + + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + sample_y = outputs['sample_z'].cpu().data.numpy().reshape(-1, config.y_size, config.k_size) + log_qy = outputs['log_qy'].cpu().data.numpy().reshape(-1, config.y_size, config.k_size) + + if config.dec_use_attn: + attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]] + attns = np.array(attns).squeeze(2).swapaxes(0, 1) + else: + attns = None + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + ctx_str.append(temp_str) + prev_ctx = 'Source context: {}'.format(ctx_str) + + latent_results[keys[b_id]].append({'context': prev_ctx, 'gt_resp': true_str, + 'pred_resp': pred_str, 'domain': batch['goals_list'], + 'sample_y': sample_y[b_id], 'log_qy': log_qy[b_id], + 'attns': attns[b_id] if attns is not None else None}) + latent_results = dict(latent_results) + return latent_results + +def dump_latent_gauss(model, data, config): + latent_results = defaultdict(list) + model.eval() + batch_cnt = 0 + de_tknize = lambda x: ' '.join(x) + data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True) + + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None: + break + + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + sample_y = outputs['sample_z'].cpu().data.numpy().reshape(-1, config.y_size) + #TODO qmu is not stored in outputs + q_mu = outputs['q_mu'].cpu().data.numpy().reshape(-1, config.y_size) + q_logvar = outputs['q_logvar'].cpu().data.numpy().reshape(-1, config.y_size) + + if config.dec_use_attn: + attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]] + attns = np.array(attns).squeeze(2).swapaxes(0, 1) + else: + attns = None + + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + keys = batch['keys'] + + for b_id in range(pred_labels.shape[0]): + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + ctx_str.append(temp_str) + prev_ctx = 'Source context: {}'.format(ctx_str) + + latent_results[keys[b_id]].append({'context': prev_ctx, 'gt_resp': true_str, + 'pred_resp': pred_str, 'domain': batch['goals_list'], + 'sample_y': sample_y[b_id], 'q_mu': q_mu[b_id], 'q_logvar' : q_logvar[b_id], + 'attns': attns[b_id] if attns is not None else None}) + latent_results = dict(latent_results) + return latent_results + + + + + + + + + + + + diff --git a/experiments_woz/mt_gauss.py b/experiments_woz/mt_gauss.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4853a2bcee54310c8732e5e337a183e6b0ad15 --- /dev/null +++ b/experiments_woz/mt_gauss.py @@ -0,0 +1,255 @@ +import time +import os +import json +import sys +from pathlib import Path +sys.path.append((Path(__file__).parent / '../').resolve().as_posix()) +import torch as th +import pdb +import logging +import random +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +import latent_dialog.corpora as corpora +from latent_dialog.data_loaders import BeliefDbDataLoadersAE, BeliefDbDataLoaders +from latent_dialog.evaluators import MultiWozEvaluator +from latent_dialog.models_task import SysMTGauss +from latent_dialog.main import train, validate, mt_train +import latent_dialog.domain as domain +from experiments_woz.dialog_utils import task_generate + +# def main(seed): + +def main(seed, pretrained_folder, pretrained_model_id): + domain_name = 'object_division' + domain_info = domain.get_domain(domain_name) + + if pretrained_folder is not None: + ae_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json') + ae_config = Pack(json.load(open(ae_config_path))) + ae_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)) + train_path=ae_config.train_path + valid_path=ae_config.valid_path + test_path=ae_config.test_path + + else: + ae_model_path = None + ae_config_path = None + train_path='../data/data_2.1/train_dials.json' + valid_path='../data/data_2.1/val_dials.json' + test_path='../data/data_2.1/test_dials.json' + + if seed is None: + seed = ae_config.seed + + base_path = Path(__file__).parent + config = Pack( + seed = seed, + ae_model_path=ae_model_path, + ae_config_path=ae_config_path, + train_path=train_path, + valid_path=valid_path, + test_path=test_path, + dact_path=(base_path / '../data/norm-multi-woz/damd_dialog_acts.json').resolve().as_posix(), + ae_zero_pad=True, + max_vocab_size=1000, + last_n_model=1, + max_utt_len=50, + max_dec_len=50, + backward_size=2, + batch_size=128, + use_gpu=True, + op='adam', + init_lr=0.0005, + l2_norm=1e-05, + momentum=0.0, + grad_clip=5.0, + dropout=0.5, + max_epoch=50, + # aux_train_freq=10, + # aux_max_epoch=1, # epoch per training, i.e every aux_train_freq, train aux_max_epoch time + shared_train=True, + embed_size=256, + num_layers=1, + utt_rnn_cell='gru', + utt_cell_size=300, + bi_utt_cell=True, + enc_use_attn=True, + dec_use_attn=False, + dec_rnn_cell='lstm', + dec_cell_size=300, + dec_attn_mode='cat', + y_size=200, + beta = 0.01, + # aux_pi_beta = 0.01, # default to 1.0 if not set + simple_posterior=True, + contextual_posterior=False, + use_mi =False, + use_pr =True, + use_diversity = False, + # + beam_size=20, + fix_batch = True, + fix_train_batch=False, + avg_type='word', + # avg_type='slot', + # slot_weight=10, + use_aux_kl=False, + selective_fine_tune=False, + print_step=300, + ckpt_step=1416, + improve_threshold=0.996, + patient_increase=2.0, + save_model=True, + early_stop=False, + gen_type='greedy', + preview_batch_num=50, + k=domain_info.input_length(), + init_range=0.1, + pretrain_folder='2021-11-25-16-47-37-mt_gauss', + forward_only=False + ) + set_seed(config.seed) + + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + stats_path = (base_path / 'sys_config_log_model').resolve().as_posix() + if config.forward_only: + saved_path = os.path.join(stats_path, config.pretrain_folder) + config = Pack(json.load(open(os.path.join(saved_path, 'config.json')))) + config['forward_only'] = True + else: + saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0]) + if not os.path.exists(saved_path): + os.makedirs(saved_path) + config.saved_path = saved_path + + if ae_config_path is not None: + config["max_dec_len"] = ae_config.max_dec_len + config["dec_use_attn"] = ae_config.dec_use_attn + config["dec_rnn_cell"] = ae_config.dec_rnn_cell + config["dec_cell_size"] = ae_config.dec_cell_size + config["dec_attn_mode"] = ae_config.dec_attn_mode + config["embed_size"] = ae_config.embed_size + config["y_size"] = ae_config.y_size + # Denoising setup + if "noise_type" in ae_config: + noise_type = ae_config.noise_type + else: + noise_type = None + if noise_type: + config["noise_type"] = ae_config.noise_type + config["noise_p"] = ae_config.noise_p + config["remove_tokens"] = ae_config.remove_tokens + config["no_special"] = ae_config.no_special + else: + noise_type = None + vocab_dict = None + + + prepare_dirs_loggers(config) + logger = logging.getLogger() + logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30)) + config.saved_path = saved_path + + # save configuration + with open(os.path.join(saved_path, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) # sort_keys=True + + # data for AE training + aux_corpus = corpora.NormMultiWozCorpusAE(config) + vocab_dict = aux_corpus.vocab_dict if noise_type else None + aux_train_dial, aux_val_dial, aux_test_dial = aux_corpus.get_corpus() + + aux_train_data = BeliefDbDataLoadersAE('Train', aux_train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger) + aux_val_data = BeliefDbDataLoadersAE('Val', aux_val_dial, config) + aux_test_data = BeliefDbDataLoadersAE('Test', aux_test_dial, config) + + # data for RG training + corpus = corpora.NormMultiWozCorpus(config) + train_dial, val_dial, test_dial = corpus.get_corpus() + + train_data = BeliefDbDataLoaders('Train', train_dial, config) + val_data = BeliefDbDataLoaders('Val', val_dial, config) + test_data = BeliefDbDataLoaders('Test', test_dial, config) + + evaluator = MultiWozEvaluator('SysWoz', config) + + model = SysMTGauss(corpus, config) + model_dict = model.state_dict() + + # load params from saved ae_model + if pretrained_folder is not None: + ae_model_dict = th.load(config.ae_model_path, map_location=lambda storage, location: storage) + tmp_k = [k for k in model_dict.keys() if k not in ae_model_dict.keys()] + aux_model_dict = {} + for k in tmp_k: + aux_model_dict[k] = ae_model_dict[k.replace("aux", "utt")] #utt_encoder param in ae model is now moved to aux_params in a new dict + model_dict.update(ae_model_dict) # all params except aux_encoder + model_dict.update(aux_model_dict) # aux encoder + model.load_state_dict(model_dict) + + if config.use_gpu: + model.cuda() + + # only train utt_encoder + if config.ae_config_path is not None and config.selective_fine_tune: + for name, param in model.named_parameters(): + if "utt_encoder" not in name: + # if "decoder" in name or "z_embedding" in name: + param.requires_grad = False + + best_epoch = None + if not config.forward_only: + try: + best_epoch = mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=task_generate) + except KeyboardInterrupt: + print('Training stopped by keyboard.') + if best_epoch is None: + model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if 'model' in p and 'rl' not in p and 'aux' not in p]) + best_epoch = model_ids[-1] + + print("$$$ Load {}-model".format(best_epoch)) + # config.batch_size = 32 + model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch)))) + + + logger.info("Forward Only Evaluation") + + validate(model, val_data, config) + validate(model, test_data, config) + + with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f) + + with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f) + + with open(os.path.join(saved_path, '{}_{}_valid_AE_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, aux_val_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True) + + with open(os.path.join(saved_path, '{}_{}_test_AE_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, aux_test_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True) + + + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '=' * 30) + +if __name__ == "__main__": + ### Train from scratch ### + + for i in [1, 2, 3, 4]: + main(i, None, None) + + + ### Load pre-trained AE ### + # multiwoz 2.0 + # pretrained = {"2020-02-28-16-49-48-sl_gauss_ae":100} + + # multiwoz 2.1 + # pretrained = {"2020-10-15-17-11-59-sl_gauss_ae":92} + + # for p in pretrained.keys(): + # folder = p + # id_ = pretrained[p] + # main(None, folder, id_) + diff --git a/experiments_woz/plas_gauss.py b/experiments_woz/plas_gauss.py new file mode 100644 index 0000000000000000000000000000000000000000..65e8961270d78b64bc63fe6a614846d3c9714411 --- /dev/null +++ b/experiments_woz/plas_gauss.py @@ -0,0 +1,217 @@ +import time +import os +import sys +import random +sys.path.append('../') +import json +import torch as th +import pdb +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +from latent_dialog.corpora import NormMultiWozCorpus +from latent_dialog.models_task import SysMTGauss, SysActZGauss +from latent_dialog.agent_task import LatentPLASAgent, PLASAgent +from latent_dialog.main import OfflinePLAS +from latent_dialog.evaluators import MultiWozEvaluator +from experiments_woz.dialog_utils import task_generate_plas, task_generate + + +def main(seed, pretrained_folder, pretrained_model_id): + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[START]', start_time, '='*30) + # RL configuration + env = 'gpu' + + exp_dir = os.path.join('sys_config_log_model', pretrained_folder, 'plas_rl-'+start_time) + # create exp folder + if not os.path.exists(exp_dir): + os.mkdir(exp_dir) + + rl_config = Pack( + sv_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json'), + sv_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)), + rl_config_path = os.path.join(exp_dir, 'rl_config.json'), + rl_model_path = os.path.join(exp_dir, 'rl_model'), + saved_path = exp_dir, + ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'), + reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'), + record_path = exp_dir, + record_freq = 100, + sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there + use_gpu = env == 'gpu', + nepoch = 1, + nepisode = 10000, + tune_pi_only=False, + is_stochastic=False, + word_plas=False, + z_loss=True, #set to False if word_plas == True + actor_rl_lr = 0.005, + importance_threshold=10, + fix_episode=True, + validate_with_critic=True, + goal_to_critic=True, + add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True + critic_kl_loss=True, + critic_kl_alpha=0.1, + critic_dropout=True, # use with regularize_critic=False + critic_dropout_rate=0.3, # only if dropout is true + critic_dropout_agg="min", #avg or min + critic_sample=False, + critic_transformer=False, + critic_actf="none", #relu or sigmoid or tanh or none + critic_loss="mse", #mse or huber + critic_rl_lr = 0.01, + train_vae=True, + train_vae_freq=10001, # if larger than nepisode. only train once at the beginning + train_vae_nepisode=0, # nepisode for the rest of VAE training + train_vae_nepisode_init=10000, # nepisode for first VAE training + weighted_vae_nll=True, + vae_beta=0.0, #during pre-training 0.01 or 0.1, no need to change distributions as we only tune the encoder + rl_lr = 0.01, + max_words = 50, + temperature=1.0, + momentum = 0.0, + nesterov = False, + gamma = 0.99, + tau=0.005, + lmbda=0.5, + beta=0.001, + batch_size=16, + rl_clip = 1.0, + n_z=10, + random_seed = seed, + policy_dropout=False, + dropout_on_eval=False, + fail_info_penalty=False + ) + + prepare_dirs_loggers(rl_config) + + # list config keys that are being compared for tensorboard naming + tb_keys = ["validate_with_critic", "actor_rl_lr", "critic_rl_lr", "critic_actf"] + tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-" + "-".join([f"{k}={rl_config[k]}" for k in tb_keys]) + # tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-" + "recurrent_critic_forward_target" + + # load previous supervised learning configuration and corpus + sv_config = Pack(json.load(open(rl_config.sv_config_path))) + sv_config['dropout'] = 0.0 + sv_config['use_gpu'] = rl_config.use_gpu + sv_config['policy_dropout'] = rl_config.policy_dropout + sv_config['dropout_on_eval'] = rl_config.dropout_on_eval + # assert sv_config.train_path == rl_config.train_path + + # set random seed + if rl_config.random_seed is None: + rl_config.random_seed = sv_config.seed + set_seed(rl_config.random_seed) + + + try: + corpus = NormMultiWozCorpus(sv_config) + except FileNotFoundError: + sv_config['train_path'] = sv_config.train_path.replace("/home/lubis", "") + sv_config['valid_path'] = sv_config.valid_path.replace("/home/lubis", "") + sv_config['test_path'] = sv_config.test_path.replace("/home/lubis", "") + corpus = NormMultiWozCorpus(sv_config) + + rl_config['train_path'] = sv_config['train_path'] + rl_config['valid_path'] = sv_config['valid_path'] + rl_config['test_path'] = sv_config['test_path'] + + rl_config['train_memory_path'] = sv_config['train_path'].replace(".json", ".dill") + rl_config['valid_memory_path'] = sv_config['valid_path'].replace(".json", ".dill") + rl_config['test_memory_path'] = sv_config['test_path'].replace(".json", ".dill") + + if rl_config.fix_episode: + rl_config['train_memory_path'] = rl_config['train_memory_path'].replace(".dill", "-ep.dill") + rl_config['valid_memory_path'] = rl_config['valid_memory_path'].replace(".dill", "-ep.dill") + rl_config['test_memory_path'] = rl_config['test_memory_path'].replace(".dill", "-ep.dill") + + rl_config['y_size'] = sv_config['y_size'] + + + # save configuration + with open(rl_config.rl_config_path, 'w') as f: + json.dump(rl_config, f, indent=4) + + if "mt_" in pretrained_folder: + sys_model = SysMTGauss(corpus, sv_config) + else: + sys_model = SysActZGauss(corpus, sv_config) + if sv_config.use_gpu: + sys_model.cuda() + + mt_model_dict = th.load(rl_config.sv_model_path, map_location=lambda storage, location: storage) + sys_model.load_state_dict(mt_model_dict) + + sys_model.eval() + evaluator = MultiWozEvaluator('SysWoz', sv_config) + + if rl_config.word_plas: + agent = PLASAgent(sys_model, corpus, rl_config, name='System', tune_pi_only=rl_config.tune_pi_only) + else: + agent = LatentPLASAgent(sys_model, corpus, rl_config, evaluator, name='System', tune_pi_only=rl_config.tune_pi_only) + + plas = OfflinePLAS(agent, corpus, sv_config, rl_config, task_generate_plas, name=tensorboard_name, vae_gen=task_generate) + # save sys model + # th.save(sys_model.state_dict(), rl_config.rl_model_path) + + # initialize train buffer + if os.path.isfile(rl_config.train_memory_path): + # if False: + print("Loading replay buffer for training from {}".format(rl_config.train_memory_path)) + plas.agent.train_buffer.load(rl_config.train_memory_path) + else: + print("Extracting experiences from training data") + plas.extract(plas.train_data, plas.agent.train_buffer) + print("Saving experiences to {}".format(rl_config.train_memory_path)) + plas.agent.train_buffer.save(rl_config.train_memory_path) + + # initialize valid buffer + if os.path.isfile(rl_config.valid_memory_path): + print("Loading replay buffer for validation from {}".format(rl_config.valid_memory_path)) + plas.agent.valid_buffer.load(rl_config.valid_memory_path) + else: + print("Extracting experiences from valid data") + plas.extract(plas.val_data, plas.agent.valid_buffer) + print("Saving experiences to {}".format(rl_config.valid_memory_path)) + plas.agent.valid_buffer.save(rl_config.valid_memory_path) + + # initialize test buffer + if os.path.isfile(rl_config.test_memory_path): + print("Loading replay buffer for test from {}".format(rl_config.test_memory_path)) + plas.agent.test_buffer.load(rl_config.test_memory_path) + else: + print("Extracting experiences from test data") + plas.extract(plas.test_data, plas.agent.test_buffer) + print("Saving experiences to {}".format(rl_config.test_memory_path)) + plas.agent.test_buffer.save(rl_config.test_memory_path) + + if sv_config.use_gpu: + agent.actor.cuda() + agent.actor_target.cuda() + agent.critic.cuda() + agent.critic_target.cuda() + + # start RL + plas.run() + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '='*30) + + +if __name__ == '__main__' : + pretrained = {} + f_in = "sys_config_log_model/tmp.lst" + + with open(f_in, "r") as f: + lines = f.readlines() + for l in lines: + if ";" not in l and "cat" not in l and "rl" not in l: + pretrained["/".join(l.split("/")[:-1])] = int(l.split("/")[-1].split("_")[1]) + + for p in pretrained.keys(): + folder = p + id_ = pretrained[p] + main(None, folder, id_) + + diff --git a/experiments_woz/run_critic.py b/experiments_woz/run_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..b113b93713459737036f10003617781b98ae4086 --- /dev/null +++ b/experiments_woz/run_critic.py @@ -0,0 +1,135 @@ +import time +import os +import sys +import random +sys.path.append('../') +import json +import torch as th +import pdb +from tqdm import trange +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +from latent_dialog.corpora import NormMultiWozCorpus +from latent_dialog.models_task import * +from latent_dialog.agent_task import LatentCriticAgent, CriticAgent +from latent_dialog.main import OfflineCritic +from latent_dialog.evaluators import MultiWozEvaluator +from experiments_woz.dialog_utils import task_generate_critic, task_generate + + +def main(pretrained_critic_folder): + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[START]', start_time, '='*30) + # RL configuration + env = 'gpu' + + exp_dir = os.path.join('sys_config_log_model', "/".join(pretrained_critic_folder.split("/")[:-1])) + print(exp_dir) + critic_config_path = exp_dir + "/critic_config.json" + + critic_config = Pack(json.load(open(critic_config_path))) + config = Pack(json.load(open(critic_config.config_path))) + config['dropout'] = 0.0 + config['use_gpu'] = critic_config.use_gpu + config['policy_dropout'] = critic_config.policy_dropout + config['dropout_on_eval'] = critic_config.dropout_on_eval + + # set random seed + if critic_config.random_seed is None: + try: + critic_config.random_seed = config.seed + except: + critic_config.random_seed = config.random_seed + set_seed(critic_config.random_seed) + + + try: + corpus = NormMultiWozCorpus(config) + except FileNotFoundError: + config['train_path'] = config.train_path.replace("/home/lubis", "") + config['valid_path'] = config.valid_path.replace("/home/lubis", "") + config['test_path'] = config.test_path.replace("/home/lubis", "") + corpus = NormMultiWozCorpus(config) + except PermissionError: + pdb.set_trace() + config['train_path'] = config.train_path.replace("/root", "..") + config['valid_path'] = config.valid_path.replace("/root", "..") + config['test_path'] = config.test_path.replace("/root", "..") + corpus = NormMultiWozCorpus(config) + + if "augpt" in pretrained_critic_folder or "mwoz" in pretrained_critic_folder or "HDSA" in pretrained_critic_folder: + pretrained_critic_folder = critic_config.model_path + + if "rl" in pretrained_critic_folder: + if "gauss" in pretrained_critic_folder: + if "plas" in pretrained_critic_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + else: + if "actz" in pretrained_critic_folder: + if "gauss" in pretrained_critic_folder: + sys_model = SysActZGauss(corpus, config) + else: + sys_model = SysActZCat(corpus, config) + elif "mt" in pretrained_critic_folder: + if "gauss" in pretrained_critic_folder: + sys_model = SysMTGauss(corpus, config) + else: + sys_model = SysMTCat(corpus, config) + else: + if "gauss" in pretrained_critic_folder: + sys_model = SysPerfectBD2Gauss(corpus, config) + else: + sys_model = SysPerfectBD2Cat(corpus, config) + + if config.use_gpu: + sys_model.cuda() + + model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage) + sys_model.load_state_dict(model_dict) + + sys_model.eval() + evaluator = MultiWozEvaluator('SysWoz', config) + + if critic_config.word_plas or critic_config.raw_response: + agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + else: + agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System') + + + agent_critic_dict = th.load(exp_dir + "/critic_model", map_location=lambda storage, location: storage) + agent.critic.load_state_dict(agent_critic_dict) + if "actor_path" in critic_config and critic_config.actor_path is not None: + agent_actor_dict = th.load(critic_config.actor_path, map_location=lambda storage, location: storage) + agent.actor.load_state_dict(agent_actor_dict) + + main = OfflineCritic(agent, corpus, config, critic_config, task_generate_critic, name="", vae_gen=task_generate, forward_only=True) + + if critic_config.use_gpu: + agent.critic.cuda() + agent.critic_target.cuda() + + with open(exp_dir + "/conf_interval.txt", "w") as f: + t_success, t_match, t_bleu, t_f1, t_Q = main.generate_func(main.agent.cvae, main.test_data, main.sv_config, critic_config, main.evaluator, None, verbose=False, critic=agent.critic, actor = agent.actor, outfile=f) + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '='*30) + + +if __name__ == '__main__' : + pretrained = [] + # list of test-critic.tsv files of critics to run + f_in = "sys_config_log_model/critic_report.lst" + + with open(f_in, "r") as f: + lines = f.readlines() + for l in lines: + if ";" not in l: + pretrained.append(l) + + for p in pretrained: + main(p) + + diff --git a/experiments_woz/sl_gauss_ae.py b/experiments_woz/sl_gauss_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..041d0488e2ea4876d78175750e3847c8df715c0a --- /dev/null +++ b/experiments_woz/sl_gauss_ae.py @@ -0,0 +1,165 @@ +import time +import os +import sys +sys.path.append('../') +import json +import pdb +import torch as th +import logging +import random +from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed +import latent_dialog.corpora as corpora +from latent_dialog.data_loaders import BeliefDbDataLoadersAE +from latent_dialog.evaluators import MultiWozEvaluator +from latent_dialog.models_task import SysAEGauss +from latent_dialog.main import train, validate +import latent_dialog.domain as domain +from dialog_utils import task_generate +import pickle as pkl + +# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"]="1" + +def main(seed): + domain_name = 'object_division' + domain_info = domain.get_domain(domain_name) + config = Pack( + seed=seed, + # train_path='../data/norm-multi-woz/train_dials.json', + # valid_path='../data/norm-multi-woz/val_dials.json', + # test_path='../data/norm-multi-woz/test_dials.json', + train_path='../data/data_2.1/train_dials.json', + valid_path='../data/data_2.1/val_dials.json', + test_path='../data/data_2.1/test_dials.json', + dact_path='../data/norm-multi-woz/damd_dialog_acts.json', + max_vocab_size=1000, + last_n_model=1, + max_utt_len=50, + max_dec_len=50, + backward_size=2, + batch_size=128, + use_gpu=True, + op='adam', + init_lr=0.001, + l2_norm=1e-05, + momentum=0.0, + grad_clip=5.0, + dropout=0.5, + max_epoch=100, + embed_size=256, + num_layers=1, + utt_rnn_cell='gru', + utt_cell_size=300, + bi_utt_cell=True, + enc_use_attn=True, + dec_use_attn=False, + dec_rnn_cell='lstm', + # must be same as ctx_cell_size due to the passed initial state + dec_cell_size=300, + # must be same as ctx_cell_size due to the passed initial state + dec_attn_mode='cat', + y_size=200, + beta = 0.01, + simple_posterior=True, + contextual_posterior=False, + use_pr = True, + use_metadata=False, + ae_zero_padding=True, + beam_size=20, + fix_batch = True, + fix_train_batch=False, + avg_type='word', + print_step=300, + ckpt_step=1416, + improve_threshold=0.996, + patient_increase=2.0, + save_model=True, + early_stop=False, + gen_type='greedy', + preview_batch_num=50, + k=domain_info.input_length(), + init_range=0.1, + pretrain_folder='2021-11-09-13-55-11-sl_gauss_ae', + ## denoising + # noise_type=None, + noise_type="tokens_removal", + noise_p=0.1, # Probability to add noise to a response + no_special=False, # True - do not allow tokens switching with special tokens + remove_tokens=False, # True -remove token completely, False - change it with UNK + forward_only=False + ) + set_seed(config.seed) + start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + stats_path = 'sys_config_log_model' + if config.forward_only: + saved_path = os.path.join(stats_path, config.pretrain_folder) + config = Pack(json.load(open(os.path.join(saved_path, 'config.json')))) + config['forward_only'] = True + else: + saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0]) + if not os.path.exists(saved_path): + os.makedirs(saved_path) + config.saved_path = saved_path + + prepare_dirs_loggers(config) + logger = logging.getLogger() + logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30)) + config.saved_path = saved_path + + # save configuration + with open(os.path.join(saved_path, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) # sort_keys=True + + corpus = corpora.NormMultiWozCorpusAE(config) + train_dial, val_dial, test_dial = corpus.get_corpus() + + # Denoising setup + noise_type = config.noise_type + vocab_dict = corpus.vocab_dict if noise_type else None + + train_data = BeliefDbDataLoadersAE('Train', train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger) + val_data = BeliefDbDataLoadersAE('Val', val_dial, config) + test_data = BeliefDbDataLoadersAE('Test', test_dial, config) + + config['n_iter'] = config.ckpt_step * config.max_epoch + + evaluator = MultiWozEvaluator('SysWoz', config) + + model = SysAEGauss(corpus, config) + + if config.use_gpu: + model.cuda() + + best_epoch = None + if not config.forward_only: + try: + # best_epoch = train(model, train_data, val_data, test_data, config, evaluator, gen=task_generate) + best_epoch = train(model, train_data, val_data, test_data, config, evaluator, gen=None) + except KeyboardInterrupt: + print('Training stopped by keyboard.') + if best_epoch is None: + model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if '-model' in p]) + best_epoch = model_ids[-1] + + print("$$$ Load {}-model".format(best_epoch)) + config.batch_size = 32 + model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch)))) + + logger.info("Forward Only Evaluation") + + validate(model, val_data, config) + validate(model, test_data, config) + + with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f) + + with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f: + task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f) + + end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) + print('[END]', end_time, '=' * 30) + +if __name__ == "__main__": + for i in [5, 23, 42, 72, 112]: + main(i) + diff --git a/experiments_woz/sys_config_log_model/pretrained.zip b/experiments_woz/sys_config_log_model/pretrained.zip new file mode 100644 index 0000000000000000000000000000000000000000..939ab34e46f57bf4a39dc3cfd1e17e2e8afd2bf4 Binary files /dev/null and b/experiments_woz/sys_config_log_model/pretrained.zip differ diff --git a/latent_dialog/__init__.py b/latent_dialog/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cabc5d9c605965db3e69614595ab9bb891814c9 --- /dev/null +++ b/latent_dialog/__init__.py @@ -0,0 +1,2 @@ +# @Time : 10/18/18 1:55 PM +# @Author : Tiancheng Zhao \ No newline at end of file diff --git a/latent_dialog/__pycache__/__init__.cpython-36.pyc b/latent_dialog/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a40514e24fef0191401e7d2520b656d034a58a24 Binary files /dev/null and b/latent_dialog/__pycache__/__init__.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/agent_task.cpython-36.pyc b/latent_dialog/__pycache__/agent_task.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b870eafe6a1870ae7775a1af813060a3121846a Binary files /dev/null and b/latent_dialog/__pycache__/agent_task.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc b/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1a90c26006d81f3c806cffc69fbd96eab587f4 Binary files /dev/null and b/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc b/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f4f94c4300597092fb0fce488727881b92e95d6 Binary files /dev/null and b/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/base_models.cpython-36.pyc b/latent_dialog/__pycache__/base_models.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d43761c04791f5a244f6ba42f32fbf5f5c238e Binary files /dev/null and b/latent_dialog/__pycache__/base_models.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/corpora.cpython-36.pyc b/latent_dialog/__pycache__/corpora.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2083c5776fd671a1ecf5808f3db31f07d6665c01 Binary files /dev/null and b/latent_dialog/__pycache__/corpora.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/criterions.cpython-36.pyc b/latent_dialog/__pycache__/criterions.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4847cef4e410bb5de379ff1d1f3333bc1a3dcf71 Binary files /dev/null and b/latent_dialog/__pycache__/criterions.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/data_loaders.cpython-36.pyc b/latent_dialog/__pycache__/data_loaders.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e19daa6211efafd5ec322229d4baba7b4149cd01 Binary files /dev/null and b/latent_dialog/__pycache__/data_loaders.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/domain.cpython-36.pyc b/latent_dialog/__pycache__/domain.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fea5bacdc823e0b903ed2012e92f451da85ce27 Binary files /dev/null and b/latent_dialog/__pycache__/domain.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/evaluators.cpython-36.pyc b/latent_dialog/__pycache__/evaluators.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8221ef99bfc4b830075dede8395a49cba2f79b82 Binary files /dev/null and b/latent_dialog/__pycache__/evaluators.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/main.cpython-36.pyc b/latent_dialog/__pycache__/main.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3822d03171b95140a4382e57850dcebb93347388 Binary files /dev/null and b/latent_dialog/__pycache__/main.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/models_task.cpython-36.pyc b/latent_dialog/__pycache__/models_task.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c81368e21c74f01de607eea0f1b779ea20211a7 Binary files /dev/null and b/latent_dialog/__pycache__/models_task.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/nn_lib.cpython-36.pyc b/latent_dialog/__pycache__/nn_lib.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6460889d0df5dd58a4f10b237aa4ad776e492a3 Binary files /dev/null and b/latent_dialog/__pycache__/nn_lib.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc b/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1697d553dc333135918fbfe07c8cdfb9bbb9f519 Binary files /dev/null and b/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/record.cpython-36.pyc b/latent_dialog/__pycache__/record.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2745dbe54ae7796ce791d1a9fcabe06cc488ac9f Binary files /dev/null and b/latent_dialog/__pycache__/record.cpython-36.pyc differ diff --git a/latent_dialog/__pycache__/utils.cpython-36.pyc b/latent_dialog/__pycache__/utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3db9a1ff54198ce831a4511c878887141c3d5ba6 Binary files /dev/null and b/latent_dialog/__pycache__/utils.cpython-36.pyc differ diff --git a/latent_dialog/agent_task.py b/latent_dialog/agent_task.py new file mode 100644 index 0000000000000000000000000000000000000000..414214799c982b6e032fe465836394e6a0fd3ca6 --- /dev/null +++ b/latent_dialog/agent_task.py @@ -0,0 +1,1746 @@ +import torch as th +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +from latent_dialog.utils import LONG, FLOAT, Pack, get_detokenize, np2var +from latent_dialog.main import get_sent, LossManager +from latent_dialog.data_loaders import BeliefDbDataLoaders +from latent_dialog.base_models import frange_cycle_linear +from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS +from latent_dialog.enc2dec.decoders import TEACH_FORCE +from latent_dialog.criterions import NormKLLoss, CatKLLoss +from latent_dialog.offlinerl_utils import * +from latent_dialog.augpt_utils import augpt_normalize, replace_augpt_tokens, replace_hdsa_tokens +from collections import deque, namedtuple, defaultdict +import random +import pdb +import logging +import dill +import warnings +import copy +import json +import time + +logger = logging.getLogger() + +class OfflineRlAgent(object): + def __init__(self, model, corpus, args, name, tune_pi_only): + self.model = model + self.corpus = corpus + self.args = args + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + print("Do we only tune the policy: {}".format(tune_pi_only)) + self.opt = optim.SGD( + [p for n, p in self.model.named_parameters() if 'c2z' in n or not tune_pi_only], + lr=self.args.rl_lr, + momentum=self.args.momentum, + nesterov=(self.args.nesterov and self.args.momentum > 0)) + self.all_rewards = [] + self.all_grads = [] + self.model.train() + + def print_dialog(self, dialog, reward, stats): + for t_id, turn in enumerate(dialog): + if t_id % 2 == 0: + print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>']))) + else: + print("Sys: {}".format(' '.join(turn))) + report = ['{}: {}'.format(k, v) for k, v in stats.items()] + print("Reward {}. {}".format(reward, report)) + + def run(self, batch, evaluator, max_words=None, temp=0.1): + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + logprobs, outs = self.model.forward_rl(batch, max_words, temp) + if batch_size == 1: + logprobs = [logprobs] + outs = [outs] + + key = batch['keys'][0] + sys_turns = [] + # construct the dialog history for printing + for turn_id, turn in enumerate(batch['contexts']): + user_input = self.corpus.id2sent(turn[-1]) + self.dlg_history.append(user_input) + sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + sys_turns.append(' '.join(sys_output)) + + for log_prob in logprobs: + self.logprobs.extend(log_prob) + generated_dialog = {key: sys_turns} + return evaluator.evaluateModel(generated_dialog, mode="offline_rl") + + def update(self, reward, stats): + self.all_rewards.append(reward) + # standardize the reward + r = (reward - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards)) + # compute accumulated discounted reward + g = self.model.np2var(np.array([r]), FLOAT).view(1, 1) + rewards = [] + for _ in self.logprobs: + rewards.insert(0, g) + g = g * self.args.gamma + + loss = 0 + # estimate the loss using one MonteCarlo rollout + for lp, r in zip(self.logprobs, rewards): + loss -= lp * r + print(loss) + self.opt.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip) + self.opt.step() + +class OfflineLatentRlAgent(OfflineRlAgent): + def run(self, batch, evaluator, max_words=None, temp=0.1): + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + logprobs, outs, logprob_z, sample_z = self.model.forward_rl(batch, max_words, temp) + if batch_size == 1: + outs = [outs] + key = batch['keys'][0] + sys_turns = [] + # construct the dialog history for printing + for turn_id, turn in enumerate(batch['contexts']): + user_input = self.corpus.id2sent(turn[-1]) + # print("Usr: {}".format(' '.join([t for t in user_input if t != '<pad>']))) + self.dlg_history.append(user_input) + sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + # print("Sys: {}".format(' '.join([t for t in sys_output if t != '<pad>']))) + sys_turns.append(' '.join(sys_output)) + + for b_id in range(batch_size): + self.logprobs.append(logprob_z[b_id]) + # compute reward here + generated_dialog = {key: sys_turns} + return evaluator.evaluateModel(generated_dialog, mode="offline_rl") + +class PLASAgent(object): + def __init__(self, cvae, corpus, args, name, tune_pi_only): + self.is_stochastic = args.is_stochastic + self.fix_episode = args.fix_episode + + if "gauss" in args.sv_config_path: + self.is_gauss = True + else: + self.is_gauss = False + self.z_embedding = copy.deepcopy(cvae.z_embedding) + self.embed_z_for_critic = args.embed_z_for_critic + + self.actor = Actor(cvae, corpus, args) + self.actor_target = copy.deepcopy(self.actor) + self.opt = optim.SGD(self.actor.parameters(), lr = args.rl_lr) + + self.cvae = cvae #plas cvae model + for n, p in self.cvae.named_parameters(): + p.requires_grad = False + + if "z_loss" not in args: + args['z_loss'] = False + else: + self.z_loss = args.z_loss + if self.is_gauss: + self.z_lossf = NormKLLoss(unit_average=True) + else: + self.z_lossf = nn.CrossEntropyLoss() + + if args.critic_loss == "mse": + self.q_lossf = nn.MSELoss() + elif args.critic_loss == "huber": + self.q_lossf = nn.HuberLoss() + self.regf = nn.MSELoss() + + self.args = args + self.discount = args.gamma + self.tau = args.tau + self.lmbda = args.lmbda + self.beta = args.beta + + self.critic_dropout = args.critic_dropout + if args.critic_dropout: + if self.fix_episode: + self.critic = SingleHierarchicalRecurrentCritic(self.cvae, corpus, self.cvae.config, args) + else: + self.critic = SingleRecurrentCritic(self.cvae, corpus, self.cvae.config, args) + + else: + self.critic = RecurrentCritic(self.cvae, corpus, self.cvae.config, args) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.rl_lr) + + self.q_lossf = nn.MSELoss() + self.regf = nn.MSELoss() + + self.train_buffer = ReplayBuffer(args) + self.valid_buffer = ReplayBuffer(args) + self.test_buffer = ReplayBuffer(args) + self.corpus = corpus + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + self.n_z = args.n_z + + def print_dialog(self, dialog, reward, stats): + for t_id, turn in enumerate(dialog): + if t_id % 2 == 0: + print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>']))) + else: + print("Sys: {}".format(' '.join(turn))) + # report = ['{}: {}'.format(k, v) for k, v in stats.items()] + # print("Reward {}. {}".format(reward, report)) + print("Reward {}\n{}".format(reward, stats)) + + def run(self, batch, evaluator): + """ + run one dialogue and compute success rate with pseudo trajectory + """ + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + de_tknize = get_detokenize() + + # ret_dict contains the z and the response + z = self.actor(batch) + _, outputs, _, sample_y = self.cvae.decode_z(z, batch_size, batch, self.args.max_words, self.args.temperature) + pred_labels = np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in outputs]) + + key = batch['keys'][0] + + sys_turns = [] + # construct the dialog history for printing and calculating reward + for turn_id, turn in enumerate(batch['contexts']): + # user_input = self.corpus.id2sent(turn[-1]) + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + self.dlg_history.append(user_input) + sys_output = get_sent(self.cvae.vocab, de_tknize, pred_labels, turn_id) + # sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + sys_turns.append(sys_output) + + # return the reward here + generated_dialog = {key: sys_turns} + task_report, success, match = evaluator.evaluateModel(generated_dialog, mode="offline_rl", verbose=False) + + return sample_y,task_report, success, match + + def train(self, verbose=False, max_words=None, temp=0.1, debug=False, n=0): + de_tknize = get_detokenize() + self.logprobs = [] + + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + ctx_lens = state['context_lens'] + batch_size = len(state['context_lens']) + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + + # predict a_t + joint_log_pz_t, z_t = self.actor(state) + logprobs_t, a_prime_t = self.cvae.decode_z(z_t, batch_size, state, max_words, temp) + a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu) + + for log_prob in joint_log_pz_t: # use likelihood of latent instead of words + self.logprobs.append(log_prob) + + with th.no_grad(): + # predict a_t+1 + _, z_t1 = self.actor_target(next_state, hard=False) + logprobs_t1, a_prime_t1 = self.cvae.decode_z(z_t1, batch_size, next_state, max_words, temp) + a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu) + + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + Qs = [self.critic_target(next_state, a_prime_t1) for _ in range(5)] + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, out_utts) + else: + current_Q1 = self.critic(state, out_utts) + + + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * target_Q + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + critic_loss = q1_loss + q2_loss + + self.critic_optimizer.zero_grad() + critic_loss.backward(retain_graph=True) + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + + # Update through DPG + actor_loss = 0 + if self.critic_dropout: + q_values = self.critic(state, a_prime_t) + else: + q_values = self.critic.q1(state, a_prime_t) + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + print("True_Sys: {}".format(true_output)) + sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id) + print("Pred_Sys: {}".format(sys_output)) + print(q_values[turn_id], self.logprobs[turn_id], Return[turn_id]) + + for lp, q in zip(self.logprobs, q_values): + actor_loss -= lp * q + actor_loss = th.mean(actor_loss) + + + self.opt.zero_grad() + actor_loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip) + self.opt.step() + + + # Update Target Networks + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "actor_loss": actor_loss.item()} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + + if verbose: + logger.info(report) + + return report, loss_dict + + def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False): + de_tknize = get_detokenize() + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, out_utts) + else: + current_Q1 = self.critic(state, out_utts) + + if sl: + # use return from data + y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1) + else: + with th.no_grad(): + # predict a_t+1 + _, z_t1 = self.actor_target(next_state, hard=False) + logprobs_t1, a_prime_t1 = self.cvae.decode_z(z_t1, batch_size, next_state, max_words, temp) + a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu) + + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1) + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + target_Q = self.critic(state, out_utts) + + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done,LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * target_Q + + q1_loss = F.mse_loss(current_Q1, y) + q2_loss = F.mse_loss(current_Q2, y) if not self.critic_dropout else 0.0 + critic_loss = q1_loss + q2_loss + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + print("True_Sys: {}".format(true_output)) + print(current_Q1[turn_id], y[turn_id]) + + + self.critic_optimizer.zero_grad() + critic_loss.backward() + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + # Update Target Networks + if sl: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(param.data) + else: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + +class LatentPLASAgent_weird(object): + def __init__(self, cvae, corpus, args, evaluator, name, tune_pi_only): + if "gauss" in args.sv_config_path: + self.is_gauss = True + self.is_stochastic = args.is_stochastic + if not self.is_stochastic: + self.actor = DeterministicGaussianActor(cvae, corpus, args) + else: + self.actor = StochasticGaussianActor(cvae, corpus, args) + else: + self.is_gauss = False + self.is_stochastic = args.is_stochastic + self.actor = CatActor(cvae, corpus, args) + self.embed_z_for_critic = args.embed_z_for_critic + + self.actor_target = copy.deepcopy(self.actor) + self.opt = optim.SGD(self.actor.parameters(), lr = args.actor_rl_lr, weight_decay=0.01) + self.importance_threshold = args.importance_threshold + self.evaluator = evaluator + + if "z_loss" not in args: + args['z_loss'] = False + else: + self.z_loss = args.z_loss + if self.is_gauss: + if not self.is_stochastic: + self.z_lossf = nn.MSELoss() + else: + self.z_lossf = NormKLLoss(unit_average=True) + else: + if not self.is_stochastic: + self.z_lossf = nn.CrossEntropyLoss() + else: + self.z_lossf = CatKLLoss() + + if "critic_kl_loss" not in args: + args['critic_kl_loss'] = False + else: + if not self.is_stochastic: + z_loss = self.z_lossf(th.exp(log_pz_t), th.argmax(corpus_z_t, dim=1)) + else: + z_loss = self.z_lossf(log_pz_t, corpus_log_pz_t, batch_size, unit_average=True) + actor_loss += z_loss + + + self.opt.zero_grad() + actor_loss.backward() + actor_total_norm = 0 + parameters = [p for p in self.actor.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + actor_total_norm += param_norm.item() ** 2 + actor_total_norm = actor_total_norm ** 0.5 + nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip) + self.opt.step() + + # Update Target Networks + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm, + "actor_loss": actor_loss.item(), + "actor_grad_norm":actor_total_norm, + "z_loss":z_loss.item() if self.z_loss else 0.0} + + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + + def train_vae_model(self, batch_cnt): + de_tknize = get_detokenize() + + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + + vae_batch = copy.deepcopy(state) + vae_batch['contexts'] = np.expand_dims(action, 1) + vae_batch['outputs'] = action + vae_batch['context_lens'] = ctx_lens - 1 + + loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE) + self.cvae.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip) + self.vae_optimizer.step() + vae_loss = self.cvae.valid_loss(loss) + + return vae_loss + + def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False): + de_tknize = get_detokenize() + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu) + + if self.is_gauss: + corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts) + _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state) + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, corpus_z_t) + else: + current_Q1 = self.critic(state, corpus_z_t) + + else: + corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True) + soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False) + if self.embed_z_for_critic: + soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size) + if self.critic_dropout: + current_Q1 = self.critic(state, soft_corpus_z_t) + else: + current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t) + + _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp) + + if type(corpus_a_t[0]) == int: + corpus_a_t = [corpus_a_t] + corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu) + + + if sl: + # use return from data + y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1) + else: + with th.no_grad(): + # predict a_t+1 + if self.is_gauss: + z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state) + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + + if not self.fix_episode: + Qs = [self.critic_target(next_state, z_t1) for z_t1 in z_t1s] + else: + # special forward to avoid using pseudo-trajectory + # tic = time.perf_counter() + Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for z_t1 in z_t1s] + # toc = time.perf_counter() + # print(f"One critic pass in {toc - tic:0.4f} seconds") + + if self.args.critic_dropout_agg == "min": + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + elif self.args.critic_dropout_agg == "avg": + target_Q = th.mean(th.cat(Qs, dim=1), 1) + else: + z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state) + if self.embed_z_for_critic: + soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size) + + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)] + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + + + if self.args.critic_kl_loss: + critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach()) + else: + critic_kl_loss = 0.0 + + # critic only ever receive supervision on final state + if not self.critic_dropout: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + else: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + y = y.unsqueeze(1) + + + + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + + critic_loss = q1_loss + q2_loss + critic_kl_loss + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id) + print("True_Sys: {}".format(true_output)) + corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id) + print("VAE_Sys: {}".format(corpus_output)) + print("pred: ", current_Q1[turn_id], "y: ", y[turn_id]) + + + self.critic_optimizer.zero_grad() + critic_loss.backward(retain_graph=True) + critic_total_norm = 0 + parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + critic_total_norm += param_norm.item() ** 2 + critic_total_norm = critic_total_norm ** 0.5 + + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + # Update Target Networks + if sl: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(param.data) + else: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + +class LatentPLASAgent(object): + def __init__(self, cvae, corpus, args, evaluator, name, tune_pi_only): + if "gauss" in args.sv_config_path: + self.is_gauss = True + self.is_stochastic = args.is_stochastic + if not self.is_stochastic: + self.actor = DeterministicGaussianActor(cvae, corpus, args) + else: + self.actor = StochasticGaussianActor(cvae, corpus, args) + else: + self.is_gauss = False + self.is_stochastic = args.is_stochastic + self.actor = CatActor(cvae, corpus, args) + self.embed_z_for_critic = args.embed_z_for_critic + + self.actor_target = copy.deepcopy(self.actor) + self.opt = optim.SGD(self.actor.parameters(), lr = args.actor_rl_lr, weight_decay=0.01) + self.importance_threshold = args.importance_threshold + + self.evaluator = evaluator + + if "z_loss" not in args: + args['z_loss'] = False + else: + self.z_loss = args.z_loss + if self.is_gauss: + if not self.is_stochastic: + self.z_lossf = nn.MSELoss() + else: + self.z_lossf = NormKLLoss(unit_average=True) + else: + if not self.is_stochastic: + self.z_lossf = nn.CrossEntropyLoss() + else: + self.z_lossf = CatKLLoss() + + self.critic_kl_loss = args.critic_kl_loss + if self.is_gauss: + self.kl_lossf = NormKLLoss(unit_average=True) + else: + self.kl_lossf = CatKLLoss() + + self.args = args + self.discount = args.gamma + self.tau = args.tau + self.lmbda = args.lmbda + self.beta = args.beta + self.fix_episode = args.fix_episode + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + if self.fix_episode: + self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = RecurrentCritic(cvae, corpus, cvae.config, args) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01) + + self.cvae = cvae #plas cvae model + for n, p in self.cvae.named_parameters(): + if "aux_encoder" not in n: + p.requires_grad = False + + + + self.train_vae = args.train_vae + if self.train_vae: + self.vae_optimizer = optim.SGD([p for p in self.cvae.parameters() if p.requires_grad==True], lr = args.rl_lr) + self.vae_loss = LossManager() + self.cvae.shared_train = False + self.cvae.config.beta = args.vae_beta + if args.weighted_vae_nll: + # set NLL to be weighted on requestable tokens + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.cvae.vocab])) + print("req tokens assigned with special weights") + if args.use_gpu: + nll_weight = nll_weight.cuda() + self.cvae.nll.avg_type = "weighted" + self.cvae.nll.set_weight(nll_weight) + + + self.q_lossf = nn.MSELoss() + self.regf = nn.MSELoss() + + self.train_buffer = ReplayBuffer(args) + self.valid_buffer = ReplayBuffer(args) + self.test_buffer = ReplayBuffer(args) + self.corpus = corpus + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + self.n_z = args.n_z + + def print_dialog(self, dialog, reward, stats): + for t_id, turn in enumerate(dialog): + if t_id % 2 == 0: + print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>']))) + else: + print("Sys: {}".format(' '.join(turn))) + # report = ['{}: {}'.format(k, v) for k, v in stats.items()] + # print("Reward {}. {}".format(reward, report)) + print("Reward {}\n{}".format(reward, stats)) + + def run(self, batch): + """ + run one dialogue and compute success rate with pseudo trajectory + """ + self.logprobs = [] + self.dlg_history =[] + batch_size = len(batch['keys']) + de_tknize = get_detokenize() + + # ret_dict contains the z and the response + z = self.actor(batch) + _, outputs, _, sample_y = self.cvae.decode_z(z, batch_size, batch, self.args.max_words, self.args.temperature) + pred_labels = np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in outputs]) + + key = batch['keys'][0] + + sys_turns = [] + # construct the dialog history for printing and calculating reward + for turn_id, turn in enumerate(batch['contexts']): + # user_input = self.corpus.id2sent(turn[-1]) + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + self.dlg_history.append(user_input) + sys_output = get_sent(self.cvae.vocab, de_tknize, pred_labels, turn_id) + # sys_output = self.corpus.id2sent(outs[turn_id]) + self.dlg_history.append(sys_output) + sys_turns.append(sys_output) + + # for log_prob in logprobs: + # self.logprobs.extend(log_prob) + # return the reward here + generated_dialog = {key: sys_turns} + task_report, success, match = self.evaluator.evaluateModel(generated_dialog, mode="offline_rl", verbose=False) + + return sample_y,task_report, success, match + + def train(self, verbose=False, max_words=None, temp=0.1, debug=False, n=0): + de_tknize = get_detokenize() + self.logprobs = [] + generated_dialogs = defaultdict(list) + + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + if self.fix_episode: + key = state['keys'][0] + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu) + + # predict a_t + if self.is_gauss: + z_t, actor_mu_t, actor_logvar_t = self.actor(state) + else: + z_t, soft_z_t, log_pz_t, logits_pz_t = self.actor(state) + # soft_z_t = self.z_embedding(soft_z_t) + if self.embed_z_for_critic: + soft_z_t = self.actor.z_embedding(soft_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_z_t = soft_z_t.view(-1, self.actor.y_size * self.actor.k_size) + + logprobs_t, a_prime_t = self.cvae.decode_z(z_t, batch_size, state, max_words, temp) + if type(a_prime_t[0]) == int: + a_prime_t = [a_prime_t] + logprobs_t = [logprobs_t] + a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu) + + # predict get at using VAE, and predict Q(st, at) + if self.is_gauss: + corpus_z_t, corpus_mu_t, corpus_logvar_t = self.cvae.get_z_via_vae(out_utts) + corpus_z_t1, corpus_mu_t1, corpus_logvar_t1 = self.cvae.get_z_via_vae(next_out_utts) + _, rg_mu_t, rg_logvar_t = self.cvae.get_z_via_rg(state) # for kl loss + _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state) # for kl loss + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, corpus_z_t) + else: + current_Q1 = self.critic(state, corpus_z_t) + else: + corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True) + soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False) + _, rg_log_pz_t, _ = self.cvae.get_z_via_rg(state) # for kl loss + _, rg_log_pz_t1, _ = self.cvae.get_z_via_rg(next_state) # for kl loss + + if self.embed_z_for_critic: + soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size) + if self.critic_dropout: + current_Q1 = self.critic(state, soft_corpus_z_t) + else: + current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t) + + + # for logging purposes only + _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp) + if type(corpus_a_t[0]) == int: + corpus_a_t = [corpus_a_t] + corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu) + + + with th.no_grad(): + # predict a_t+1 and Q(s_t+1, a_t+1) + if self.is_gauss: + z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state) + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + if not self.fix_episode: + Qs = [self.critic_target(next_state, z_t1) for _ in range(5)] + else: + # special forward to avoid using pseudo-trajectory + # tic = time.perf_counter() + Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for _ in range(5)] + # toc = time.perf_counter() + # print(f"One batch critic forward pass in {toc - tic:0.4f} seconds") + + if self.args.critic_dropout_agg == "min": + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + elif self.args.critic_dropout_agg == "avg": + target_Q = th.mean(th.cat(Qs, dim=1), 1) + + else: + z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state) + if self.embed_z_for_critic: + soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size) + + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)] + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + + + + if self.args.critic_kl_loss: + if self.is_gauss: + critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach()) + else: + critic_kl_loss = self.kl_lossf(log_pz_t1, rg_log_pz_t1.detach(), unit_average=True) + else: + critic_kl_loss = 0.0 + + if not self.critic_dropout: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + else: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + y = y.unsqueeze(1) + + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + + critic_loss = q1_loss + q2_loss + critic_kl_loss + + self.critic_optimizer.zero_grad() + critic_loss.backward(retain_graph=True) + critic_total_norm = 0 + parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + critic_total_norm += param_norm.item() ** 2 + critic_total_norm = critic_total_norm ** 0.5 + + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + + # Update through DPG + actor_loss = 0 + if self.is_gauss: + if self.critic_dropout: + q_values = self.critic(state, z_t) + else: + q_values = self.critic.q1(state, z_t) + else: + if self.critic_dropout: + q_values = self.critic(state, soft_z_t) + else: + q_values = self.critic.q1(state, soft_z_t) + + if debug: + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + print("True_Sys: {}".format(true_output)) + corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id) + print("VAE_Sys: {}".format(corpus_output)) + sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id) + print("Pred_Sys: {}".format(sys_output)) + # print(q_values[turn_id], Return[turn_id]) + print("pred_Q: ", current_Q1[turn_id], "y_Q: ", y[turn_id], "target: ", target_Q[turn_id], "Return: ", Return[turn_id]) + print() + + if not self.is_stochastic: + actor_loss = - th.mean(q_values) + else: + if self.is_gauss: + # off-policy policy gradient + pi_a_s = self.cvae.gaussian_prob(actor_mu_t, actor_logvar_t, corpus_z_t) + 1e-15 # batch_size, 200 + pib_a_s = self.cvae.gaussian_prob(corpus_mu_t, corpus_logvar_t, corpus_z_t) # batch_size, 200 + threshold = th.ones([batch_size, self.cvae.y_size]) * self.importance_threshold + if self.args.use_gpu: + threshold = threshold.cuda() + importance_ratio = th.min(pi_a_s / pib_a_s, threshold) + + actor_loss = - th.mean(importance_ratio.detach() * th.log(pi_a_s) * q_values.detach()) + + else: + log_pi_a_s = self.cvae.categorical_logprob(logits_pz_t, corpus_z_t, temp=1.0) # batch_size, 200 + log_pib_a_s = self.cvae.categorical_logprob(corpus_logits_pz_t, corpus_z_t, temp=1.0) # batch_size, 200 + threshold = th.ones([batch_size]) * self.importance_threshold + if self.args.use_gpu: + threshold = threshold.cuda() + importance_ratio = th.min(th.exp(log_pi_a_s) / th.exp(log_pib_a_s), threshold) + actor_loss = - th.mean(importance_ratio.detach() * log_pi_a_s * q_values.detach()) + + + + if self.args.z_loss: + if self.is_gauss: + if not self.is_stochastic: + z_loss = self.z_lossf(z_t, corpus_z_t.detach()) + else: + z_loss = 0.01 * self.z_lossf(actor_mu_t, actor_logvar_t, corpus_mu_t.detach(), corpus_logvar_t.detach()) + actor_loss += z_loss + else: + if not self.is_stochastic: + z_loss = self.z_lossf(th.exp(log_pz_t), th.argmax(corpus_z_t, dim=1)) + else: + z_loss = self.z_lossf(log_pz_t, corpus_log_pz_t, batch_size, unit_average=True) + actor_loss += z_loss + + + self.opt.zero_grad() + actor_loss.backward() + actor_total_norm = 0 + parameters = [p for p in self.actor.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + actor_total_norm += param_norm.item() ** 2 + actor_total_norm = actor_total_norm ** 0.5 + nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip) + self.opt.step() + + + # Update Target Networks + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm, + "actor_loss": actor_loss.item(), + "actor_grad_norm":actor_total_norm, + "z_loss":z_loss.item() if self.z_loss else 0.0} + + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + + def train_vae_model(self, batch_cnt): + de_tknize = get_detokenize() + + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + + vae_batch = copy.deepcopy(state) + vae_batch['contexts'] = np.expand_dims(action, 1) + vae_batch['outputs'] = action + vae_batch['context_lens'] = ctx_lens - 1 + + loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE) + self.cvae.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip) + self.vae_optimizer.step() + vae_loss = self.cvae.valid_loss(loss) + + return vae_loss + + def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False): + de_tknize = get_detokenize() + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu) + + if self.is_gauss: + corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts) + _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state) + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, corpus_z_t) + else: + current_Q1 = self.critic(state, corpus_z_t) + + else: + corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True) + soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False) + if self.embed_z_for_critic: + # corpus_z_t = self.z_embedding(corpus_z_t) + soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size) + if self.critic_dropout: + current_Q1 = self.critic(state, soft_corpus_z_t) + else: + current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t) + + _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp) + + if type(corpus_a_t[0]) == int: + corpus_a_t = [corpus_a_t] + corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu) + + + if sl: + # use return from data + y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1) + else: + with th.no_grad(): + # predict a_t+1 + if self.is_gauss: + z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state) + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + + if not self.fix_episode: + Qs = [self.critic_target(next_state, z_t1) for _ in range(5)] + else: + # special forward to avoid using pseudo-trajectory + # tic = time.perf_counter() + Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for _ in range(5)] + # toc = time.perf_counter() + # print(f"One critic pass in {toc - tic:0.4f} seconds") + + if self.args.critic_dropout_agg == "min": + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + elif self.args.critic_dropout_agg == "avg": + target_Q = th.mean(th.cat(Qs, dim=1), 1) + else: + z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state) + if self.embed_z_for_critic: + soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0) + else: + soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size) + + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)] + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + + + if self.args.critic_kl_loss: + critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach()) + else: + critic_kl_loss = 0.0 + + # critic only ever receive supervision on final state + if not self.critic_dropout: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + else: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + y = y.unsqueeze(1) + + + + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + + critic_loss = q1_loss + q2_loss + critic_kl_loss + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id) + print("True_Sys: {}".format(true_output)) + corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id) + print("VAE_Sys: {}".format(corpus_output)) + print("pred: ", current_Q1[turn_id], "y: ", y[turn_id]) + + + self.critic_optimizer.zero_grad() + critic_loss.backward(retain_graph=True) + critic_total_norm = 0 + parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + critic_total_norm += param_norm.item() ** 2 + critic_total_norm = critic_total_norm ** 0.5 + + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + # Update Target Networks + if sl: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(param.data) + else: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + +class LatentCriticAgent(object): + def __init__(self, cvae, corpus, args, evaluator, name, vae=None): + if "gauss" in args.config_path: + self.is_gauss = True + self.kl_lossf = NormKLLoss(unit_average=True) + else: + self.is_gauss = False + self.kl_lossf = CatKLLoss() + self.embed_z_for_critic = args.embed_z_for_critic + + self.evaluator = evaluator + self.corpus = corpus + + if "raw_response" in args: + self.raw_response = args.raw_response + else: + self.raw_response = False + + if self.raw_response: + self._read_responses(args.response_path) + + if "corpus_response" in args: + self.corpus_response = args.corpus_response + else: + self.corpus_response = False + + assert self.corpus_response != self.raw_response + + if self.corpus_response or self.raw_response: + self.vae = vae + else: + self.vae=None + + self.args = args + self.discount = args.gamma + self.tau = args.tau + self.lmbda = args.lmbda + self.beta = args.beta + self.fix_episode = args.fix_episode + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + if self.fix_episode: + self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = RecurrentCritic(cvae, corpus, cvae.config, args) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01) + + self.cvae = cvae #lava model + for n, p in self.cvae.named_parameters(): + p.requires_grad = False # no param is trained + + self.train_vae = args.train_vae + if self.train_vae: + self.vae_optimizer = optim.SGD([p for p in self.cvae.parameters() if p.requires_grad==True], lr = args.rl_lr) + self.vae_loss = LossManager() + self.cvae.shared_train = False + self.cvae.config.beta = args.vae_beta + if args.weighted_vae_nll: + # set NLL to be weighted on requestable tokens + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.cvae.vocab])) + print("req tokens assigned with special weights") + if args.use_gpu: + nll_weight = nll_weight.cuda() + self.cvae.nll.avg_type = "weighted" + self.cvae.nll.set_weight(nll_weight) + + self.q_lossf = nn.MSELoss() + self.regf = nn.MSELoss() + + self.train_buffer = ReplayBuffer(args) + self.valid_buffer = ReplayBuffer(args) + self.test_buffer = ReplayBuffer(args) + self.corpus = corpus + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + self.n_z = args.n_z + + def _read_responses(self, json_path): + """ + following shades of bleu format + """ + + self.raw_responses = defaultdict(list) + + # read responses + json_list = [json_path, json_path.replace("test", "val"), json_path.replace("test", "train")] + for j in json_list: + print(f"Reading responses from {j}...") + with open(j) as f: + raw_data = json.load(f) + + # convert to ID of critic encoder + for k in raw_data.keys(): + if "augpt" in json_path: + self.raw_responses[k] = [self.corpus._sent2id(replace_augpt_tokens(augpt_normalize(raw_data[k]['response'][i]), raw_data[k]['active_domain'][i]).split()) for i in range(len(raw_data[k]['response']))] + elif "HDSA" in json_path or "hdsa" in json_path: + self.raw_responses[k + ".json"] = [self.corpus._sent2id(replace_hdsa_tokens(raw_data[k][i])) for i in range(len(raw_data[k]))] + + print(len(raw_data.keys()), "dialogues read") + + print(f"responses from {len(self.raw_responses.keys())} dialogues read") + + + def train_vae_model(self, batch_cnt): + de_tknize = get_detokenize() + + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + + vae_batch = copy.deepcopy(state) + vae_batch['contexts'] = np.expand_dims(action, 1) + vae_batch['outputs'] = action + vae_batch['context_lens'] = ctx_lens - 1 + + loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE) + self.cvae.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip) + self.vae_optimizer.step() + vae_loss = self.cvae.valid_loss(loss) + + return vae_loss + + def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False, n=0): + de_tknize = get_detokenize() + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + key = state['keys'][0] + + if self.raw_response: + while key not in self.raw_responses.keys(): + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + key = state['keys'][0] + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu) + if self.raw_response: + a_prime_t1 = self.cvae.np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in self.raw_responses[key][1:] + [[]]]), LONG) + elif self.corpus_response: + a_prime_t1 = next_out_utts + + if self.is_gauss: + corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts) + _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp) + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, corpus_z_t) + else: + current_Q1 = self.critic(state, corpus_z_t) + + else: + corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True) + soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False) + _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp) + if self.embed_z_for_critic: + soft_corpus_z_t = self.cvae.z_embedding(soft_corpus_z_t.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0) + corpus_z_t = self.cvae.z_embedding(corpus_z_t.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0) + else: + soft_corpus_z_t = soft_corpus_z_t.view(-1, self.cvae.y_size * self.cvae.k_size) + corpus_z_t = corpus_z_t.view(-1, self.cvae.y_size * self.cvae.k_size) + if self.critic_dropout: + current_Q1 = self.critic(state, soft_corpus_z_t) + else: + current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t) + + + + if type(corpus_a_t[0]) == int: + corpus_a_t = [corpus_a_t] + corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu) + + z_t, actor_mu_t, actor_logvar_t = self.cvae.get_z_via_rg(state) + logprobs_t, a_prime_t = self.cvae.decode_z(z_t, batch_size, state, max_words, temp) + if type(a_prime_t[0]) == int: + a_prime_t = [a_prime_t] + a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu) + + + with th.no_grad(): + if sl: + # use return from data + y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1) + else: + # predict a_t+1 + if self.is_gauss: + if not self.raw_response: + z_t1, actor_mu_t1, actor_logvar_t1 = self.cvae.get_z_via_rg(next_state) + else: + z_t1, actor_mu_t1, actor_logvar_t1 = self.vae.get_z_via_vae(a_prime_t1) + + _, corpus_mu_t1, corpus_logvar_t1 = self.cvae.get_z_via_vae(next_out_utts) + + else: #categorical + z_t1, actor_log_pz_t1, actor_pz_t1 = self.cvae.get_z_via_rg(next_state, hard=True) + if self.embed_z_for_critic: + z_t1 = self.cvae.z_embedding(z_t1.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0) + else: + z_t1 = z_t1.view(-1, self.cvae.y_size * self.cvae.k_size) + _, corpus_log_pz_t1, corpus_logits_pz_t1 = self.cvae.get_z_via_vae(next_out_utts) + + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, z_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + if not self.fix_episode: + Qs = [self.critic_target(next_state, z_t1) for z_t1 in z_t1s] + else: + # special forward to avoid using pseudo-trajectory + Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for z_t1 in z_t1s] + + if self.args.critic_dropout_agg == "min": + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + elif self.args.critic_dropout_agg == "avg": + target_Q = th.mean(th.cat(Qs, dim=1), 1) + + + if self.args.critic_kl_loss: + if self.is_gauss: + critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, corpus_mu_t1.detach(), corpus_logvar_t1.detach()) + else: + critic_kl_loss = self.kl_lossf(actor_log_pz_t1, corpus_log_pz_t1.detach(), unit_average=True) + else: + critic_kl_loss = 0.0 + + # critic only ever receive supervision on final state + if not self.critic_dropout: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + else: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) + y = y.unsqueeze(1) + + + + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + + critic_loss = q1_loss + q2_loss + critic_kl_loss + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id) + print("True_Sys: {}".format(true_output)) + corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id) + print("VAE_Sys: {}".format(corpus_output)) + print("pred: ", current_Q1[turn_id], "y: ", y[turn_id]) + + + self.critic_optimizer.zero_grad() + critic_loss.backward() + critic_total_norm = 0 + parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + critic_total_norm += param_norm.item() ** 2 + critic_total_norm = critic_total_norm ** 0.5 + + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + # Update Target Networks + if sl: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(param.data) + else: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + +class CriticAgent(LatentCriticAgent): + + def __init__(self, cvae, corpus, args, evaluator, name): + if "gauss" in args.config_path: + self.is_gauss = True + self.kl_lossf = NormKLLoss(unit_average=True) + else: + self.is_gauss = False + self.kl_lossf = CatKLLoss() + self.embed_z_for_critic = args.embed_z_for_critic + + self.evaluator = evaluator + self.corpus = corpus + + if "raw_response" in args: + self.raw_response = args.raw_response + else: + self.raw_response = False + + if self.raw_response: + self._read_responses(args.response_path) + + if "corpus_response" in args: + self.corpus_response = args.corpus_response + else: + self.corpus_response = False + + if self.corpus_response or self.raw_response: + assert self.corpus_response != self.raw_response + + self.args = args + self.discount = args.gamma + self.tau = args.tau + self.lmbda = args.lmbda + self.beta = args.beta + self.fix_episode = args.fix_episode + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + if self.fix_episode: + self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args) + else: + self.critic = RecurrentCritic(cvae, corpus, cvae.config, args) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01) + + self.cvae = cvae #lava model + for n, p in self.cvae.named_parameters(): + p.requires_grad = False # no param is trained + + if "actor_path" in args and args.actor_path is not None: + with open(args.actor_config, "r") as f: + actor_config = Pack(json.load(f)) + self.is_stochastic = actor_config.is_stochastic + if self.is_gauss: + if not self.is_stochastic: + self.actor = DeterministicGaussianActor(cvae, corpus, args) + else: + self.actor = StochasticGaussianActor(cvae, corpus, args) + else: + self.actor = CatActor(cvae, corpus, args) + self.embed_z_for_critic = args.embed_z_for_critic + + actor_dict = th.load(args.actor_path, map_location=lambda storage, location: storage) + self.actor.load_state_dict(actor_dict) + else: + self.actor = None + + + self.q_lossf = nn.MSELoss() + self.regf = nn.MSELoss() + + self.train_buffer = ReplayBuffer(args) + self.valid_buffer = ReplayBuffer(args) + self.test_buffer = ReplayBuffer(args) + self.name = name + self.raw_goal = None + self.vec_goals_list = None + self.logprobs = None + self.n_z = args.n_z + + def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False, n=0): + de_tknize = get_detokenize() + # Sample replay buffer / batch + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + key = state['keys'][0] + + if self.raw_response: + while key not in self.raw_responses.keys(): + experiences = self.train_buffer.sample() + state, action, reward, next_state, expert_next_action, done, Return = experiences + key = state['keys'][0] + + ctx_lens = state['context_lens'] # (batch_size, ) + batch_size = len(state['context_lens']) + + out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu) + next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu) + corpus_a_t = action + corpus_a_t1 = expert_next_action + + if type(corpus_a_t[0]) == int: + corpus_a_t = [corpus_a_t] + corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu) + + if not self.critic_dropout: + current_Q1, current_Q2 = self.critic(state, corpus_a_t) + else: + current_Q1 = self.critic(state, corpus_a_t) + + # compute target + with th.no_grad(): + if sl: + # use return from data + y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1) + else: + if self.raw_response: + a_prime_t1 = self.cvae.np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in self.raw_responses[key][1:] + [[]]]), LONG) + elif self.corpus_response: + a_prime_t1 = next_out_utts + else: + # predict a_t+1 + if self.actor is not None: + z_t1, actor_mu_t1, actor_logvar_t1 = self.actor(next_state) + _, a_prime_t1 = self.cvae.decode_z(z_t1, batch_size, next_state, self.args.max_words, self.args.temperature) + else: + logprobs, a_prime_t1, joint_logpz, sample_z = self.cvae.forward_rl(next_state, max_words=max_words, temp=temp) + + if type(a_prime_t1[0]) == int: + a_prime_t1 = [a_prime_t1] + a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu) + + # Critic Training + if not self.critic_dropout: + target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1) + # Soft Clipped Double Q-learning + target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2) + else: + if not self.fix_episode: + Qs = [self.critic_target(next_state, a_prime_t1) for _ in range(5)] + else: + # special forward to avoid using pseudo-trajectory + # tic = time.perf_counter() + Qs = [self.critic_target.forward_target(next_state, a_prime_t1.unsqueeze(1), corpus_a_t) for _ in range(5)] + # toc = time.perf_counter() + # print(f"One batch of critic pass in {toc - tic:0.4f} seconds") + if self.args.critic_dropout_agg == "min": + target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0] + elif self.args.critic_dropout_agg == "avg": + target_Q = th.mean(th.cat(Qs, dim=1), 1) + + # critic only ever receive supervision on final state + if not self.critic_dropout: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q) + else: + y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q) + y = y.unsqueeze(1) + + + + q1_loss = self.q_lossf(current_Q1, y) + q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0 + + critic_loss = q1_loss + q2_loss + + if debug: + + print("===TURN T===") + for turn_id, turn in enumerate(state['contexts']): + user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1) + print("Usr: {}".format(user_input)) + if turn_id > 0 and not sl: + sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t1, turn_id - 1) + print("Pred_Sys: {}".format(sys_output)) + true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id) + print("True_Sys: {}".format(true_output)) + corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id) + print("VAE_Sys: {}".format(corpus_output)) + print("pred: ", current_Q1[turn_id], "y: ", y[turn_id]) + + + self.critic_optimizer.zero_grad() + critic_loss.backward() + critic_total_norm = 0 + parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad] + for p in parameters: + param_norm = p.grad.detach().data.norm(2) + critic_total_norm += param_norm.item() ** 2 + critic_total_norm = critic_total_norm ** 0.5 + + nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip) + self.critic_optimizer.step() + + # Update Target Networks + if sl: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(param.data) + else: + for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): + target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) + + loss_dict = {"critic_loss": critic_loss.item(), + "q1_loss": q1_loss.item(), + "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, + "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0, + "critic_grad_norm":critic_total_norm} + + report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()]) + + if verbose: + logger.info(report) + + return report, loss_dict + diff --git a/latent_dialog/augpt_utils.py b/latent_dialog/augpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5c4c1db7358b6a7e321f56bb05a6db0ed637e6 --- /dev/null +++ b/latent_dialog/augpt_utils.py @@ -0,0 +1,458 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright © 2021 lubis <lubis@hilbert242> +# +# Distributed under terms of the MIT license. + +""" +utils from AuGPT codebase +""" +import re +import os +import sys +import types +import shutil +import logging +import requests +import torch +import zipfile +import bisect +import random +import copy +import json +from collections import OrderedDict, defaultdict +from typing import Callable, Union, Set, Optional, List, Dict, Any, Tuple, MutableMapping # noqa: 401 +from dataclasses import dataclass +import pdb + +DATASETS_PATH = os.path.join(os.path.expanduser(os.environ.get('DATASETS_PATH', '~/datasets')), 'augpt') + +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + +fin = open(os.path.join(DATASETS_PATH, 'mapping.pair')) +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + + + +class AutoDatabase: + @staticmethod + def load(pretrained_model_name_or_path): + database_file = os.path.join(pretrained_model_name_or_path, 'database.zip') + + with zipfile.ZipFile(database_file) as zipf: + def _build_database(): + module = types.ModuleType('database') + exec(zipf.read('database.py').decode('utf-8'), module.__dict__) + return module.Database(zipf) + + database = _build_database() + + + return database + +class BeliefParser: + def __init__(self): + self.slotval_re = re.compile(r"(\w[\w ]*\w) = ([\w\d: |']+)") + self.domain_re = re.compile(r"(\w+) {\s*([\w,= :\d|']*)\s*}", re.IGNORECASE) + + def __call__(self, raw_belief: str): + belief = OrderedDict() + for match in self.domain_re.finditer(raw_belief): + domain, domain_bs = match.group(1), match.group(2) + belief[domain] = {} + for slot_match in self.slotval_re.finditer(domain_bs): + slot, val = slot_match.group(1), slot_match.group(2) + belief[domain][slot] = val + return belief + +class AutoLexicalizer: + @staticmethod + def load(pretrained_model_name_or_path): + lexicalizer_file = os.path.join(pretrained_model_name_or_path, 'lexicalizer.zip') + + with zipfile.ZipFile(lexicalizer_file) as zipf: + def _build_lexicalizer(): + module = types.ModuleType('lexicalizer') + exec(zipf.read('lexicalizer.py').decode('utf-8'), module.__dict__) + return module.Lexicalizer(zipf) + + lexicalizer = _build_lexicalizer() + + + return lexicalizer + +def build_blacklist(items, domains=None): + for i, (dialogue, items) in enumerate(items): + if domains is not None and set(dialogue['domains']).difference(domains): + yield i + elif items[-1]['speaker'] != 'system': + yield i + +class BlacklistItemsWrapper: + def __init__(self, items, blacklist): + self.items = items + self.key2idx = items.key2idx + self._indexmap = [] + blacklist_pointer = 0 + for i in range(len(items)): + if blacklist_pointer >= len(blacklist): + self._indexmap.append(i) + elif i < blacklist[blacklist_pointer]: + self._indexmap.append(i) + elif i == blacklist[blacklist_pointer]: + blacklist_pointer += 1 + assert len(self._indexmap) == len(items) - len(blacklist) + + def __getitem__(self, idx): + if isinstance(idx, str): + idx = self.key2idx[idx] + return self.items[self._indexmap[idx]] + + def __len__(self): + return len(self._indexmap) +def split_name(dataset_name: str): + split = dataset_name.rindex('/') + return dataset_name[:split], dataset_name[split + 1:] + +@dataclass +class DialogDatasetItem: + context: Union[List[str], str] + belief: Union[Dict[str, Dict[str, str]], str] = None + database: Union[List[Tuple[str, int]], List[Tuple[str, int, Any]], None, str] = None + response: str = None + positive: bool = True + raw_belief: Any = None + raw_response: str = None + key: str = None + + def __getattribute__(self, name): + val = object.__getattribute__(self, name) + if name == 'belief' and val is None and self.raw_belief is not None: + val = format_belief(self.raw_belief) + self.belief = val + return val + +@dataclass +class DialogDataset(torch.utils.data.Dataset): + items: List[any] + database: Any = None + domains: List[str] = None + lexicalizer: Any = None + transform: Callable[[Any], Any] = None + normalize_input: Callable[[str], str] = None + ontology: Dict[Tuple[str, str], Set[str]] = None + + @staticmethod + def build_dataset_without_database(items, *args, **kwargs): + return DialogDataset(items, FakeDatabase(), *args, **kwargs) + + def __getitem__(self, index): + item = self.items[index] + if self.transform is not None: + item = self.transform(item) + return item + + def __len__(self): + return len(self.items) + + def map(self, transformation): + def trans(x): + x = self.transform(x) + x = transformation(x) + return x + return dataclasses.replace(self, transform=trans) + + def finish(self, progressbar: Union[str, bool] = False): + if self.transform is None: + return self + + ontology = defaultdict(lambda: set()) + domains = set(self.domains) if self.domains else set() + + items = [] + for i in trange(len(self), + desc=progressbar if isinstance(progressbar, str) else 'loading dataset', + disable=not progressbar): + item = self[i] + for k, bs in item.raw_belief.items(): + domains.add(k) + for k2, val in bs.items(): + ontology[(k, k2)].add(val) + items.append(item) + if self.ontology: + ontology = merge_ontologies((self.ontology, ontology)) + return dataclasses.replace(self, items=items, transform=None, domains=domains, ontology=ontology) + +class DialogueItems: + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + r.append(e + s) + s += e + return r + + def __init__(self, dialogues): + lengths = [len(x['items']) for x in dialogues] + self.keys = [x['name'] for x in dialogues] + self.key2idx = {k:i for (i, k) in enumerate(self.keys)} + self.cumulative_sizes = DialogueItems.cumsum(lengths) + self.dialogues = dialogues + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dialogue_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dialogue_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dialogue_idx - 1] + return self.dialogues[dialogue_idx], self.dialogues[dialogue_idx]['items'][:sample_idx + 1] + + def __len__(self): + if not self.cumulative_sizes: + return 0 + return self.cumulative_sizes[-1] + + +def load_dataset(name, use_goal=False, context_window_size=15, domains=None, **kwargs) -> DialogDataset: + name, split = split_name(name) + path = os.path.join(DATASETS_PATH, name) + with open(os.path.join(path, f'{split}.json'), 'r') as f: + data = json.load(f, object_pairs_hook=OrderedDict) + dialogues = data['dialogues'] + items = DialogueItems(dialogues) + items = BlacklistItemsWrapper(items, list(build_blacklist(items, domains))) + + def transform(x): + dialogue, items = x + context = [s['text'] for s in items[:-1]] + if context_window_size is not None and context_window_size > 0: + context = context[-context_window_size:] + belief = items[-1]['belief'] + database = items[-1]['database'] + item = DialogDatasetItem(context, + raw_belief=belief, + database=database, + response=items[-1]['delexicalised_text'], + raw_response=items[-1]['text'], + key=dialogue['name']) + if use_goal: + setattr(item, 'goal', dialogue['goal']) + # MultiWOZ evaluation uses booked domains property + if 'booked_domains' in items[-1]: + setattr(item, 'booked_domains', items[-1]['booked_domains']) + setattr(item, 'dialogue_act', items[-1]['dialogue_act']) + setattr(item, 'active_domain', items[-1]['active_domain']) + return item + + dataset = DialogDataset(items, transform=transform, domains=data['domains']) + if os.path.exists(os.path.join(path, 'database.zip')): + dataset.database = AutoDatabase.load(path) + + if os.path.exists(os.path.join(path, 'lexicalizer.zip')): + dataset.lexicalizer = AutoLexicalizer.load(path) + + return dataset + +def format_belief(belief: OrderedDict) -> str: + assert isinstance(belief, OrderedDict) + str_bs = [] + for domain, domain_bs in belief.items(): + domain_bs = ', '.join([f'{slot} = {val}' for slot, val in sorted(domain_bs.items(), key=lambda x: x[0])]) + str_bs.extend([domain, '{' + domain_bs + '}']) + return ' '.join(str_bs) + +def augpt_normalize(text, delexicalize=True): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + if delexicalize: + text = re.sub(pricepat, ' [price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + if delexicalize: + text = re.sub('[\":\<>@\(\)]', '', text) + else: + text = re.sub('[\"\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + +def replace_augpt_tokens(text, active_domains): + if type(active_domains) == list: + if len(active_domains) > 1: + pdb.set_trace() + elif len(active_domains) == 0: + active_domain = None + else: + active_domain = active_domains[0] + else: + active_domain = active_domains + + tokens = text.split() + + ret = [] + for t in tokens: + if t[0] == "[" and len(t) > 1: + if t in ["[name]", "[address]", "[phone]", "[postcode]", "[reference]", "[id]"]: + if active_domain is None: + continue + else: + ret.append(f'[{active_domain}_{t[1:]}') + elif t == "[car]": + ret.append("[taxi_type]") + elif t in ["[departure]", "[destination]"]: + ret.append("[value_place]") + elif t in ["[leave", "[arrive"]: + ret.append("[value_time]") + elif t in ["at]", "by]", "range]"]: + continue + elif t in ["[food]", "[area]", "[time]", "[day]", "[price]"]: + ret.append(f"[value_{t[1:]}") + elif t == "[price": + ret.append("[value_pricerange]") + elif t in ["[duration]", "[stars]", "[stay]", "[people]"]: + ret.append("[value_count]") + elif t == "[type]": + ret.append("hotel") + else: + pdb.set_trace() + else: + ret.append(t) + + ret = " ".join(ret) + ret = re.sub("([0-9])+", "[value_count]", ret) + + return ret + + +def replace_hdsa_tokens(text): + tokens = text.split() + ret = [] + for t in tokens: + if t[0] == "[" and len(t) > 1: + if t == "[train_trainid]": + ret.append("[train_id]") + elif t in ["[train_arriveby]", "[train_leaveat]"]: + ret.append("[value_time]") + elif t in ["[train_departure]", "[train_destination]"]: + ret.append("[value_place]") + elif t in ["[hotel_pricerange]", "[restaurant_pricerange]"]: + ret.append("[value_pricerange]") + elif t in ["[hotel_area]", "[attraction_area]", "[restaurant_area]"]: + ret.append("[value_area]") + elif t in ["[train_price]"]: + ret.append("[value_price]") + elif t == "[train_day]": + ret.append("[value_day]") + elif t == "[restaurant_food]": + ret.append("[value_food]") + else: + ret.append(t) + # skipped.append(t) + else: + ret.append(t) + + # ret = " ".join(ret) + # ret = re.sub("([0-9])+", "[value_count]", ret) + + # return ret, skipped + return ret + diff --git a/latent_dialog/base_data_loaders.py b/latent_dialog/base_data_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..e96372acf9fcdbbf17bf393dc58158751212dacc --- /dev/null +++ b/latent_dialog/base_data_loaders.py @@ -0,0 +1,183 @@ +import numpy as np +import logging + + +class BaseDataLoaders(object): + def __init__(self, name): + self.data_size = None + self.indexes = None + self.name = name + + def _shuffle_indexes(self): + np.random.shuffle(self.indexes) + + def _shuffle_batch_indexes(self): + np.random.shuffle(self.batch_indexes) + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + + if shuffle and not fix_batch: + self._shuffle_indexes() + + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i*self.batch_size: (i+1)*self.batch_size]) + + if shuffle and fix_batch: + self._shuffle_batch_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def next_batch(self): + if self.ptr < self.num_batch: + selected_ids = self.batch_indexes[self.ptr] + self.ptr += 1 + return self._prepare_batch(selected_index=selected_ids) + else: + return None + + def _prepare_batch(self, *args, **kwargs): + raise NotImplementedError('Have to override _prepare_batch()') + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + +class LongDataLoader(object): + """A special efficient data loader for TBPTT. Assume the data contains + N long sequences, each sequence has length k_i + + :ivar batch_size: the size of a minibatch + :ivar backward_size: how many steps in time to do BP + :ivar step_size: how fast we move the window + :ivar ptr: the current idx of batch + :ivar num_batch: the total number of batch + :ivar batch_indexes: a list of list. Each item is the IDs in this batch + :ivar grid_indexes: a list of (b_id, s_id, e_id). b_id is the index of + batch, s_id is the starting time id in that batch and e_id is the ending + time id. + :ivar indexes: a list, the ordered of sequences ID it should go through + :ivar data_size: the number of sequences, N. + :ivar data_lens: a list containing k_i + :ivar prev_alive_size: + :ivar name: the name of the this data loader + """ + logger = logging.getLogger() + + def __init__(self, name): + self.batch_size = 0 + self.backward_size = 0 + self.step_size = 0 + self.ptr = 0 + self.num_batch = None + self.batch_indexes = None # one batch is a dialog + self.grid_indexes = None # grid is the tokenized versiion + self.indexes = None + self.data_lens = None + self.data_size = None + self.name = name + + def _shuffle_batch_indexes(self): + np.random.shuffle(self.batch_indexes) + + def _shuffle_grid_indexes(self): + np.random.shuffle(self.grid_indexes) + + def _prepare_batch(self, cur_grid, prev_grid): + raise NotImplementedError("Have to override prepare batch") + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + + assert len(self.indexes) == self.data_size and \ + len(self.data_lens) == self.data_size + # make sure backward_size can be divided by step size + assert config.backward_size % config.step_size == 0 + + self.ptr = 0 + self.batch_size = config.batch_size + self.backward_size = config.backward_size + self.step_size = config.step_size + + # create batch indexes + temp_num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(temp_num_batch): + self.batch_indexes.append( + self.indexes[i * self.batch_size:(i + 1) * self.batch_size]) + + left_over = self.data_size - temp_num_batch * config.batch_size + if shuffle: + self._shuffle_batch_indexes() + + # create grid indexes + self.grid_indexes = [] + for idx, b_ids in enumerate(self.batch_indexes): + # assume the b_ids are sorted + all_lens = [self.data_lens[i] for i in b_ids] + max_len = self.data_lens[b_ids[0]] + min_len = self.data_lens[b_ids[-1]] + assert np.max(all_lens) == max_len + assert np.min(all_lens) == min_len + num_seg = (max_len - self.backward_size - self.step_size) // self.step_size + cut_start, cut_end = [], [] + if num_seg > 1: + cut_start = list(range(config.step_size, num_seg * config.step_size, config.step_size)) + cut_end = list(range(config.backward_size + config.step_size, + num_seg * config.step_size + config.backward_size, + config.step_size)) + assert cut_end[-1] < max_len + + actual_size = min(max_len, config.backward_size) + temp_end = list(range(2, actual_size, config.step_size)) + temp_start = [0] * len(temp_end) + + cut_start = temp_start + cut_start + cut_end = temp_end + cut_end + + assert len(cut_end) == len(cut_start) + new_grids = [(idx, s_id, e_id) for s_id, e_id in + zip(cut_start, cut_end) if s_id < min_len - 1] + + self.grid_indexes.extend(new_grids) + + # shuffle batch indexes + if shuffle: + self._shuffle_grid_indexes() + + self.num_batch = len(self.grid_indexes) + if verbose: + self.logger.info("%s init with %d batches with %d left over samples" % + (self.name, self.num_batch, left_over)) + + def next_batch(self): + if self.ptr < self.num_batch: + current_grid = self.grid_indexes[self.ptr] + if self.ptr > 0: + prev_grid = self.grid_indexes[self.ptr - 1] + else: + prev_grid = None + self.ptr += 1 + return self._prepare_batch(cur_grid=current_grid, + prev_grid=prev_grid) + else: + return None + + def pad_to(self, max_len, tokens, do_pad=True): + if len(tokens) >= max_len: + return tokens[0:max_len - 1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens diff --git a/latent_dialog/base_models.py b/latent_dialog/base_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c793e8c732bbdb4220b1adfae375e2beb69808bc --- /dev/null +++ b/latent_dialog/base_models.py @@ -0,0 +1,116 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +import numpy as np +from latent_dialog.utils import INT, FLOAT, LONG, cast_type +import pdb + + +class BaseModel(nn.Module): + def __init__(self, config): + super(BaseModel, self).__init__() + self.use_gpu = config.use_gpu + self.config = config + self.kl_w = 0.0 + + def np2var(self, inputs, dtype): + if inputs is None: + return None + return cast_type(Variable(th.from_numpy(inputs)), + dtype, + self.use_gpu) + + def forward(self, *inputs): + raise NotImplementedError + + def backward(self, loss, batch_cnt): + total_loss = self.valid_loss(loss, batch_cnt) + total_loss.backward() + + def valid_loss(self, loss, batch_cnt=None): + total_loss = 0.0 + for k, l in loss.items(): + if l is not None: + total_loss += l + return total_loss + + def get_optimizer(self, config, verbose=True): + if config.op == 'adam': + if verbose: + print('Use Adam') + return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=config.init_lr, weight_decay=config.l2_norm) + elif config.op == 'sgd': + print('Use SGD') + return optim.SGD(self.parameters(), lr=config.init_lr, momentum=config.momentum) + elif config.op == 'rmsprop': + print('Use RMSProp') + return optim.RMSprop(self.parameters(), lr=config.init_lr, momentum=config.momentum) + + def get_clf_optimizer(self, config): + params = [] + params.extend(self.gru_attn_encoder.parameters()) + params.extend(self.feat_projecter.parameters()) + params.extend(self.sel_classifier.parameters()) + + if config.fine_tune_op == 'adam': + print('Use Adam') + return optim.Adam(params, lr=config.fine_tune_lr) + elif config.fine_tune_op == 'sgd': + print('Use SGD') + return optim.SGD(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum) + elif config.fine_tune_op == 'rmsprop': + print('Use RMSProp') + return optim.RMSprop(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum) + + + def model_sel_loss(self, loss, batch_cnt): + return self.valid_loss(loss, batch_cnt) + + + def extract_short_ctx(self, context, context_lens, backward_size=1): + utts = [] + for b_id in range(context.shape[0]): + utts.append(context[b_id, context_lens[b_id]-1]) + # utt = [] + # for i in range(context_lens[b_id]): + # utt.extend(context[b_id, i]) + # utts.append(utt) + return np.array(utts) + + def flatten_context(self, context, context_lens, align_right=False): + utts = [] + temp_lens = [] + for b_id in range(context.shape[0]): + temp = [] + for t_id in range(context_lens[b_id]): + for token in context[b_id, t_id]: + if token != 0: + temp.append(token) + temp_lens.append(len(temp)) + utts.append(temp) + max_temp_len = np.max(temp_lens) + results = np.zeros((context.shape[0], max_temp_len)) + for b_id in range(context.shape[0]): + if align_right: + results[b_id, -temp_lens[b_id]:] = utts[b_id] + else: + results[b_id, 0:temp_lens[b_id]] = utts[b_id] + + return results + + +def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5): + L = np.ones(n_iter) * stop + period = n_iter/n_cycle + step = (stop-start)/(period*ratio) # linear schedule + + for c in range(n_cycle): + v, i = start, 0 + while v <= stop and (int(i+c*period) < n_iter): + L[int(i+c*period)] = v + v += step + i += 1 + return L + diff --git a/latent_dialog/corpora.py b/latent_dialog/corpora.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1f8ca4e2bd46e2de1c574b65a4c2d53aa0609f --- /dev/null +++ b/latent_dialog/corpora.py @@ -0,0 +1,573 @@ +from __future__ import unicode_literals +import numpy as np +from collections import Counter +from latent_dialog.utils import Pack, get_tokenize, get_chat_tokenize, missingdict +from latent_dialog.augpt_utils import AutoDatabase, AutoLexicalizer, augpt_normalize +import json +from nltk.tokenize import WordPunctTokenizer +import logging +from collections import defaultdict +import pdb +import os +import re + +PAD = '<pad>' +UNK = '<unk>' +USR = 'YOU:' +SYS = 'THEM:' +BOD = '<d>' +EOD = '</d>' +BOS = '<s>' +EOS = '<eos>' +SEL = '<selection>' +SEP = "|" +REQ = "<requestable>" +INF = "<informable>" +WILD = "%s" +SPECIAL_TOKENS = [PAD, UNK, USR, SYS, BOS, BOD, EOS, EOD] +STOP_TOKENS = [EOS, SEL] +DECODING_MASKED_TOKENS = [PAD, UNK, USR, SYS, BOD] + +REQ_TOKENS = {} +DOMAIN_REQ_TOKEN = ['restaurant', 'hospital', 'hotel','attraction', 'train', 'police', 'taxi'] +ACTIVE_BS_IDX = [13, 30, 35, 61, 72, 91, 93] #indexes in the BS indicating if domain is active +NO_MATCH_DB_IDX = [-1, 0, -1, 6, 12, 18, -1] # indexes in DB pointer indicating 0 match is found for that domain, -1 mean that domain has no DB +REQ_TOKENS['attraction'] = ["[attraction_address]", "[attraction_name]", "[attraction_phone]", "[attraction_postcode]", "[attraction_reference]", "[attraction_type]"] +REQ_TOKENS['hospital'] = ["[hospital_address]", "[hospital_department]", "[hospital_name]", "[hospital_phone]", "[hospital_postcode]"] #, "[hospital_reference]" +REQ_TOKENS['hotel'] = ["[hotel_address]", "[hotel_name]", "[hotel_phone]", "[hotel_postcode]", "[hotel_reference]", "[hotel_type]"] +REQ_TOKENS['restaurant'] = ["[restaurant_name]", "[restaurant_address]", "[restaurant_phone]", "[restaurant_postcode]", "[restaurant_reference]"] +REQ_TOKENS['train'] = ["[train_id]", "[train_reference]"] +REQ_TOKENS['police'] = ["[police_address]", "[police_phone]", "[police_postcode]"] #"[police_name]", +REQ_TOKENS['taxi'] = ["[taxi_phone]", "[taxi_type]"] + +GENERIC_TOKENS = ["[value_area]", "[value_count]", "[value_day]", "[value_food]", "[value_place]", "[value_price]", "[value_pricerange]", "[value_time]"] + + +class NormMultiWozCorpus(object): + logger = logging.getLogger() + + def __init__(self, config): + self.bs_size = 94 + self.db_size = 30 + self.goal_size = 77 + self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b'] + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt'] + self.config = config + self.tokenize = lambda x: x.split() + self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config) + self._extract_vocab() + self._extract_goal_vocab() + self.logger.info('Loading corpus finished.') + + def _read_file(self, config): + train_data = json.load(open(config.train_path)) + valid_data = json.load(open(config.valid_path)) + test_data = json.load(open(config.test_path)) + + train_data = self._process_dialogue(train_data) + valid_data = self._process_dialogue(valid_data) + test_data = self._process_dialogue(test_data) + + return train_data, valid_data, test_data + + def _process_dialogue(self, data): + new_dlgs = [] + all_sent_lens = [] + all_dlg_lens = [] + + for key, raw_dlg in data.items(): + norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)] + for t_id in range(len(raw_dlg['db'])): + usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] + sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] + norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + all_sent_lens.extend([len(usr_utt), len(sys_utt)]) + # To stop dialog + norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + # if self.config.to_learn == 'usr': + # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + all_dlg_lens.append(len(raw_dlg['db'])) + processed_goal = self._process_goal_simple(raw_dlg['goal']) + new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) + + self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( + np.max(all_sent_lens), float(np.mean(all_sent_lens)))) + self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( + np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) + return new_dlgs + + def _extract_vocab(self): + all_words = [] + for dlg in self.train_corpus: + for turn in dlg.dlg: + all_words.extend(turn.utt) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size) + oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words)) + + self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0], + vocab_count[keep_vocab_size - 1][1]) + + 'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100)) + + vocab_count = vocab_count[0:keep_vocab_size] + self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS] + self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)} + self.unk_id = self.vocab_dict[UNK] + self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab))) + + def _process_goal(self, raw_goal): + res = {} + for domain in self.domains: + all_words = [] + d_goal = raw_goal[domain] + if d_goal: + for info_type in self.info_types: + sv_info = d_goal.get(info_type, dict()) + if info_type == 'reqt' and isinstance(sv_info, list): + all_words.extend([info_type + '|' + item for item in sv_info]) + elif isinstance(sv_info, dict): + all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()]) + else: + print('Fatal Error!') + exit(-1) + res[domain] = all_words + return res + + def _process_goal_simple(self, raw_goal): + res = {} + for domain in self.domains: + all_words = [] + if domain in raw_goal: + d_goal = raw_goal[domain] + for info_type in ['book', 'info', 'reqt']: + sv_info = d_goal.get(info_type, dict()) + if info_type == 'reqt' and isinstance(sv_info, list): + all_words.extend([info_type + '|' + item for item in sv_info]) + elif isinstance(sv_info, dict): + all_words.extend([info_type + '|' + k for k, v in sv_info.items()]) + else: + print('Fatal Error!') + exit(-1) + res[domain] = all_words + return res + + + def _extract_goal_vocab(self): + self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {} + for domain in self.domains: + all_words = [] + for dlg in self.train_corpus: + all_words.extend(dlg.goal[domain]) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + discard_wc = np.sum([c for t, c in vocab_count]) + + self.logger.info('================= domain = {}, \n'.format(domain) + + 'goal vocab size of train set = %d, \n' % (raw_vocab_size,) + + 'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) + + 'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),)) + + self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count] + self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])} + self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK] + + def get_corpus(self): + id_train = self._to_id_corpus('Train', self.train_corpus) + id_val = self._to_id_corpus('Valid', self.val_corpus) + id_test = self._to_id_corpus('Test', self.test_corpus) + return id_train, id_val, id_test + + def _to_id_corpus(self, name, data): + results = [] + for dlg in data: + if len(dlg.dlg) < 1: + continue + id_dlg = [] + for turn in dlg.dlg: + id_turn = Pack(utt=self._sent2id(turn.utt), + speaker=turn.speaker, + db=turn.db, bs=turn.bs) + id_dlg.append(id_turn) + id_goal = self._goal2id(dlg.goal) + results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) + return results + + def _sent2id(self, sent): + return [self.vocab_dict.get(t, self.unk_id) for t in sent] + + def _goal2id(self, goal): + res = {} + for domain in self.domains: + d_bow = [0.0] * len(self.goal_vocab[domain]) + for word in goal[domain]: + word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain]) + d_bow[word_id] += 1.0 + res[domain] = d_bow + return res + + def id2sent(self, id_list): + return [self.vocab[i] for i in id_list] + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +def get_summary_bstate(bstate): + """Based on the mturk annotations we form multi-domain belief state""" + domains = [u'taxi', u'restaurant', u'hospital', + u'hotel', u'attraction', u'train', u'police'] + book_keys = { + 'taxi': ['booked'], + 'police': ['booked'], + 'hospital': ['booked'], + 'restaurant': ['booked', 'time', 'day', 'people'], + 'hotel': ['booked', 'stay', 'day', 'people'], + 'attraction': ['booked'], + 'train': ['booked', 'people'] + } + + semi_keys = { + 'taxi': ['leave at', 'destination', 'departure', 'arrive by'], + 'police': [], + 'restaurant': ['food', 'price range', 'name', 'area'], + 'hospital': ['department'], + 'hotel': ['name', 'area', 'parking', 'price range', 'stars', 'internet', 'type'], + 'attraction': ['type', 'name', 'area'], + 'train': ['leave at', 'destination', 'day', 'arriveby', 'departure'], + } + # {'taxi': {'book': {'booked': []}, 'semi': {'leaveAt': '', 'destination': '', 'departure': '', 'arriveBy': ''}}, + # 'police': {'book': {'booked': []}, 'semi': {}}, + # 'restaurant': {'book': {'booked': [], 'time': '', 'day': '', 'people': ''}, 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, + # 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, + # 'hotel': {'book': {'booked': [], 'stay': '', 'day': '', 'people': ''}, 'semi': {'name': 'not mentioned', 'area': 'not mentioned', 'parking': 'not mentioned', 'pricerange': 'cheap', 'stars': 'not mentioned', 'internet': 'not mentioned', 'type': 'hotel'}}, + #'attraction': {'book': {'booked': []}, 'semi': {'type': '', 'name': '', 'area': ''}}, + #'train': {'book': {'booked': [], 'people': ''}, 'semi': {'leaveAt': '', 'destination': '', 'day': '', 'arriveBy': '', 'departure': ''}}} + summary_bstate = [] + pdb.set_trace() + for domain in domains: + domain_active = False + + booking = [] + if domain in bstate: + # for slot in sorted(bstate[domain]['book'].keys()): + for slot in sorted(book_keys[domain]): + if slot == 'booked': + if bstate[domain]['book']['booked']: + booking.append(1) + else: + booking.append(0) + else: + if bstate[domain]['book'][slot] != "": + booking.append(1) + else: + booking.append(0) + if domain == 'train': + if 'people' not in bstate[domain]['book'].keys(): + booking.append(0) + if 'ticket' not in bstate[domain]['book'].keys(): + booking.append(0) + else: + if domain == "train": + booking = [0, 0, 0] + else: + booking = [0] + summary_bstate += booking + + # for slot in bstate[domain]['semi']: + for slot in semi_keys(domain): + slot_enc = [0, 0, 0] + if slot in bstate[domain]: + slot_enc[2] = 1 + else: + slot_enc[0] = 1 + # if bstate[domain]['semi'][slot] == 'not mentioned': + # slot_enc[0] = 1 + # elif bstate[domain]['semi'][slot] == 'dont care' or bstate[domain]['semi'][slot] == 'dontcare' or bstate[domain]['semi'][slot] == "don't care": + # slot_enc[1] = 1 + # elif bstate[domain]['semi'][slot]: + # slot_enc[2] = 1 + if slot_enc != [0, 0, 0]: + domain_active = True + summary_bstate += slot_enc + + # quasi domain-tracker + if domain_active: + summary_bstate += [1] + else: + summary_bstate += [0] + + + # print(len(summary_bstate)) + assert len(summary_bstate) == 94 + return summary_bstate + +def addDBPointer(db_text, booked_text): + """Create database pointer for all related domains.""" + matches = { + "restaurant": [0, 0, 0, 0, 0, 1], + "hotel" : [0, 0, 0, 0, 0, 1], + "attraction" : [0, 0, 0, 0, 0, 1], + "train" : [0, 0, 0, 0, 0, 1] + } + + booked = { + "restaurant": [1, 0], + "hotel": [1, 0], + "train": [1, 0] + } + + + if booked_text[0] != "none": + for domain in booked_text: + booked[domain] = [0, 1] + + + if len(db_text) > 0: + for i in range(int(len(db_text)/2)): + dom = db_text[i * 2] + num = db_text[i * 2 + 1] + if dom != "train": + if num == "0": + matches[dom] = [1, 0, 0, 0, 0, 0] + elif num == "1": + matches[dom] = [0, 1, 0, 0, 0, 0] + elif num == "2": + matches[dom] = [0, 0, 1, 0, 0, 0] + elif num == "3": + matches[dom] = [0, 0, 0, 1, 0, 0] + elif num == "4": + matches[dom] = [0, 0, 0, 0, 1, 0] + else: + if num == "0": + matches[dom] = [1, 0, 0, 0, 0, 0] + elif int(num) <= 2: + matches[dom] = [0, 1, 0, 0, 0, 0] + elif int(num) <= 5: + matches[dom] = [0, 0, 1, 0, 0, 0] + elif int(num) <= 10: + matches[dom] = [0, 0, 0, 1, 0, 0] + elif int(num) <= 40: + matches[dom] = [0, 0, 0, 0, 1, 0] + + + metadata = matches["restaurant"] + matches["hotel"] + matches["attraction"] + matches["train"] + booked["restaurant"] + booked["hotel"] + booked["train"] + + return metadata + +class NormMultiWozCorpusAE(object): + logger = logging.getLogger() + + def __init__(self, config): + self.bs_size = 94 + self.db_size = 30 + self.goal_size = 77 + self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b'] + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt'] + # self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome'] + # self.act2id = {a:i for i, a in enumerate(self.act_types)} + # self.id2act = {i:a for i, a in enumerate(self.act_types)} + # self.act_size = len(self.act_types) #domain agnostic act + # self.act_size = len(domain) * len(self.act_types) #domain dependent act + self.config = config + self.tokenize = lambda x: x.split() + self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config) + self._extract_vocab() + self._extract_goal_vocab() + self.logger.info('Loading corpus finished.') + + def _read_file(self, config): + train_data = json.load(open(config.train_path)) + valid_data = json.load(open(config.valid_path)) + test_data = json.load(open(config.test_path)) + + train_data = self._process_dialogue(train_data) + valid_data = self._process_dialogue(valid_data) + test_data = self._process_dialogue(test_data) + + return train_data, valid_data, test_data + + def _process_dialogue(self, data): + new_dlgs = [] + all_sent_lens = [] + all_dlg_lens = [] + + for key, raw_dlg in data.items(): + norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)] + for t_id in range(len(raw_dlg['db'])): + usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] + sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] + norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) + all_sent_lens.extend([len(usr_utt), len(sys_utt)]) + # To stop dialog + norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + # if self.config.to_learn == 'usr': + # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) + all_dlg_lens.append(len(raw_dlg['db'])) + processed_goal = self._process_goal(raw_dlg['goal']) + new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) + + self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( + np.max(all_sent_lens), float(np.mean(all_sent_lens)))) + self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( + np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) + return new_dlgs + + def _extract_vocab(self): + all_words = [] + for dlg in self.train_corpus: + for turn in dlg.dlg: + all_words.extend(turn.utt) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size) + oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words)) + + self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0], + vocab_count[keep_vocab_size - 1][1]) + + 'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100)) + + vocab_count = vocab_count[0:keep_vocab_size] + self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS] + self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)} + self.unk_id = self.vocab_dict[UNK] + self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab))) + + def _process_goal(self, raw_goal): + res = {} + for domain in self.domains: + all_words = [] + d_goal = raw_goal[domain] + if d_goal: + for info_type in self.info_types: + sv_info = d_goal.get(info_type, dict()) + if info_type == 'reqt' and isinstance(sv_info, list): + all_words.extend([info_type + '|' + item for item in sv_info]) + elif isinstance(sv_info, dict): + all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()]) + else: + print('Fatal Error!') + exit(-1) + res[domain] = all_words + return res + + def _process_multidomain_summary_acts(self, dact): + """ + process dialogue action dictionary into binary vector representation + each domain has its own vector, and final output is the flattened respresentation of each domain's action + """ + res = {} + # dact = {domain:{action:[slot]}, domain:{action:[slot]}} + for domain in self.domains: + res[domain] = np.zeros(len(self.act_types)) + if domain in dact.keys(): + for i in range(len(self.act_types)): + if self.act_types[i] in dact[domain].keys(): + res[domain][i] = 1 + + + # multiwoz dact = {domain-act:[[slot, value], [slot, value]]} + # for domain in self.domains: + # res[domain] = np.zeros(len(self.act_types)) + # for k in dact.keys(): + # d = k.split("-")[0].lower() + # a = k.split("-")[1].lower() + + # res[d][self.act2id[a]] = 1 + + flat_res = [act for domain in sorted(self.domains) for act in res[domain]] + return flat_res + + def _process_summary_acts(self, dact): + """ + process dialogue action dictionary into binary vector representation, ignoring domain information + """ + res = np.zeros(len(self.act_types)) + # damd dact = {domain:{action:[slot]}, domain:{action:[slot]}} + for domain in self.domains: + if domain in dact.keys(): + for i in range(len(self.act_types)): + if self.act_types[i] in dact[domain].keys(): + res[i] = 1 + + # multiwoz dact = {domain-act:[[slot, value], [slot, value]]} + # for k in dact.keys(): + # # d = k.split("-")[0].lower() + # a = k.split("-")[1].lower() + + # res[self.act2id[a]] = 1 + + return list(res) + + def _extract_goal_vocab(self): + self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {} + for domain in self.domains: + all_words = [] + for dlg in self.train_corpus: + all_words.extend(dlg.goal[domain]) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = len(vocab_count) + discard_wc = np.sum([c for t, c in vocab_count]) + + self.logger.info('================= domain = {}, \n'.format(domain) + + 'goal vocab size of train set = %d, \n' % (raw_vocab_size,) + + 'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) + + 'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),)) + + self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count] + self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])} + self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK] + + def get_corpus(self): + id_train = self._to_id_corpus('Train', self.train_corpus) + id_val = self._to_id_corpus('Valid', self.val_corpus) + id_test = self._to_id_corpus('Test', self.test_corpus) + return id_train, id_val, id_test + + def _to_id_corpus(self, name, data): + results = [] + for dlg in data: + if len(dlg.dlg) < 1: + continue + id_dlg = [] + for turn in dlg.dlg: + id_turn = Pack(utt=self._sent2id(turn.utt), + speaker=turn.speaker, + db=turn.db, bs=turn.bs) #, act=turn.act) + id_dlg.append(id_turn) + id_goal = self._goal2id(dlg.goal) + results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) + return results + + def _sent2id(self, sent): + return [self.vocab_dict.get(t, self.unk_id) for t in sent] + + def _goal2id(self, goal): + res = {} + for domain in self.domains: + d_bow = [0.0] * len(self.goal_vocab[domain]) + for word in goal[domain]: + word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain]) + d_bow[word_id] += 1.0 + res[domain] = d_bow + return res + + def id2sent(self, id_list): + return [self.vocab[i] for i in id_list] + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + diff --git a/latent_dialog/criterions.py b/latent_dialog/criterions.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d1e78b8aa8aeb80fe5108fb481a95c804822f3 --- /dev/null +++ b/latent_dialog/criterions.py @@ -0,0 +1,184 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +import numpy as np +import pdb +import math +from latent_dialog import domain +from latent_dialog.utils import LONG + +class NLLEntropy(_Loss): + def __init__(self, padding_idx, avg_type): + super(NLLEntropy, self).__init__() + self.padding_idx = padding_idx + self.avg_type = avg_type + + def set_weight(self, weight): + self.weight = weight + + def forward(self, net_output, labels, weight=None): + batch_size = net_output.size(0) + pred = net_output.view(-1, net_output.size(-1)) + target = labels.view(-1) + + if self.avg_type is None: + loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx) + elif self.avg_type == 'seq': + loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx) + loss = loss / batch_size + elif self.avg_type == 'real_word': + loss = F.nll_loss(pred, target, ignore_index=self.padding_idx, reduce=False) + loss = loss.view(-1, net_output.size(1)) + loss = th.sum(loss, dim=1) + word_cnt = th.sum(th.sign(labels), dim=1).float() + loss = loss / word_cnt + loss = th.mean(loss) + elif self.avg_type == 'word': + loss = F.nll_loss(pred, target, reduction='mean', ignore_index=self.padding_idx) + elif self.avg_type == 'weighted': + loss = F.nll_loss(pred, target, weight=self.weight, reduction='mean', ignore_index=self.padding_idx) + else: + raise ValueError('Unknown average type') + + return loss + +class NLLEntropy4CLF(_Loss): + def __init__(self, dictionary, bad_tokens=['<disconnect>', '<disagree>'], reduction='elementwise_mean'): + super(NLLEntropy4CLF, self).__init__() + w = th.Tensor(len(dictionary)).fill_(1) + for token in bad_tokens: + w[dictionary[token]] = 0.0 + self.crit = nn.CrossEntropyLoss(w, reduction=reduction) + + def forward(self, preds, labels): + # preds: (batch_size, outcome_len, outcome_vocab_size) + # labels: (batch_size, outcome_len) + preds = preds.view(-1, preds.size(-1)) + labels = labels.view(-1) + return self.crit(preds, labels) + +class CombinedNLLEntropy4CLF(_Loss): + def __init__(self, dictionary, corpus, np2var, bad_tokens=['<disconnect>', '<disagree>']): + super(CombinedNLLEntropy4CLF, self).__init__() + self.dictionary = dictionary + self.domain = domain.get_domain('object_division') + self.corpus = corpus + self.np2var = np2var + self.bad_tokens = bad_tokens + + def forward(self, preds, goals_id, outcomes_id): + # preds: (batch_size, outcome_len, outcome_vocab_size) + # goals_id: list of list, id, batch_size*goal_len + # outcomes_id: list of list, id, batch_size*outcome_len + batch_size = len(goals_id) + losses = [] + for bth in range(batch_size): + pred = preds[bth] # (outcome_len, outcome_vocab_size) + goal = goals_id[bth] # list, id, len=goal_len + goal_str = self.corpus.id2goal(goal) # list, str, len=goal_len + outcome = outcomes_id[bth] # list, id, len=outcome_len + outcome_str = self.corpus.id2outcome(outcome) # list, str, len=outcome_len + + if outcome_str[0] in self.bad_tokens: + continue + + # get all the possible choices + choices = self.domain.generate_choices(goal_str) + sel_outs = [pred[i] for i in range(pred.size(0))] # outcome_len*(outcome_vocab_size, ) + + choices_logits = [] # outcome_len*(option_amount, 1) + for i in range(self.domain.selection_length()): + idxs = np.array([self.dictionary[c[i]] for c in choices]) + idxs_var = self.np2var(idxs, LONG) # (option_amount, ) + choices_logits.append(th.gather(sel_outs[i], 0, idxs_var).unsqueeze(1)) + + choice_logit = th.sum(th.cat(choices_logits, 1), 1, keepdim=False) # (option_amount, ) + choice_logit = choice_logit.sub(choice_logit.max().item()) # (option_amount, ) + prob = F.softmax(choice_logit, dim=0) # (option_amount, ) + + label = choices.index(outcome_str) + target_prob = prob[label] + losses.append(-th.log(target_prob)) + return sum(losses) / float(len(losses)) + +class CatKLLoss(_Loss): + def __init__(self): + super(CatKLLoss, self).__init__() + + def forward(self, log_qy, log_py, batch_size=None, unit_average=False): + """ + qy * log(q(y)/p(y)) + """ + qy = th.exp(log_qy) + y_kl = th.sum(qy * (log_qy - log_py), dim=1) + if unit_average: + return th.mean(y_kl) + else: + return th.sum(y_kl)/batch_size + +class Entropy(_Loss): + def __init__(self): + super(Entropy, self).__init__() + + def forward(self, log_qy, batch_size=None, unit_average=False): + """ + -qy log(qy) + """ + if log_qy.dim() > 2: + log_qy = log_qy.squeeze() + qy = th.exp(log_qy) + h_q = th.sum(-1 * log_qy * qy, dim=1) + if unit_average: + return th.mean(h_q) + else: + return th.sum(h_q) / batch_size + +class GaussianEntropy(_Loss): + def __init__(self): + super(GaussianEntropy, self).__init__() + + def forward(self, mu, logvar): + """ + 0.5 (log(mu*var)) + 0.5 + """ + std = th.exp(0.5 * logvar) + var = th.square(std) + h_q = 0.5 * (th.log(2 * math.pi * var)) + 0.5 + + return th.mean(h_q) + +class BinaryNLLEntropy(_Loss): + + def __init__(self, size_average=True): + super(BinaryNLLEntropy, self).__init__() + self.size_average = size_average + + def forward(self, net_output, label_output): + """ + :param net_output: batch_size x + :param labels: + :return: + """ + batch_size = net_output.size(0) + loss = F.binary_cross_entropy_with_logits(net_output, label_output, size_average=self.size_average) + if self.size_average is False: + loss /= batch_size + return loss + +class NormKLLoss(_Loss): + def __init__(self, unit_average=False): + super(NormKLLoss, self).__init__() + self.unit_average = unit_average + + def forward(self, recog_mu, recog_logvar, prior_mu, prior_logvar): + # find the KL divergence between two Gaussian distribution + loss = 1.0 + (recog_logvar - prior_logvar) + loss -= th.div(th.pow(prior_mu - recog_mu, 2), th.exp(prior_logvar)) + loss -= th.div(th.exp(recog_logvar), th.exp(prior_logvar)) + if self.unit_average: + kl_loss = -0.5 * th.mean(loss, dim=1) + else: + kl_loss = -0.5 * th.sum(loss, dim=1) + avg_kl_loss = th.mean(kl_loss) + return avg_kl_loss diff --git a/latent_dialog/data_loaders.py b/latent_dialog/data_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..212d4fe871bdfc8a6f587219284d7d3e7aee0a87 --- /dev/null +++ b/latent_dialog/data_loaders.py @@ -0,0 +1,434 @@ +import numpy as np +import copy +import random +import traceback +from enum import Enum +from collections import defaultdict +import pdb +from latent_dialog.utils import Pack +from latent_dialog.base_data_loaders import BaseDataLoaders, LongDataLoader +from latent_dialog.corpora import USR, SYS, UNK, SPECIAL_TOKENS +import json + + +class NoiseType(Enum): + REMOVAL = "tokens_removal" + CHANGE = "tokens_switch" + +class BeliefDbDataLoaders(BaseDataLoaders): + def __init__(self, name, data, config, pad_context=True, noise=None, ind_voc=None, logger=None): + super(BeliefDbDataLoaders, self).__init__(name) + self.max_utt_len = config.max_utt_len + self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size) + self.data_size = len(self.data) + self.pad_context = pad_context + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + + self.noise_type = noise + + if self.noise_type is not None: + self.noise_p = config.noise_p + self.remove_tokens = config.remove_tokens + + self.ind_voc = copy.deepcopy(ind_voc) + if config.no_special: + for token in SPECIAL_TOKENS: + if token != UNK: + del self.ind_voc[token] + self.tokens = list(self.ind_voc.keys()) + + self.num_noised_tokens = 0 + self.noised_tokens_dist = defaultdict(int) + self.num_examples = 0 + + self.logger = logger + + + def flatten_dialog( self, data, backward_size): + results = [] + indexes = [] + batch_indexes = [] + resp_set = set() + for dlg in data: + goal = dlg.goal + key = dlg.key + batch_index = [] + for i in range(1, len(dlg.dlg)): + if dlg.dlg[i].speaker == USR or dlg.dlg[i].speaker == "user": + continue + e_idx = i + s_idx = max(0, e_idx - backward_size) + response = dlg.dlg[i].copy() + response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) + resp_set.add(json.dumps(response.utt)) + context = [] + for turn in dlg.dlg[s_idx: e_idx]: + turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) + context.append(turn) + results.append(Pack(context=context, response=response, goal=goal, key=key)) + indexes.append(len(indexes)) + batch_index.append(indexes[-1]) + if len(batch_index) > 0: + batch_indexes.append(batch_index) + print("Unique resp {}".format(len(resp_set))) + return results, indexes, batch_indexes + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + if fix_batch: + self.batch_size = None + self.num_batch = len(self.batch_indexes) + else: + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size]) + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + if shuffle: + if fix_batch: + self._shuffle_batch_indexes() + else: + self._shuffle_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def _prepare_batch(self, selected_index): + rows = [self.data[idx] for idx in selected_index] + + ctx_utts, ctx_lens = [], [] + out_utts, out_lens = [], [] + + out_bs, out_db = [] , [] + raw_bs = [] + goals, goal_lens = [], [[] for _ in range(len(self.domains))] + keys = [] + + for row in rows: + in_row, out_row, goal_row = row.context, row.response, row.goal + + # source context + keys.append(row.key) + batch_ctx = [] + for turn in in_row: + ctx_utt = copy.deepcopy(turn.utt) + + if self.noise_type is not None: + try: + ctx_utt = self._add_noise(ctx_utt) + except Exception as error: + self.logger.warn(traceback.print_exc()) + + batch_ctx.append(self.pad_to(self.max_utt_len, ctx_utt, do_pad=self.pad_context)) + + ctx_utts.append(batch_ctx) + ctx_lens.append(len(batch_ctx)) + + # target response + out_utt = [t for idx, t in enumerate(out_row.utt)] + out_utts.append(out_utt) + out_lens.append(len(out_utt)) + + out_bs.append(out_row.bs) + out_db.append(out_row.db) + + if "raw_bs" in out_row: + raw_bs.append(out_row["raw_bs"]) + + # goal + goals.append(goal_row) + for i, d in enumerate(self.domains): + goal_lens[i].append(len(goal_row[d])) + + batch_size = len(ctx_lens) + vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns + max_ctx_len = np.max(vec_ctx_lens) + if self.pad_context: + vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) + else: + vec_ctx_utts = [] + vec_out_bs = np.array(out_bs) # (batch_size, 94) + vec_out_db = np.array(out_db) # (batch_size, 30) + vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens + max_out_len = np.max(vec_out_lens) + vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) + + max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens] + if max_goal_lens != min_goal_lens: + print('Fatal Error!') + exit(-1) + self.goal_lens = max_goal_lens + vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens] + + for b_id in range(batch_size): + if self.pad_context: + vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] + else: + vec_ctx_utts.append(ctx_utts[b_id]) + + vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] + for i, d in enumerate(self.domains): + vec_goals_list[i][b_id, :] = goals[b_id][d] + + return Pack(context_lens=vec_ctx_lens, # (batch_size, ) + contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) + output_lens=vec_out_lens, # (batch_size, ) + outputs=vec_out_utts, # (batch_size, max_out_len) + bs=vec_out_bs, # (batch_size, 94) + raw_bs=raw_bs, + db=vec_out_db, # (batch_size, 30) + goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain + keys=keys) + + def _add_noise(self, sequence): + num_tokens = np.ceil(self.noise_p * (len(sequence)-2)) + indices = np.random.choice(range(1, len(sequence)-1), size=int(num_tokens), replace=False) + self.num_noised_tokens += num_tokens + self.noised_tokens_dist[num_tokens] += 1 + self.num_examples += 1 + + if self.noise_type == NoiseType.REMOVAL.value: + return self._remove_tokens(sequence, indices) + elif self.noise_type == NoiseType.CHANGE.value: + return self._change_tokens(sequence, indices) + + def _remove_tokens(self, sequence, indices): + for index in indices: + sequence[index] = self.ind_voc[UNK] + + if self.remove_tokens: + sequence = list(filter(lambda x: x != self.ind_voc[UNK], sequence)) + + return sequence + + def _change_tokens(self, sequence, indices): + for index in indices: + new_token_index = sequence[index] + + while new_token_index == sequence[index]: + ind = int(len(self.tokens) * random.random()) + new_token_index = self.ind_voc[self.tokens[ind]] + + sequence[index] = new_token_index + + return sequence + +class BeliefDbDataLoadersAE(BaseDataLoaders): + def __init__(self, name, data, config, noise=None, ind_voc=None, logger=None): + super(BeliefDbDataLoadersAE, self).__init__(name) + self.max_utt_len = config.max_utt_len + self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size) + self.data_size = len(self.data) + self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi'] + # self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome'] + if "ae_zero_pad" in config.keys(): + self.zero_pad = config.ae_zero_pad + else: + self.zero_pad = False + + self.noise_type = noise + + if self.noise_type is not None: + self.noise_p = config.noise_p + self.remove_tokens = config.remove_tokens + + self.ind_voc = copy.deepcopy(ind_voc) + if config.no_special: + for token in SPECIAL_TOKENS: + if token != UNK: + del self.ind_voc[token] + self.tokens = list(self.ind_voc.keys()) + + self.num_noised_tokens = 0 + self.noised_tokens_dist = defaultdict(int) + self.num_examples = 0 + + self.logger = logger + + + + def flatten_dialog(self, data, backward_size): + results = [] + indexes = [] + batch_indexes = [] + resp_set = set() + for dlg in data: + goal = dlg.goal + key = dlg.key + batch_index = [] + for i in range(1, len(dlg.dlg)): + if dlg.dlg[i].speaker == USR: + continue + e_idx = i + s_idx = max(0, e_idx - backward_size) + response = dlg.dlg[i].copy() + response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) + resp_set.add(json.dumps(response.utt)) + context = [] + for turn in dlg.dlg[s_idx: e_idx]: + turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) + context.append(turn) + results.append(Pack(context=context, response=response, goal=goal, key=key)) + indexes.append(len(indexes)) + batch_index.append(indexes[-1]) + if len(batch_index) > 0: + batch_indexes.append(batch_index) + print("Unique resp {}".format(len(resp_set))) + return results, indexes, batch_indexes + + def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False): + self.ptr = 0 + if fix_batch: + self.batch_size = None + self.num_batch = len(self.batch_indexes) + else: + self.batch_size = config.batch_size + self.num_batch = self.data_size // config.batch_size + self.batch_indexes = [] + for i in range(self.num_batch): + self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size]) + if verbose: + print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch)) + if shuffle: + if fix_batch: + self._shuffle_batch_indexes() + else: + self._shuffle_indexes() + + if verbose: + print('%s begins with %d batches' % (self.name, self.num_batch)) + + def _prepare_batch(self, selected_index): + rows = [self.data[idx] for idx in selected_index] + + ctx_utts, ctx_lens = [], [] + out_utts, out_lens = [], [] + # out_act = [] + out_bs, out_db = [] , [] + goals, goal_lens = [], [[] for _ in range(len(self.domains))] + keys = [] + + raw_bs = [] + + for row in rows: + in_row, out_row, goal_row = row.context, row.response, row.goal + + # source context + keys.append(row.key) + + # batch_ctx = [] + # for turn in in_row: + # batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) + # ctx_utts.append(batch_ctx) + # ctx_lens.append(len(batch_ctx)) + + # target response + out_utt = [t for idx, t in enumerate(out_row.utt)] + out_utts.append(out_utt) + out_lens.append(len(out_utt)) + + if not self.zero_pad: + out_bs.append(out_row.bs) + out_db.append(out_row.db) + else: + out_bs.append([0] * 94) + out_db.append([0] * 30) + # out_act.append(out_row.act) + + # for AE, input = output + # print(out_row.utt) + ctx_utt = copy.deepcopy(out_row.utt) + if self.noise_type is not None: + try: + ctx_utt = self._add_noise(ctx_utt) + except Exception as error: + self.logger.warn(traceback.print_exc()) + + batch_ctx = self.pad_to(self.max_utt_len, ctx_utt, do_pad=True) + + # print(out_row.utt) + # pdb.set_trace() + + ctx_utts.append(batch_ctx) + ctx_lens.append(len(batch_ctx)) + + + if "raw_bs" in out_row: + raw_bs.append(out_row["raw_bs"]) + + # goal + goals.append(goal_row) + for i, d in enumerate(self.domains): + goal_lens[i].append(len(goal_row[d])) + + + batch_size = len(ctx_lens) + vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns + max_ctx_len = np.max(vec_ctx_lens) + vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) + vec_out_bs = np.array(out_bs) # (batch_size, 94) + vec_out_db = np.array(out_db) # (batch_size, 30) + # vec_out_act = np.array(out_act) # (batch_size, 11) + vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens + max_out_len = np.max(vec_out_lens) + vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) + + max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens] + if max_goal_lens != min_goal_lens: + print('Fatal Error!') + exit(-1) + self.goal_lens = max_goal_lens + vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens] + + for b_id in range(batch_size): + vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] + vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] + for i, d in enumerate(self.domains): + vec_goals_list[i][b_id, :] = goals[b_id][d] + + return Pack(context_lens=vec_ctx_lens, # (batch_size, ) + contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) + output_lens=vec_out_lens, # (batch_size, ) + outputs=vec_out_utts, # (batch_size, max_out_len) + raw_bs=raw_bs, + bs=vec_out_bs, # (batch_size, 94) + db=vec_out_db, # (batch_size, 30) + # act=vec_out_act, #(batch_size, 11) + goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain + keys=keys) + + def _add_noise(self, sequence): + num_tokens = np.ceil(self.noise_p * (len(sequence)-2)) + indices = np.random.choice(range(1, len(sequence)-1), size=int(num_tokens), replace=False) + self.num_noised_tokens += num_tokens + self.noised_tokens_dist[num_tokens] += 1 + self.num_examples += 1 + + if self.noise_type == NoiseType.REMOVAL.value: + return self._remove_tokens(sequence, indices) + elif self.noise_type == NoiseType.CHANGE.value: + return self._change_tokens(sequence, indices) + + def _remove_tokens(self, sequence, indices): + for index in indices: + sequence[index] = self.ind_voc[UNK] + + if self.remove_tokens: + sequence = list(filter(lambda x: x != self.ind_voc[UNK], sequence)) + + return sequence + + def _change_tokens(self, sequence, indices): + for index in indices: + new_token_index = sequence[index] + + while new_token_index == sequence[index]: + ind = int(len(self.tokens) * random.random()) + new_token_index = self.ind_voc[self.tokens[ind]] + + sequence[index] = new_token_index + + return sequence + diff --git a/latent_dialog/dialog_task.py b/latent_dialog/dialog_task.py new file mode 100644 index 0000000000000000000000000000000000000000..f985d83bb985bb2a48626616ba34d34f714e0b72 --- /dev/null +++ b/latent_dialog/dialog_task.py @@ -0,0 +1,129 @@ +from latent_dialog.metric import MetricsContainer +from latent_dialog.corpora import EOD, EOS +from latent_dialog import evaluators + + +class Dialog(object): + """Dialogue runner.""" + def __init__(self, agents, args): + assert len(agents) == 2 + self.agents = agents + self.system, self.user = agents + self.args = args + self.metrics = MetricsContainer() + self.dlg_evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', args) + self._register_metrics() + + def _register_metrics(self): + """Registers valuable metrics.""" + self.metrics.register_average('dialog_len') + self.metrics.register_average('sent_len') + self.metrics.register_average('reward') + self.metrics.register_time('time') + + def _is_eod(self, out): + return len(out) == 2 and out[0] == EOD and out[1] == EOS + + def _eval_dialog(self, conv, g_key, goal): + generated_dialog = dict() + generated_dialog[g_key] = {'goal': goal, 'log': list()} + for t_id, (name, utt) in enumerate(conv): + # assert utt[-1] == EOS, utt + if t_id % 2 == 0: + assert name == 'Baozi' + utt = ' '.join(utt[:-1]) + if utt == EOD: + continue + generated_dialog[g_key]['log'].append({'text': utt}) + report, success_r, match_r = self.dlg_evaluator.evaluateModel(generated_dialog, mode='rollout') + return success_r + match_r + + def show_metrics(self): + return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()]) + + def run(self, g_key, goal): + """Runs one instance of the dialogue.""" + # initialize agents by feeding in the goal + # initialize BOD utterance for each agent + for agent in self.agents: + agent.feed_goal(goal) + agent.bod_init() + + # role assignment + reader, writer = self.system, self.user + begin_name = writer.name + print('begin_name = {}'.format(begin_name)) + + conv = [] + # reset metrics + self.metrics.reset() + nturn = 0 + while True: + nturn += 1 + # produce an utterance + out_words = writer.write() # out: list of word, str, len = max_words + print('\t{} out_words = {}'.format(writer.name, ' '.join(out_words))) + + self.metrics.record('sent_len', len(out_words)) + # self.metrics.record('%s_unique' % writer.name, out_words) + + # append the utterance to the conversation + conv.append((writer.name, out_words)) + # make the other agent to read it + reader.read(out_words) + # check if the end of the conversation was generated + if self._is_eod(out_words): + break + + if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn: + # return conv, 0 + break + + writer, reader = reader, writer + + # evaluate dialog and produce success + reward = self._eval_dialog(conv, g_key, goal) + print('Reward = {}'.format(reward)) + # perform update + self.system.update(reward) + self.metrics.record('time') + self.metrics.record('dialog_len', len(conv)) + self.metrics.record('reward', int(reward)) + + print('='*50) + print(self.show_metrics()) + print('='*50) + return conv, reward + + +class DialogEval(Dialog): + def run(self, g_key, goal): + """Runs one instance of the dialogue.""" + # initialize agents by feeding in the goal + # initialize BOD utterance for each agent + for agent in self.agents: + agent.feed_goal(goal) + agent.bod_init() + + # role assignment + reader, writer = self.system, self.user + conv = [] + nturn = 0 + while True: + nturn += 1 + # produce an utterance + out_words = writer.write() # out: list of word, str, len = max_words + conv.append((writer.name, out_words)) + # make the other agent to read it + reader.read(out_words) + # check if the end of the conversation was generated + if self._is_eod(out_words): + break + + writer, reader = reader, writer + if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn: + return conv, 0 + + # evaluate dialog and produce success + reward = self._eval_dialog(conv, g_key, goal) + return conv, reward diff --git a/latent_dialog/domain.py b/latent_dialog/domain.py new file mode 100644 index 0000000000000000000000000000000000000000..43d3ff99bcf99042bc65ec3f3fc0da96734b391b --- /dev/null +++ b/latent_dialog/domain.py @@ -0,0 +1,124 @@ +import re +import random +import json + + +def get_domain(name): + if name == 'object_division': + return ObjectDivisionDomain() + raise() + + +class ObjectDivisionDomain(object): + def __init__(self): + self.item_pattern = re.compile('^item([0-9])=([0-9\-])+$') + + def input_length(self): + return 3 + + def selection_length(self): + return 6 + + def generate_choices(self, inpt): + cnts, _ = self.parse_context(inpt) + + def gen(cnts, idx=0, choice=[]): + if idx >= len(cnts): + left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)] + right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))] + return [left_choice + right_choice] + choices = [] + for c in range(cnts[idx] + 1): + choice.append(c) + choices += gen(cnts, idx + 1, choice) + choice.pop() + return choices + choices = gen(cnts) + choices.append(['<no_agreement>'] * self.selection_length()) + choices.append(['<disconnect>'] * self.selection_length()) + return choices + + def parse_context(self, ctx): + cnts = [int(n) for n in ctx[0::2]] + vals = [int(v) for v in ctx[1::2]] + return cnts, vals + + def _to_int(self, x): + try: + return int(x) + except: + return 0 + + def score_choices(self, choices, ctxs): + assert len(choices) == len(ctxs) + # print('choices = {}'.format(choices)) + # print('ctxs = {}'.format(ctxs)) + cnts = [int(x) for x in ctxs[0][0::2]] + agree, scores = True, [0 for _ in range(len(ctxs))] + for i, n in enumerate(cnts): + for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)): + # taken = self._to_int(choice[i+3][-1]) + taken = self._to_int(choice[i][-1]) + n -= taken + scores[agent_id] += int(ctx[2 * i + 1]) * taken + agree = agree and (n == 0) + return agree, scores + + +class ContextGenerator(object): + """Dialogue context generator. Generates contexes from the file.""" + def __init__(self, context_file): + self.ctxs = [] + with open(context_file, 'r') as f: + ctx_pair = [] + for line in f: + ctx = line.strip().split() + ctx_pair.append(ctx) + if len(ctx_pair) == 2: + self.ctxs.append(ctx_pair) + ctx_pair = [] + + def sample(self): + return random.choice(self.ctxs) + + def iter(self, nepoch=1): + for e in range(nepoch): + random.shuffle(self.ctxs) + for ctx in self.ctxs: + yield ctx + + def total_size(self, nepoch): + return nepoch*len(self.ctxs) + + +class ContextGeneratorEval(object): + """Dialogue context generator. Generates contexes from the file.""" + def __init__(self, context_file): + self.ctxs = [] + with open(context_file, 'r') as f: + ctx_pair = [] + for line in f: + ctx = line.strip().split() + ctx_pair.append(ctx) + if len(ctx_pair) == 2: + self.ctxs.append(ctx_pair) + ctx_pair = [] + + +class TaskGoalGenerator(object): + def __init__(self, goal_file): + self.goals = [] + data = json.load(open(goal_file)) + for key, raw_dlg in data.items(): + self.goals.append((key, raw_dlg['goal'])) + + def sample(self): + return random.choice(self.goals) + + def iter(self, nepoch=1): + for e in range(nepoch): + random.shuffle(self.goals) + for goal in self.goals: + yield goal + + diff --git a/latent_dialog/enc2dec/__init__.py b/latent_dialog/enc2dec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6caafab6a74e286872551db1a2edcc848eb00fd0 --- /dev/null +++ b/latent_dialog/enc2dec/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Author: Tiancheng Zhao +# Date: 9/15/18 diff --git a/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e5601a38d9b18e633e2c8eb076ee990266f342a Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc differ diff --git a/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dc8e9504000fe70600b605639cc96f6a2bc5197 Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc differ diff --git a/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c65603128e11acfb9750a779b4dae7be502e9f79 Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc differ diff --git a/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cf032d687ea874ca05a9f135d616435cd10234 Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc differ diff --git a/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..887fe1642e54c64711c41b5f6aeb07360d40607b Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc differ diff --git a/latent_dialog/enc2dec/base_modules.py b/latent_dialog/enc2dec/base_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7d98601596f147a25aa47332483f5502eaf9d3 --- /dev/null +++ b/latent_dialog/enc2dec/base_modules.py @@ -0,0 +1,68 @@ +import torch as th +import torch.nn as nn +import numpy as np +from torch.nn.modules.module import _addindent + +def summary(model, show_weights=True, show_parameters=True): + """ + Summarizes torch model by showing trainable parameters and weights. + """ + tmpstr = model.__class__.__name__ + ' (\n' + total_params = 0 + for key, module in model._modules.items(): + # if it contains layers let call it recursively to get params + # and weights + if type(module) in [ + th.nn.modules.container.Container, + th.nn.modules.container.Sequential + ]: + modstr = summary(module) + else: + modstr = module.__repr__() + modstr = _addindent(modstr, 2) + + params = sum([np.prod(p.size()) for p in module.parameters()]) + weights = tuple([tuple(p.size()) for p in module.parameters()]) + total_params += params + + tmpstr += ' (' + key + '): ' + modstr + if show_weights: + tmpstr += ', weights={}'.format(weights) + if show_parameters: + tmpstr += ', parameters={}'.format(params) + tmpstr += '\n' + + tmpstr = tmpstr + ') Total Parameters={}'.format(total_params) + return tmpstr + + +class BaseRNN(nn.Module): + KEY_ATTN_SCORE = 'attention_score' + KEY_SEQUENCE = 'sequence' + + def __init__(self, input_dropout_p, rnn_cell, + input_size, hidden_size, num_layers, + output_dropout_p, bidirectional): + super(BaseRNN, self).__init__() + self.input_dropout = nn.Dropout(p=input_dropout_p) + if rnn_cell.lower() == 'lstm': + self.rnn_cell = nn.LSTM + elif rnn_cell.lower() == 'gru': + self.rnn_cell = nn.GRU + else: + raise ValueError('Unsupported RNN Cell Type: {0}'.format(rnn_cell)) + self.rnn = self.rnn_cell(input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=output_dropout_p, + bidirectional=bidirectional) + + # TODO Trick for initializing LSTM gate parameters + if rnn_cell.lower() == 'lstm': + for names in self.rnn._all_weights: + for name in filter(lambda n: 'bias' in n, names): + bias = getattr(self.rnn, name) + n = bias.size(0) + start, end = n // 4, n // 2 + bias.data[start:end].fill_(1.) diff --git a/latent_dialog/enc2dec/classifier.py b/latent_dialog/enc2dec/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..359d500a3db66136a3063300b5b8f562f6d975ab --- /dev/null +++ b/latent_dialog/enc2dec/classifier.py @@ -0,0 +1,103 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from latent_dialog.enc2dec.base_modules import BaseRNN + + +class EncoderGRUATTN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths): + super(EncoderGRUATTN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + self.variable_lengths = variable_lengths + self.nhid_attn = hidden_size + self.output_size = hidden_size*2 if bidirectional else hidden_size + + # attention to combine selection hidden states + self.attn = nn.Sequential( + nn.Linear(2 * hidden_size, hidden_size), + nn.Tanh(), + nn.Linear(hidden_size, 1) + ) + + def forward(self, residual_var, input_var, turn_feat, mask=None, init_state=None, input_lengths=None): + # residual_var: (batch_size, max_dlg_len, 2*utt_cell_size) + # input_var: (batch_size, max_dlg_len, dlg_cell_size) + + # TODO switch of mask + # mask = None + + require_embed = True + if require_embed: + # input_cat = th.cat([input_var, residual_var], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size) + input_cat = th.cat([input_var, residual_var, turn_feat], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size) + else: + # input_cat = th.cat([input_var], 2) + input_cat = th.cat([input_var, turn_feat], 2) + if mask is not None: + input_mask = mask.view(input_cat.size(0), input_cat.size(1), 1) # (batch_size, max_dlg_len*max_utt_len, 1) + input_cat = th.mul(input_cat, input_mask) + embedded = self.input_dropout(input_cat) + + require_rnn = True + if require_rnn: + if init_state is not None: + h, _ = self.rnn(embedded, init_state) + else: + h, _ = self.rnn(embedded) # (batch_size, max_dlg_len, 2*nhid_attn) + + logit = self.attn(h.contiguous().view(-1, 2*self.nhid_attn)).view(h.size(0), h.size(1)) # (batch_size, max_dlg_len) + # if mask is not None: + # logit_mask = mask.view(input_cat.size(0), input_cat.size(1)) + # logit_mask = -999.0 * logit_mask + # logit = logit_mask + logit + + prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(h) # (batch_size, max_dlg_len, 2*nhid_attn) + attn = th.sum(th.mul(h, prob), 1) # (batch_size, 2*nhid_attn) + + return attn + + else: + logit = self.attn(embedded.contiguous().view(input_cat.size(0)*input_cat.size(1), -1)).view(input_cat.size(0), input_cat.size(1)) + if mask is not None: + logit_mask = mask.view(input_cat.size(0), input_cat.size(1)) + logit_mask = -999.0 * logit_mask + logit = logit_mask + logit + + prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(embedded) # (batch_size, max_dlg_len, 2*nhid_attn) + attn = th.sum(th.mul(embedded, prob), 1) # (batch_size, 2*nhid_attn) + + return attn + + +class FeatureProjecter(nn.Module): + def __init__(self, input_dropout_p, input_size, output_size): + super(FeatureProjecter, self).__init__() + self.input_dropout = nn.Dropout(p=input_dropout_p) + self.sel_encoder = nn.Sequential( + nn.Linear(input_size, output_size), + nn.Tanh() + ) + + def forward(self, goals_h, attn_outs): + h = th.cat([attn_outs, goals_h], 1) # (batch_size, 2*nhid_attn+goal_nhid) + h = self.input_dropout(h) + h = self.sel_encoder.forward(h) # (batch_size, nhid_sel) + return h + + +class SelectionClassifier(nn.Module): + def __init__(self, selection_length, input_size, output_size): + super(SelectionClassifier, self).__init__() + self.sel_decoders = nn.ModuleList() + for _ in range(selection_length): + self.sel_decoders.append(nn.Linear(input_size, output_size)) + + def forward(self, proj_outs): + outs = [decoder.forward(proj_outs).unsqueeze(1) for decoder in self.sel_decoders] # outcome_len*(batch_size, 1, outcome_vocab_size) + outs = th.cat(outs, 1) # (batch_size, outcome_len, outcome_vocab_size) + return outs diff --git a/latent_dialog/enc2dec/decoders.py b/latent_dialog/enc2dec/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..b08a3492506fbf7aa5ae9f5325a096e8e0698c82 --- /dev/null +++ b/latent_dialog/enc2dec/decoders.py @@ -0,0 +1,575 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +import numpy as np +from latent_dialog.enc2dec.base_modules import BaseRNN +from latent_dialog.utils import cast_type, LONG, FLOAT +from latent_dialog.corpora import DECODING_MASKED_TOKENS, EOS +import pdb + + +TEACH_FORCE = 'teacher_forcing' +TEACH_GEN = 'teacher_gen' +GEN = 'gen' +GEN_VALID = 'gen_valid' + + +class Attention(nn.Module): + def __init__(self, dec_cell_size, ctx_cell_size, attn_mode, project): + super(Attention, self).__init__() + self.dec_cell_size = dec_cell_size + self.ctx_cell_size = ctx_cell_size + self.attn_mode = attn_mode + if project: + self.linear_out = nn.Linear(dec_cell_size+ctx_cell_size, dec_cell_size) + else: + self.linear_out = None + + if attn_mode == 'general': + self.dec_w = nn.Linear(dec_cell_size, ctx_cell_size) + elif attn_mode == 'cat': + self.dec_w = nn.Linear(dec_cell_size, dec_cell_size) + self.attn_w = nn.Linear(ctx_cell_size, dec_cell_size) + self.query_w = nn.Linear(dec_cell_size, 1) + + def forward(self, output, context): + # output: (batch_size, output_seq_len, dec_cell_size) + # context: (batch_size, max_ctx_len, ctx_cell_size) + batch_size = output.size(0) + max_ctx_len = context.size(1) + + if self.attn_mode == 'dot': + attn = th.bmm(output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len) + elif self.attn_mode == 'general': + mapped_output = self.dec_w(output) # (batch_size, output_seq_len, ctx_cell_size) + attn = th.bmm(mapped_output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len) + elif self.attn_mode == 'cat': + mapped_output = self.dec_w(output) # (batch_size, output_seq_len, dec_cell_size) + mapped_attn = self.attn_w(context) # (batch_size, max_ctx_len, dec_cell_size) + tiled_output = mapped_output.unsqueeze(2).repeat(1, 1, max_ctx_len, 1) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size) + tiled_attn = mapped_attn.unsqueeze(1) # (batch_size, 1, max_ctx_len, dec_cell_size) + fc1 = th.tanh(tiled_output+tiled_attn) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size) + attn = self.query_w(fc1).squeeze(-1) # (batch_size, otuput_seq_len, max_ctx_len) + else: + raise ValueError('Unknown attention mode') + + # TODO mask + # if self.mask is not None: + + attn = F.softmax(attn.view(-1, max_ctx_len), dim=1).view(batch_size, -1, max_ctx_len) # (batch_size, output_seq_len, max_ctx_len) + mix = th.bmm(attn, context) # (batch_size, output_seq_len, ctx_cell_size) + combined = th.cat((mix, output), dim=2) # (batch_size, output_seq_len, dec_cell_size+ctx_cell_size) + if self.linear_out is None: + return combined, attn + else: + output = th.tanh( + self.linear_out(combined.view(-1, self.dec_cell_size+self.ctx_cell_size))).view( + batch_size, -1, self.dec_cell_size) # (batch_size, output_seq_len, dec_cell_size) + return output, attn + + +class DecoderRNN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, + bidirectional, vocab_size, use_attn, ctx_cell_size, attn_mode, sys_id, eos_id, use_gpu, + max_dec_len, embedding=None): + + super(DecoderRNN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + + # TODO embedding is None or not + if embedding is None: + self.embedding = nn.Embedding(vocab_size, input_size) + else: + self.embedding = embedding + + # share parameters between encoder and decoder + # self.rnn = ctx_encoder.rnn + # self.FC = nn.Linear(input_size, utt_encoder.output_size) + + self.use_attn = use_attn + if self.use_attn: + self.attention = Attention(dec_cell_size=hidden_size, + ctx_cell_size=ctx_cell_size, + attn_mode=attn_mode, + project=True) + + self.dec_cell_size = hidden_size + self.output_size = vocab_size + self.project = nn.Linear(self.dec_cell_size, self.output_size) + self.log_softmax = F.log_softmax + + self.sys_id = sys_id + self.eos_id = eos_id + self.use_gpu = use_gpu + self.max_dec_len = max_dec_len + + def forward(self, batch_size, dec_inputs, dec_init_state, attn_context, mode, gen_type, beam_size, goal_hid=None): + # dec_inputs: (batch_size, response_size-1) + # attn_context: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + + ret_dict = dict() + + if self.use_attn: + ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() + + if mode == GEN: + dec_inputs = None + + if gen_type != 'beam': + beam_size = 1 + + if dec_inputs is not None: + decoder_input = dec_inputs + else: + # prepare the BOS inputs + with th.no_grad(): + bos_var = Variable(th.LongTensor([self.sys_id])) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size*beam_size, 1) # (batch_size, 1) + + if mode == GEN and gen_type == 'beam': + # TODO if beam search, repeat the initial states of the RNN + pass + else: + decoder_hidden_state = dec_init_state + + prob_outputs = [] # list of logprob | max_dec_len*(batch_size, 1, vocab_size) + symbol_outputs = [] # list of word ids | max_dec_len*(batch_size, 1) + # back_pointers = [] + # lengths = blabla... + + def decode(step, cum_sum, step_output, step_attn): + prob_outputs.append(step_output) + step_output_slice = step_output.squeeze(1) # (batch_size, vocab_size) + if self.use_attn: + ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) + + if gen_type == 'greedy': + _, symbols = step_output_slice.topk(1) # (batch_size, 1) + elif gen_type == 'sample': + # TODO FIXME + # symbols = self.gumbel_max(step_output_slice) + pass + elif gen_type == 'beam': + # TODO + pass + else: + raise ValueError('Unsupported decoding mode') + + symbol_outputs.append(symbols) + + return cum_sum, symbols + + if mode == TEACH_FORCE: + prob_outputs, decoder_hidden_state, attn = self.forward_step(input_var=decoder_input, hidden_state=decoder_hidden_state, encoder_outputs=attn_context, goal_hid=goal_hid) + else: + # do free running here + cum_sum = None + for step in range(self.max_dec_len): + # Input: + # decoder_input: (batch_size, 1) + # decoder_hidden_state: tuple: (h, c) + # attn_context: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + # Output: + # decoder_output: (batch_size, 1, vocab_size) + # decoder_hidden_state: tuple: (h, c) + # step_attn: (batch_size, 1, max_ctx_len) + decoder_output, decoder_hidden_state, step_attn = self.forward_step(decoder_input, decoder_hidden_state, attn_context, goal_hid=goal_hid) + cum_sum, symbols = decode(step, cum_sum, decoder_output, step_attn) + decoder_input = symbols + + prob_outputs = th.cat(prob_outputs, dim=1) # (batch_size, max_dec_len, vocab_size) + + # back tracking to recover the 1-best in beam search + # if gen_type == 'beam': + + ret_dict[DecoderRNN.KEY_SEQUENCE] = symbol_outputs + + # prob_outputs: (batch_size, max_dec_len, vocab_size) + # decoder_hidden_state: tuple: (h, c) + # ret_dict[DecoderRNN.KEY_ATTN_SCORE]: max_dec_len*(batch_size, 1, max_ctx_len) + # ret_dict[DecoderRNN.KEY_SEQUENCE]: max_dec_len*(batch_size, 1) + return prob_outputs, decoder_hidden_state, ret_dict + + def forward_step(self, input_var, hidden_state, encoder_outputs, goal_hid): + # input_var: (batch_size, response_size-1 i.e. output_seq_len) + # hidden_state: tuple: (h, c) + # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size) + # goal_hid: (batch_size, goal_nhid) + batch_size, output_seq_len = input_var.size() + embedded = self.embedding(input_var) # (batch_size, output_seq_len, embedding_dim) + + # add goals + if goal_hid is not None: + goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (batch_size, 1, goal_nhid) + goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (batch_size, output_seq_len, goal_nhid) + embedded = th.cat([embedded, goal_rep], dim=2) # (batch_size, output_seq_len, embedding_dim+goal_nhid) + + embedded = self.input_dropout(embedded) + + # ############ + # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1) + + # output: (batch_size, output_seq_len, dec_cell_size) + # hidden: tuple: (h, c) + output, hidden_s = self.rnn(embedded, hidden_state) + + attn = None + if self.use_attn: + # output: (batch_size, output_seq_len, dec_cell_size) + # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size) + # attn: (batch_size, output_seq_len, max_ctx_len) + output, attn = self.attention(output, encoder_outputs) + + logits = self.project(output.contiguous().view(-1, self.dec_cell_size)) # (batch_size*output_seq_len, vocab_size) + prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size) + return prediction, hidden_s, attn + + # special for rl + def _step(self, input_var, hidden_state, encoder_outputs, goal_hid): + # input_var: (1, 1) + # hidden_state: tuple: (h, c) + # encoder_outputs: (1, max_dlg_len, dlg_cell_size) + # goal_hid: (1, goal_nhid) + batch_size, output_seq_len = input_var.size() + embedded = self.embedding(input_var) # (1, 1, embedding_dim) + + if goal_hid is not None: + goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (1, 1, goal_nhid) + goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (1, 1, goal_nhid) + embedded = th.cat([embedded, goal_rep], dim=2) # (1, 1, embedding_dim+goal_nhid) + + embedded = self.input_dropout(embedded) + + # ############ + # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1) + + # output: (1, 1, dec_cell_size) + # hidden: tuple: (h, c) + output, hidden_s = self.rnn(embedded, hidden_state) + + attn = None + if self.use_attn: + # output: (1, 1, dec_cell_size) + # encoder_outputs: (1, max_dlg_len, dlg_cell_size) + # attn: (1, 1, max_dlg_len) + output, attn = self.attention(output, encoder_outputs) + + logits = self.project(output.view(-1, self.dec_cell_size)) # (1*1, vocab_size) + prediction = logits.view(batch_size, output_seq_len, -1) # (1, 1, vocab_size) + # prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size) + return prediction, hidden_s + + # special for rl + def write(self, input_var, hidden_state, encoder_outputs, max_words, vocab, stop_tokens, goal_hid=None, mask=True, + decoding_masked_tokens=DECODING_MASKED_TOKENS): + # input_var: (1, 1) + # hidden_state: tuple: (h, c) + # encoder_outputs: max_dlg_len*(1, 1, dlg_cell_size) + # goal_hid: (1, goal_nhid) + logprob_outputs = [] # list of logprob | max_dec_len*(1, ) + symbol_outputs = [] # list of word ids | max_dec_len*(1, ) + decoder_input = input_var + decoder_hidden_state = hidden_state + if type(encoder_outputs) is list: + encoder_outputs = th.cat(encoder_outputs, 1) # (1, max_dlg_len, dlg_cell_size) + # print('encoder_outputs.size() = {}'.format(encoder_outputs.size())) + + if mask: + special_token_mask = Variable(th.FloatTensor([-999. if token in decoding_masked_tokens else 0. for token in vocab])) + special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) + + def _sample(dec_output, num_i): + # dec_output: (1, 1, vocab_size), need to softmax and log_softmax + dec_output = dec_output.view(-1) # (vocab_size, ) + # TODO temperature + prob = F.softmax(dec_output/0.6, dim=0) # (vocab_size, ) + logprob = F.log_softmax(dec_output, dim=0) # (vocab_size, ) + symbol = prob.multinomial(num_samples=1).detach() # (1, ) + # _, symbol = prob.topk(1) # (1, ) + _, tmp_symbol = prob.topk(1) # (1, ) + # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) + # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) + logprob = logprob.gather(0, symbol) # (1, ) + return logprob, symbol + + for i in range(max_words): + decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) + # disable special tokens from being generated in a normal turn + if mask: + decoder_output += special_token_mask.expand(1, 1, -1) + logprob, symbol = _sample(decoder_output, i) + logprob_outputs.append(logprob) + symbol_outputs.append(symbol) + decoder_input = symbol.view(1, -1) + + if vocab[symbol.item()] in stop_tokens: + break + + assert len(logprob_outputs) == len(symbol_outputs) + # logprob_list = [t.item() for t in logprob_outputs] + logprob_list = logprob_outputs + symbol_list = [t.item() for t in symbol_outputs] + return logprob_list, symbol_list + + # For MultiWoz RL + def forward_rl(self, batch_size, dec_init_state, attn_context, vocab, max_words, goal_hid=None, mask=True, temp=0.1): + # prepare the BOS inputs + with th.no_grad(): + bos_var = Variable(th.LongTensor([self.sys_id])) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size, 1) # (1, 1) + decoder_hidden_state = dec_init_state # tuple: (h, c) + encoder_outputs = attn_context # (1, ctx_len, ctx_cell_size) + + logprob_outputs = [] # list of logprob | max_dec_len*(1, ) + symbol_outputs = [] # list of word ids | max_dec_len*(1, ) + + if mask: + special_token_mask = Variable(th.FloatTensor([-999. if token in DECODING_MASKED_TOKENS else 0. for token in vocab])) + special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) + + def _sample(dec_output, num_i): + # dec_output: (1, 1, vocab_size), need to softmax and log_softmax + dec_output = dec_output.view(batch_size, -1) # (batch_size, vocab_size, ) + prob = F.softmax(dec_output/temp, dim=1) # (batch_size, vocab_size, ) + logprob = F.log_softmax(dec_output, dim=1) # (batch_size, vocab_size, ) + symbol = prob.multinomial(num_samples=1).detach() # (batch_size, 1) + # _, symbol = prob.topk(1) # (1, ) + _, tmp_symbol = prob.topk(1) # (1, ) + # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) + # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) + logprob = logprob.gather(1, symbol) # (1, ) + return logprob, symbol + + stopped_samples = set() + for i in range(max_words): + decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) + # disable special tokens from being generated in a normal turn + if mask: + decoder_output += special_token_mask.expand(1, 1, -1) + logprob, symbol = _sample(decoder_output, i) + logprob_outputs.append(logprob) + symbol_outputs.append(symbol) + decoder_input = symbol.view(batch_size, -1) + for b_id in range(batch_size): + if vocab[symbol[b_id].item()] == EOS: + stopped_samples.add(b_id) + + if len(stopped_samples) == batch_size: + break + + assert len(logprob_outputs) == len(symbol_outputs) + symbol_outputs = th.cat(symbol_outputs, dim=1).cpu().data.numpy().tolist() + logprob_outputs = th.cat(logprob_outputs, dim=1) + logprob_list = [] + symbol_list = [] + for b_id in range(batch_size): + b_logprob = [] + b_symbol = [] + for t_id in range(logprob_outputs.shape[1]): + symbol = symbol_outputs[b_id][t_id] + if vocab[symbol] == EOS and t_id != 0: + break + + b_symbol.append(symbol_outputs[b_id][t_id]) + b_logprob.append(logprob_outputs[b_id][t_id]) + + logprob_list.append(b_logprob) + symbol_list.append(b_symbol) + + # TODO backward compatible, if batch_size == 1, we remove the nested structure + if batch_size == 1: + logprob_list = logprob_list[0] + symbol_list = symbol_list[0] + + return logprob_list, symbol_list + +class DecoderPointerGen(BaseRNN): + + def __init__(self, vocab_size, max_len, input_size, hidden_size, sos_id, + eos_id, n_layers=1, rnn_cell='lstm', input_dropout_p=0, + dropout_p=0, attn_mode='cat', attn_size=None, use_gpu=True, + embedding=None): + + super(DecoderPointerGen, self).__init__(vocab_size, input_size, + hidden_size, input_dropout_p, + dropout_p, n_layers, rnn_cell, False) + + self.output_size = vocab_size + self.max_length = max_len + self.eos_id = eos_id + self.sos_id = sos_id + self.use_gpu = use_gpu + self.attn_size = attn_size + + if embedding is None: + self.embedding = nn.Embedding(self.output_size, self.input_size) + else: + self.embedding = embedding + + self.attention = Attention(self.hidden_size, attn_size, attn_mode, + project=True) + + self.project = nn.Linear(self.hidden_size, self.output_size) + self.sentinel = nn.Parameter(torch.randn((1, 1, attn_size)), requires_grad=True) + self.register_parameter('sentinel', self.sentinel) + + def forward_step(self, input_var, hidden, attn_ctxs, attn_words, ctx_embed=None): + """ + attn_size: number of context to attend + :param input_var: + :param hidden: + :param attn_ctxs: batch_size x attn_size+1 x ctx_size. If None, then leave it empty + :param attn_words: batch_size x attn_size + :return: + """ + # we enable empty attention context + batch_size = input_var.size(0) + seq_len = input_var.size(1) + embedded = self.embedding(input_var) + if ctx_embed is not None: + embedded += ctx_embed + + embedded = self.input_dropout(embedded) + output, hidden = self.rnn(embedded, hidden) + + if attn_ctxs is None: + # pointer network here + logits = self.project(output.contiguous().view(-1, self.hidden_size)) + predicted_softmax = F.log_softmax(logits, dim=1) + return predicted_softmax, None, hidden, None, None + else: + attn_size = attn_words.size(1) + combined_output, attn = self.attention(output, attn_ctxs) + + # output: batch_size x seq_len x hidden_size + # attn: batch_size x seq_len x (attn_size+1) + + # pointer network here + rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1) + g = attn[:, :, 0].contiguous() + ptr_attn = attn[:, :, 1:].contiguous() + ptr_softmax = Variable(torch.zeros((batch_size * seq_len * attn_size, self.vocab_size))) + ptr_softmax = cast_type(ptr_softmax, FLOAT, self.use_gpu) + + # convert words and ids into 1D + flat_attn_words = attn_words.unsqueeze(1).repeat(1, seq_len, 1).view(-1, 1) + flat_attn = ptr_attn.view(-1, 1) + + # fill in the attention into ptr_softmax + ptr_softmax = ptr_softmax.scatter_(1, flat_attn_words, flat_attn) + ptr_softmax = ptr_softmax.view(batch_size * seq_len, attn_size, self.vocab_size) + ptr_softmax = torch.sum(ptr_softmax, dim=1) + + # mix the softmax from rnn and pointer + mixture_softmax = rnn_softmax * g.view(-1, 1) + ptr_softmax + + # take the log to get logsoftmax + logits = torch.log(mixture_softmax.clamp(min=1e-8)) + predicted_softmax = logits.view(batch_size, seq_len, -1) + ptr_softmax = ptr_softmax.view(batch_size, seq_len, -1) + + return predicted_softmax, ptr_softmax, hidden, ptr_attn, g + + def forward(self, batch_size, attn_context, attn_words, + inputs=None, init_state=None, mode=TEACH_FORCE, + gen_type='greedy', ctx_embed=None): + + # sanity checks + ret_dict = dict() + + if mode == GEN: + inputs = None + + if inputs is not None: + decoder_input = inputs + else: + # prepare the BOS inputs + bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True) + bos_var = cast_type(bos_var, LONG, self.use_gpu) + decoder_input = bos_var.expand(batch_size, 1) + + # append sentinel to the attention + if attn_context is not None: + attn_context = torch.cat([self.sentinel.expand(batch_size, 1, self.attn_size), + attn_context], dim=1) + + decoder_hidden = init_state + decoder_outputs = [] # a list of logprob + sequence_symbols = [] # a list word ids + attentions = [] + pointer_gs = [] + pointer_outputs = [] + lengths = np.array([self.max_length] * batch_size) + + def decode(step, step_output): + decoder_outputs.append(step_output) + step_output_slice = step_output.squeeze(1) + + if gen_type == 'greedy': + symbols = step_output_slice.topk(1)[1] + elif gen_type == 'sample': + symbols = self.gumbel_max(step_output_slice) + else: + raise ValueError("Unsupported decoding mode") + + sequence_symbols.append(symbols) + + eos_batches = symbols.data.eq(self.eos_id) + if eos_batches.dim() > 0: + eos_batches = eos_batches.cpu().view(-1).numpy() + update_idx = ((lengths > di) & eos_batches) != 0 + lengths[update_idx] = len(sequence_symbols) + return symbols + + # Manual unrolling is used to support random teacher forcing. + # If teacher_forcing_ratio is True or False instead of a probability, + # the unrolling can be done in graph + if mode == TEACH_FORCE: + pred_softmax, ptr_softmax, decoder_hidden, attn, step_g = self.forward_step( + decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) + + # in teach forcing mode, we don't need symbols. + attentions = attn + decoder_outputs = pred_softmax + pointer_gs = step_g + pointer_outputs = ptr_softmax + + else: + # do free running here + for di in range(self.max_length): + pred_softmax, ptr_softmax, decoder_hidden, step_attn, step_g = self.forward_step( + decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) + + symbols = decode(di, pred_softmax) + + # append the results into ctx dictionary + attentions.append(step_attn) + pointer_gs.append(step_g) + pointer_outputs.append(ptr_softmax) + decoder_input = symbols + + # make list be a tensor + decoder_outputs = torch.cat(decoder_outputs, dim=1) + pointer_outputs = torch.cat(pointer_outputs, dim=1) + pointer_gs = torch.cat(pointer_gs, dim=1) + + # save the decoded sequence symbols and sequence length + ret_dict[self.KEY_ATTN_SCORE] = attentions + ret_dict[self.KEY_SEQUENCE] = sequence_symbols + ret_dict[self.KEY_LENGTH] = lengths + ret_dict[self.KEY_G] = pointer_gs + ret_dict[self.KEY_PTR_SOFTMAX] = pointer_outputs + ret_dict[self.KEY_PTR_CTX] = attn_words + + return decoder_outputs, decoder_hidden, ret_dict diff --git a/latent_dialog/enc2dec/encoders.py b/latent_dialog/enc2dec/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..c73bbfd6488a17683f891cd8cefa699b015131a1 --- /dev/null +++ b/latent_dialog/enc2dec/encoders.py @@ -0,0 +1,215 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from latent_dialog.enc2dec.base_modules import BaseRNN + + +class EncoderRNN(BaseRNN): + def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths): + super(EncoderRNN, self).__init__(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional) + self.variable_lengths = variable_lengths + self.output_size = hidden_size*2 if bidirectional else hidden_size + + def forward(self, input_var, init_state=None, input_lengths=None, goals=None): + # add goals + if goals is not None: + batch_size, max_ctx_len, ctx_nhid = input_var.size() + goals = goals.view(goals.size(0), 1, goals.size(1)) + goals_rep = goals.repeat(1, max_ctx_len, 1).view(batch_size, max_ctx_len, -1) # (batch_size, max_ctx_len, goal_nhid) + input_var = th.cat([input_var, goals_rep], dim=2) + + embedded = self.input_dropout(input_var) + + if self.variable_lengths: + embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, + batch_first=True) + if init_state is not None: + output, hidden = self.rnn(embedded, init_state) + else: + output, hidden = self.rnn(embedded) + if self.variable_lengths: + output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) + + return output, hidden + + +class RnnUttEncoder(nn.Module): + def __init__(self, vocab_size, embedding_dim, feat_size, goal_nhid, rnn_cell, + utt_cell_size, num_layers, input_dropout_p, output_dropout_p, + bidirectional, variable_lengths, use_attn, embedding=None): + super(RnnUttEncoder, self).__init__() + if embedding is None: + self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) + else: + self.embedding = embedding + + self.rnn = EncoderRNN(input_dropout_p=input_dropout_p, + rnn_cell=rnn_cell, + input_size=embedding_dim+feat_size+goal_nhid, + hidden_size=utt_cell_size, + num_layers=num_layers, + output_dropout_p=output_dropout_p, + bidirectional=bidirectional, + variable_lengths=variable_lengths) + + self.utt_cell_size = utt_cell_size + self.multiplier = 2 if bidirectional else 1 + self.output_size = self.multiplier * self.utt_cell_size + self.use_attn = use_attn + if self.use_attn: + self.key_w = nn.Linear(self.output_size, self.utt_cell_size) + self.query = nn.Linear(self.utt_cell_size, 1) + + def forward(self, utterances, feats=None, init_state=None, goals=None): + batch_size, max_ctx_len, max_utt_len = utterances.size() + # get word embeddings + flat_words = utterances.view(-1, max_utt_len) # (batch_size*max_ctx_len, max_utt_len) + word_embeddings = self.embedding(flat_words) # (batch_size*max_ctx_len, max_utt_len, embedding_dim) + flat_mask = th.sign(flat_words).float() + # add features + if feats is not None: + flat_feats = feats.view(-1, 1) # (batch_size*max_ctx_len, 1) + flat_feats = flat_feats.unsqueeze(1).repeat(1, max_utt_len, 1) # (batch_size*max_ctx_len, max_utt_len, 1) + word_embeddings = th.cat([word_embeddings, flat_feats], dim=2) # (batch_size*max_ctx_len, max_utt_len, embedding_dim+1) + + # add goals + if goals is not None: + goals = goals.view(goals.size(0), 1, 1, goals.size(1)) + goals_rep = goals.repeat(1, max_ctx_len, max_utt_len, 1).view(batch_size*max_ctx_len, max_utt_len, -1) # (batch_size*max_ctx_len, max_utt_len, goal_nhid) + word_embeddings = th.cat([word_embeddings, goals_rep], dim=2) + + # enc_outs: (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size) + # enc_last: (num_layers*num_directions, batch_size*max_ctx_len, utt_cell_size) + enc_outs, enc_last = self.rnn(word_embeddings, init_state=init_state) + + if self.use_attn: + fc1 = th.tanh(self.key_w(enc_outs)) # (batch_size*max_ctx_len, max_utt_len, utt_cell_size) + attn = self.query(fc1).squeeze(2) + # (batch_size*max_ctx_len, max_utt_len) + attn = F.softmax(attn, attn.dim()-1) # (batch_size*max_ctx_len, max_utt_len, 1) + attn = attn * flat_mask + attn = (attn / (th.sum(attn, dim=1, keepdim=True)+1e-10)).unsqueeze(2) + utt_embedded = attn * enc_outs # (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size) + utt_embedded = th.sum(utt_embedded, dim=1) # (batch_size*max_ctx_len, num_directions*utt_cell_size) + else: + # FIXME bug for multi-layer + attn = None + utt_embedded = enc_last.transpose(0, 1).contiguous() # (batch_size*max_ctx_lens, num_layers*num_directions, utt_cell_size) + utt_embedded = utt_embedded.view(-1, self.output_size) # (batch_size*max_ctx_len*num_layers, num_directions*utt_cell_size) + + utt_embedded = utt_embedded.view(batch_size, max_ctx_len, self.output_size) + return utt_embedded, word_embeddings.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1), \ + enc_outs.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1) + + +class MlpGoalEncoder(nn.Module): + def __init__(self, goal_vocab_size, k, nembed, nhid, init_range): + super(MlpGoalEncoder, self).__init__() + + # create separate embedding for counts and values + self.cnt_enc = nn.Embedding(goal_vocab_size, nembed) + self.val_enc = nn.Embedding(goal_vocab_size, nembed) + + self.encoder = nn.Sequential( + nn.Tanh(), + nn.Linear(k*nembed, nhid) + ) + + self.cnt_enc.weight.data.uniform_(-init_range, init_range) + self.val_enc.weight.data.uniform_(-init_range, init_range) + self._init_cont(self.encoder, init_range) + + def _init_cont(self, cont, init_range): + """initializes a container uniformly.""" + for m in cont: + if hasattr(m, 'weight'): + m.weight.data.uniform_(-init_range, init_range) + if hasattr(m, 'bias'): + m.bias.data.fill_(0) + + def forward(self, goal): + # goal: (batch_size, goal_len) + goal = goal.transpose(0, 1).contiguous() # (goal_len, batch_size) + idx = np.arange(goal.size(0) // 2) + + # extract counts and values + cnt_idx = Variable(th.from_numpy(2 * idx + 0)) + val_idx = Variable(th.from_numpy(2 * idx + 1)) + + if goal.is_cuda: + cnt_idx = cnt_idx.type(th.cuda.LongTensor) + val_idx = val_idx.type(th.cuda.LongTensor) + else: + cnt_idx = cnt_idx.type(th.LongTensor) + val_idx = val_idx.type(th.LongTensor) + + cnt = goal.index_select(0, cnt_idx) # (3, batch_size) + val = goal.index_select(0, val_idx) # (3, batch_size) + + # embed counts and values + cnt_emb = self.cnt_enc(cnt) # (3, batch_size, nembed) + val_emb = self.val_enc(val) # (3, batch_size, nembed) + + # element wise multiplication to get a hidden state + h = th.mul(cnt_emb, val_emb) # (3, batch_size, nembed) + # run the hidden state through the MLP + h = h.transpose(0, 1).contiguous().view(goal.size(1), -1) # (batch_size, 3*nembed) + goal_h = self.encoder(h) # (batch_size, nhid) + + return goal_h + + +class TaskMlpGoalEncoder(nn.Module): + def __init__(self, goal_vocab_sizes, nhid, init_range): + super(TaskMlpGoalEncoder, self).__init__() + + self.encoder = nn.ModuleList() + for v_size in goal_vocab_sizes: + domain_encoder = nn.Sequential( + nn.Linear(v_size, nhid), + nn.Tanh() + ) + self._init_cont(domain_encoder, init_range) + self.encoder.append(domain_encoder) + + def _init_cont(self, cont, init_range): + """initializes a container uniformly.""" + for m in cont: + if hasattr(m, 'weight'): + m.weight.data.uniform_(-init_range, init_range) + if hasattr(m, 'bias'): + m.bias.data.fill_(0) + + def forward(self, goals_list): + # goals_list: list of tensor, 7*(batch_size, goal_len), goal_len varies among differnet domains + outs = [encoder.forward(goal) for goal, encoder in zip(goals_list, self.encoder)] # 7*(batch_size, goal_nhid) + outs = th.sum(th.stack(outs), dim=0) # (batch_size, goal_nhid) + return outs + + +class SelfAttn(nn.Module): + def __init__(self, hidden_size): + super(SelfAttn, self).__init__() + self.query = nn.Linear(hidden_size, 1) + + def forward(self, keys, values, attn_mask=None): + """ + :param attn_inputs: batch_size x time_len x hidden_size + :param attn_mask: batch_size x time_len + :return: summary state + """ + alpha = F.softmax(self.query(keys), dim=1) + if attn_mask is not None: + alpha = alpha * attn_mask.unsqueeze(2) + alpha = alpha / th.sum(alpha, dim=1, keepdim=True) + + summary = th.sum(values * alpha, dim=1) + return summary \ No newline at end of file diff --git a/latent_dialog/evaluators.py b/latent_dialog/evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a699cb2312fbfcc5ed047d9f008c965c8fae8c --- /dev/null +++ b/latent_dialog/evaluators.py @@ -0,0 +1,1230 @@ +from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction +import math +import numpy as np +import scipy.stats as st +import latent_dialog.normalizer.delexicalize as delex +from latent_dialog.utils import get_tokenize, get_detokenize +from collections import Counter, defaultdict +from nltk.util import ngrams +from latent_dialog.corpora import SYS, USR, BOS, EOS +from sklearn.feature_extraction.text import CountVectorizer +import json +from latent_dialog.normalizer.delexicalize import normalize +from latent_dialog.augpt_utils import * +import json +import sqlite3 +import os +import random +import logging +import pdb +import re +from tqdm import tqdm +from sklearn.multiclass import OneVsRestClassifier +from sklearn.linear_model import SGDClassifier +from sklearn import metrics +from nltk.translate import bleu_score +from nltk.translate.bleu_score import SmoothingFunction +from scipy.stats import gmean + + +class BaseEvaluator(object): + def initialize(self): + raise NotImplementedError + + def add_example(self, ref, hyp): + raise NotImplementedError + + def get_report(self, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def _get_prec_recall(tp, fp, fn): + precision = tp / (tp + fp + 10e-20) + recall = tp / (tp + fn + 10e-20) + f1 = 2 * precision * recall / (precision + recall + 1e-20) + return precision, recall, f1 + + @staticmethod + def _get_tp_fp_fn(label_list, pred_list): + tp = len([t for t in pred_list if t in label_list]) + fp = max(0, len(pred_list) - tp) + fn = max(0, len(label_list) - tp) + return tp, fp, fn + +class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def score(self, hypothesis, corpus, n=1): + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in zip(hypothesis, corpus): + # if type(hyps[0]) is list: + # hyps = [hyp.split() for hyp in hyps[0]] + # else: + # hyps = [hyp.split() for hyp in hyps] + + # refs = [ref.split() for ref in refs] + hyps = [hyps] + # Shawn's evaluation + # refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + # hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu + +class BleuEvaluator(BaseEvaluator): + def __init__(self, data_name): + self.data_name = data_name + self.labels = list() + self.hyps = list() + + def initialize(self): + self.labels = list() + self.hyps = list() + + def add_example(self, ref, hyp): + self.labels.append(ref) + self.hyps.append(hyp) + + def get_report(self): + tokenize = get_tokenize() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + for label, hyp in zip(self.labels, self.hyps): + # label = label.replace(EOS, '') + # hyp = hyp.replace(EOS, '') + # ref_tokens = tokenize(label)[1:] + # hyp_tokens = tokenize(hyp)[1:] + ref_tokens = tokenize(label) + hyp_tokens = tokenize(hyp) + refs.append([ref_tokens]) + hyps.append(hyp_tokens) + bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + report = '\n===== BLEU = %f =====\n' % (bleu,) + return '\n===== REPORT FOR DATASET {} ====={}'.format(self.data_name, report) + +class MultiWozDB(object): + # loading databases + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] + dbs = {} + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + + for domain in domains: + db = os.path.join(CUR_DIR, 'data/norm-multi-woz/db/{}-dbase.db'.format(domain)) + conn = sqlite3.connect(db) + c = conn.cursor() + dbs[domain] = c + + def queryResultVenues(self, domain, turn, real_belief=False): + # query the db + sql_query = "select * from {}".format(domain) + + if real_belief == True: + items = turn.items() + else: + items = turn['metadata'][domain]['semi'].items() + + flag = True + for key, val in items: + if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care": + pass + else: + if flag: + sql_query += " where " + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" " + key + "=" + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" and " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" and " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" and " + key + "=" + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + return self.dbs[domain].execute(sql_query).fetchall() + except: + return [] + +class MultiWozEvaluator(BaseEvaluator): + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + logger = logging.getLogger() + def __init__(self, data_name, config): + self.data_name = data_name + self.slot_dict = delex.prepareSlotValuesIndependent() + delex_path = config.train_path.replace("train_dials.json", "") + [p for p in os.listdir(config.train_path.replace("train_dials.json", "")) if 'delex' in p][0] + self.delex_dialogues = json.load(open(delex_path)) + self.db = MultiWozDB() + self.labels = list() + self.hyps = list() + + def initialize(self): + self.labels = list() + self.hyps = list() + + def add_example(self, ref, hyp): + self.labels.append(ref) + self.hyps.append(hyp) + + def _parseGoal(self, goal, d, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': [], 'requestable': [], 'booking': [], 'failed': {}} + if 'info' in d['goal'][domain]: + # if d['goal'][domain].has_key('info'): + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append('reference') + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + if 'trainID' in d['goal'][domain]['reqt']: + goal[domain]['requestable'].append('id') + else: + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + for s in d['goal'][domain]['reqt']: # addtional requests: + if s in ['phone', 'address', 'postcode', 'reference', 'id']: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append("reference") + + goal[domain]["informable"] = d['goal'][domain]['info'] + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]["booking"] = d['goal'][domain]['book'] + + return goal + + def _evaluateGeneratedDialogue(self, dialog, goal, realDialogue, real_requestables, soft_acc=False): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, sent_t in enumerate(dialog): + for domain in goal.keys(): + # for computing success + if '[' + domain + '_name]' in sent_t or '_id' in sent_t: # undo delexicalization if system generates [domain_name] or [domain_id] + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + # in this case, look for the actual offered venues based on true belief state + venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1]) + # venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if '[' + domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + # assumption doesn't always hold, maybe it's better if name is provided by user that it is ignored? + if 'info' in realDialogue['goal'][domain]: + if 'name' in realDialogue['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + if domain == 'train': + if not venue_offered[domain]: + # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']: + if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']: + venue_offered[domain] = '[' + domain + '_name]' + + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH (between offered venue by generated dialogue and venue actually fitting to the criteria) + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + # if venue offered is not dict + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: # yields false positive, does not match what is offered with real dialogue? + match += 1 + match_stat = 1 + # if venue offered is dict + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: # actually checks the offered venue + match += 1 + match_stat = 1 + # other domains + else: + if domain + '_name]' in venue_offered[domain]: # yields false positive, in terms of occurence and correctness + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match)/len(goal.keys()) + else: + # only count success if all domain has matches + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + # SUCCESS (whether the requestable info in realDialogue is generated by the system) + # if no match, then success is assumed to be 0 + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: # if there is no requestable, assume to be succesful. incorrect, cause does not count false positives. + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + # pdb.set_trace() + # print(provided_requestables[domain], real_requestables[domain]) + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success)/len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + # rint requests, 'DIFF', requests_real, 'SUCC', success + return success, match, stats + + def _evaluateGeneratedDialogue_new(self, dialog, goal, realDialogue, real_requestables, soft_acc=False): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, sent_t in enumerate(dialog): + for domain in goal.keys(): + # for computing success + if '[' + domain + '_name]' in sent_t or '_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1]) + # venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + # if realDialogue['goal'][domain].has_key('info'): + if 'info' in realDialogue['goal'][domain]: + # if realDialogue['goal'][domain]['info'].has_key('name'): + if 'name' in realDialogue['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # the original method + # if domain == 'train': + # if not venue_offered[domain]: + # # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']: + # if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']: + # venue_offered[domain] = '[' + domain + '_name]' + + # Wrong one in HDSA + # if domain == 'train': + # if not venue_offered[domain]: + # if goal[domain]['requestable'] and 'id' not in goal[domain]['requestable']: + # venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + else: + if domain + '_name]' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match)/len(goal.keys()) + else: + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + # SUCCESS + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success)/len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + # print requests, 'DIFF', requests_real, 'SUCC', success + + + + return success, match, stats + + def _evaluateRealDialogue(self, dialog, filename, soft=False): + """Evaluation of the real dialogue from corpus. + First we loads the user goal and then go through the dialogue history. + Similar to evaluateGeneratedDialogue above.""" + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # get the list of domains in the goal + domains_in_goal = [] + goal = {} + for domain in domains: + if dialog['goal'][domain]: + goal = self._parseGoal(goal, dialog, domain) + domains_in_goal.append(domain) + + # compute corpus success + real_requestables = {} + provided_requestables = {} + venue_offered = {} + for domain in goal.keys(): + provided_requestables[domain] = [] + venue_offered[domain] = [] + real_requestables[domain] = goal[domain]['requestable'] + + # iterate each turn + m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1] + for t in range(len(m_targetutt)): + for domain in domains_in_goal: + sent_t = m_targetutt[t] + # for computing match - where there are limited entities + if domain + '_name' in sent_t or '_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1]) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + for requestable in requestables: + # check if reference could be issued + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + # return goal, 0, match, real_requestables + elif 'train_reference' in sent_t: + if dialog['log'][t * 2]['db_pointer'][-1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable in sent_t: + provided_requestables[domain].append(requestable) + + # offer was made? + for domain in domains_in_goal: + # if name was provided for the user, the match is being done automatically + # if dialog['goal'][domain].has_key('info'): + if 'info' in dialog['goal'][domain]: + # if dialog['goal'][domain]['info'].has_key('name'): + if 'name' in dialog['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + # HARD (0-1) EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match, success = 0, 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True) + # print(goal_venues) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + + else: + if domain + '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if not soft: + if match == len(goal.keys()): + match = 1 + else: + match = 0 + else: + match = float(match) / len(goal.keys()) + + # SUCCESS + if match == 1: # this or match > 0? + for domain in domains_in_goal: + domain_success = 0 + success_stat = 0 + if len(real_requestables[domain]) == 0: + # check that + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if success >= len(real_requestables): + success = 1 + else: + if soft: + success = float(success) / len(real_requestables) + else: + success = 0 + + return goal, success, match, real_requestables, stats + + def _evaluateRolloutDialogue(self, dialog): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # get the list of domains in the goal + domains_in_goal = [] + goal = {} + for domain in domains: + if dialog['goal'][domain]: + goal = self._parseGoal(goal, dialog, domain) + domains_in_goal.append(domain) + + # compute corpus success + real_requestables = {} + provided_requestables = {} + venue_offered = {} + for domain in goal.keys(): + provided_requestables[domain] = [] + venue_offered[domain] = [] + real_requestables[domain] = goal[domain]['requestable'] + + # iterate each turn + m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1] + for t in range(len(m_targetutt)): + for domain in domains_in_goal: + sent_t = m_targetutt[t] + # for computing match - where there are limited entities + if domain + '_name' in sent_t or domain+'_id' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + venue_offered[domain] = '[' + domain + '_name]' + """ + venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1]) + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + """ + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + for requestable in requestables: + # check if reference could be issued + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + # return goal, 0, match, real_requestables + elif 'train_reference' in sent_t: + if True or dialog['log'][t * 2]['db_pointer'][-1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable in sent_t: + provided_requestables[domain].append(requestable) + + # offer was made? + for domain in domains_in_goal: + # if name was provided for the user, the match is being done automatically + # if dialog['goal'][domain].has_key('info'): + if 'info' in dialog['goal'][domain]: + # if dialog['goal'][domain]['info'].has_key('name'): + if 'name' in dialog['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + # if id was not requested but train was found we dont want to override it to check if we booked the right train + if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']): + venue_offered[domain] = '[' + domain + '_name]' + + # REWARD CALCULATION + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + match, success = 0.0, 0.0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True) + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: + match += 1 + match_stat = 1 + else: + if domain + '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + match = min(1.0, float(match) / len(goal.keys())) + + # SUCCESS + if match: + for domain in domains_in_goal: + domain_success = 0 + success_stat = 0 + if len(real_requestables[domain]) == 0: + # check that + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + success = min(1.0, float(success) / len(real_requestables)) + + return success, match, stats + + def _parse_entities(self, tokens): + entities = [] + for t in tokens: + if '[' in t and ']' in t: + entities.append(t) + return entities + + def evaluateModel(self, dialogues, mode='valid', new_version=False, verbose=True): + """Gathers statistics for the whole sets.""" + delex_dialogues = self.delex_dialogues + successes, matches = 0, 0 + corpus_successes, corpus_matches = 0, 0 + total = 0 + data_succ, data_match = [], [] + corpus_succ, corpus_match = [], [] + + gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + sng_gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], + 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + for filename, dial in dialogues.items(): + if mode == 'rollout': + success, match, stats = self._evaluateRolloutDialogue(dial) + else: + # data is ground truth, dial is generated + data = delex_dialogues[filename] + goal, success, match, requestables, _ = self._evaluateRealDialogue(data, filename) # only goal and requestables are kept + corpus_successes += success + corpus_succ.append(success) + corpus_matches += match + corpus_match.append(match) + if new_version: + success, match, stats = self._evaluateGeneratedDialogue_new(dial, goal, data, requestables, + soft_acc=mode =='offline_rl') + else: + success, match, stats = self._evaluateGeneratedDialogue(dial, goal, data, requestables, + soft_acc=mode =='offline_rl') + + # if success == 0: + # pdb.set_trace() + + successes += success + data_succ.append(success) + matches += match + data_match.append(match) + total += 1 + + for domain in gen_stats.keys(): + gen_stats[domain][0] += stats[domain][0] + gen_stats[domain][1] += stats[domain][1] + gen_stats[domain][2] += stats[domain][2] + + if 'SNG' in filename: + for domain in gen_stats.keys(): + sng_gen_stats[domain][0] += stats[domain][0] + sng_gen_stats[domain][1] += stats[domain][1] + sng_gen_stats[domain][2] += stats[domain][2] + + report = "" + report += '{} Corpus Matches : {:2.2f}%, Groundtruth {} Matches : {:2.2f}%'.format(mode, (matches / float(total) * 100), mode, (corpus_matches / float(total) * 100)) + "\n" + report += '{} Corpus Success : {:2.2f}%, Groundtruth {} Success : {:2.2f}%'.format(mode, (successes / float(total) * 100), mode, (corpus_successes / float(total) * 100)) + "\n" + report += 'Total number of dialogues: %s, new version=%s ' % (total, new_version) + # compute 95% confidence interval + # for name, data in zip(["match", "success", "corpus_match", "corpus_success"], [data_match, data_succ, corpus_match, corpus_succ]): + # mean = np.mean(data) + # std = np.std(data) + # lower, upper = st.t.interval(0.95, len(data)-1, loc=np.mean(data), scale=st.sem(data)) + # print(f"{name}: {mean * 100} +- {(mean-lower) * 100}") + + + if verbose: + self.logger.info(report) + return report, successes/float(total), matches/float(total) + + def get_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + hyp_tokens = [BOS] + tokenize(hyp.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + hyps.append(hyp_tokens) + + ref_entities = self._parse_entities(ref_tokens) + hyp_entities = self._parse_entities(hyp_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, hyp_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + bleu = BLEUScorer().score(hyps, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + report = "\nBLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + return report, bleu, prec, rec, f1 + + def get_groundtruth_report(self): + tokenize = lambda x: x.split() + print('Generate report for {} samples'.format(len(self.hyps))) + refs, hyps = [], [] + tp, fp, fn = 0, 0, 0 + for label, hyp in zip(self.labels, self.hyps): + ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS] + refs.append([ref_tokens]) + + ref_entities = self._parse_entities(ref_tokens) + tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, ref_entities) + tp += tpp + fp += fpp + fn += fnn + + # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1) + # bleu = BLEUScorer().score(refs, refs) + prec, rec, f1 = self._get_prec_recall(tp, fp, fn) + # report = "\nGroundtruth BLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1) + report = "\nGroundtruth\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(prec, rec, f1) + return report, 0, prec, rec, f1 + +class MultiWozEvaluatorwPenalty(MultiWozEvaluator): + CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '') + logger = logging.getLogger() + def __init__(self, data_name, config): + super(MultiWozEvaluatorwPenalty, self).__init__(data_name, config) + + def _parseGoal(self, goal, d, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': [], 'requestable': [], 'booking': [], 'failed': {}} + if 'info' in d['goal'][domain]: + # if d['goal'][domain].has_key('info'): + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append('reference') + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + if 'trainID' in d['goal'][domain]['reqt']: + goal[domain]['requestable'].append('id') + goal[domain]['failed'] = d['goal'][domain]['fail_info'] + else: + if 'reqt' in d['goal'][domain]: + # if d['goal'][domain].has_key('reqt'): + for s in d['goal'][domain]['reqt']: # addtional requests: + if s in ['phone', 'address', 'postcode', 'reference', 'id']: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + goal[domain]['failed'] = d['goal'][domain]['fail_info'] + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]['requestable'].append("reference") + + goal[domain]["informable"] = d['goal'][domain]['info'] + if 'book' in d['goal'][domain]: + # if d['goal'][domain].has_key('book'): + goal[domain]["booking"] = d['goal'][domain]['book'] + + return goal + + def _evaluateGeneratedDialogue(self, dialog, goal, realDialogue, real_requestables, soft_acc=False): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + fail_info_penalty = {} + domains_in_goal = [] + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + fail_info_penalty[domain] = 0 + domains_in_goal.append(domain) + + for t, sent_t in enumerate(dialog): # go turn by turn + for domain in goal.keys(): # for each domain in goal + # for computing success + if '[' + domain + '_name]' in sent_t or '_id' in sent_t: # undo delexicalization if system generates [domain_name] or [domain_id] + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + # in this case, look for the actual offered venues based on true belief state + venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1]) + + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = random.sample(venues, 1) + else: + flag = False + for ven in venues: + if venue_offered[domain][0] == ven: + flag = True + break + if not flag and venues: # sometimes there are no results so sample won't work + # print venues + venue_offered[domain] = random.sample(venues, 1) + if goal[domain]['failed']: + if any([req in sent_t for req in ['[' + domain + "_" + req + ']' for req in requestables + ['name']]]) and len(venues) == 0: + fail_info_penalty[domain] = 1 + # print(sent_t, len(venues)) + + else: # not limited so we can provide one + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][ + -1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if '[' + domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + # assumption doesn't always hold, maybe it's better if name is provided by user that it is ignored? + if 'info' in realDialogue['goal'][domain]: + if 'name' in realDialogue['goal'][domain]['info']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + if domain == 'train': + if not venue_offered[domain]: + # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']: + if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']: + venue_offered[domain] = '[' + domain + '_name]' + + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0, 0], 'hotel': [0, 0, 0, 0], 'attraction': [0, 0, 0, 0], 'train': [0, 0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH (between offered venue by generated dialogue and venue actually fitting to the criteria) + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + # if venue offered is not dict + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: # yields false positive, does not match what is offered with real dialogue? + match += 1 + match_stat = 1 + # if venue offered is dict + elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: # actually checks the offered venue + match += 1 + match_stat = 1 + stats[domain][3] = fail_info_penalty[domain] + # other domains + else: + if domain + '_name]' in venue_offered[domain]: # yields false positive, in terms of occurence and correctness + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + + if soft_acc: + match = float(match)/len(goal.keys()) + else: + # only count success if all domain has matches + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + # SUCCESS (whether the requestable info in realDialogue is generated by the system) + # if no match, then success is assumed to be 0 + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: # if there is no requestable, assume to be succesful. incorrect, cause does not count false positives. + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + # final eval + if soft_acc: + success = float(success)/len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + success -= sum([stats[domain][3] for domain in ['restaurant', 'hotel', 'attraction', 'train']]) + + + # rint requests, 'DIFF', requests_real, 'SUCC', success + return success, match, stats diff --git a/latent_dialog/main.py b/latent_dialog/main.py new file mode 100644 index 0000000000000000000000000000000000000000..7debd8f9ff08db390eca0df618c696c2409fbf77 --- /dev/null +++ b/latent_dialog/main.py @@ -0,0 +1,1485 @@ +import os +import sys +import numpy as np +import torch as th +from torch.utils.tensorboard import SummaryWriter +from torch import nn +from tqdm import trange +from collections import defaultdict, Counter +from latent_dialog.enc2dec.base_modules import summary +from latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN +from datetime import datetime +from latent_dialog.utils import get_detokenize, LONG, FLOAT +from latent_dialog.corpora import EOS, PAD +from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE +from latent_dialog import evaluators +from latent_dialog.record import record, record_task, UniquenessSentMetric, UniquenessWordMetric +from sklearn.metrics import mean_squared_error as mse +import logging +import pdb +import json +from scipy.stats import entropy +import time + +logger = logging.getLogger() + +class LossManager(object): + def __init__(self): + self.losses = defaultdict(list) + self.backward_losses = [] + + def add_loss(self, loss): + for key, val in loss.items(): + # print('key = %s\nval = %s' % (key, val)) + if val is not None and type(val) is not bool: + self.losses[key].append(val.item()) + + def pprint(self, name, window=None, prefix=None): + str_losses = [] + for key, loss in self.losses.items(): + if loss is None: + continue + aver_loss = np.average(loss) if window is None else np.average(loss[-window:]) + if 'nll' in key: + str_losses.append('{} PPL {:.3f}'.format(key, np.exp(aver_loss))) + else: + str_losses.append('{} {:.3f}'.format(key, aver_loss)) + + + if prefix: + return '{}: {} {}'.format(prefix, name, ' '.join(str_losses)) + else: + return '{} {}'.format(name, ' '.join(str_losses)) + + def clear(self): + self.losses = defaultdict(list) + self.backward_losses = [] + + def add_backward_loss(self, loss): + self.backward_losses.append(loss.item()) + + def avg_loss(self): + return np.mean(self.backward_losses) + +class OfflineTaskReinforce(object): + def __init__(self, agent, corpus, sv_config, sys_model, rl_config, generate_func): + self.agent = agent + self.corpus = corpus + self.sv_config = sv_config + self.sys_model = sys_model + self.rl_config = rl_config + # training func for supervised learning + self.train_func = task_train_single_batch + self.record_func = record_task + self.validate_func = validate + + # prepare data loader + train_dial, val_dial, test_dial = self.corpus.get_corpus() + self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) + self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) + + # create log files + if self.rl_config.record_freq > 0: + self.learning_exp_file = open(os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w') + self.ppl_val_file = open(os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w') + self.rl_val_file = open(os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w') + self.ppl_test_file = open(os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w') + self.rl_test_file = open(os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w') + # evaluation + if "fail_info_penalty" in rl_config and rl_config.fail_info_penalty: + self.evaluator = evaluators.MultiWozEvaluatorwPenalty('SYS_WOZ', rl_config) + else: + self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', rl_config) + self.generate_func = generate_func + + def run(self): + n = 0 + best_valid_loss = np.inf + best_rewards = -1 * np.inf + + # BEFORE RUN, RECORD INITIAL PERFORMANCE + test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) + t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, + self.evaluator, None, verbose=False)#, temperature=self.rl_config.temperature) + + self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1)) + self.ppl_test_file.flush() + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match)) + self.rl_test_file.flush() + + self.sys_model.train() + try: + for epoch_id in range(self.rl_config.nepoch): + self.train_data.epoch_init(self.sv_config, shuffle=True, verbose=epoch_id == 0, fix_batch=True) + while True: + if n % self.rl_config.episode_repeat == 0: + batch = self.train_data.next_batch() + + if batch is None: + break + + n += 1 + if n % 50 == 0: + print("Reinforcement Learning {}/{} episode".format(n, self.train_data.num_batch*self.rl_config.nepoch)) + self.learning_exp_file.write( + '{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-50:]))) + self.learning_exp_file.flush() + + # reinforcement learning + # make sure it's the same dialo + assert len(set(batch['keys'])) == 1 + task_report, success, match = self.agent.run(batch, self.evaluator, max_words=self.rl_config.max_words, temp=self.rl_config.temperature) + reward = float(success) # + float(match) + # print(reward) + stats = {'Match': match, 'Success': success} + self.agent.update(reward, stats) + + # supervised learning + if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0: + self.train_func(self.sys_model, self.sl_train_data, self.sv_config) + + # record model performance in terms of several evaluation metrics + if self.rl_config.record_freq > 0 and n % self.rl_config.record_freq == 0: + self.agent.print_dialog(self.agent.dlg_history, reward, stats) + print('-'*15, 'Recording start', '-'*15) + # save train reward + self.learning_exp_file.write('{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-self.rl_config.record_freq:]))) + self.learning_exp_file.flush() + + # PPL & reward on validation + valid_loss = self.validate_func(self.sys_model, self.val_data, self.sv_config, use_py=True) + v_success, v_match, v_bleu, v_f1 = self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, None, verbose=False) #, temperature=self.rl_config.temperature) + self.ppl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(valid_loss), v_bleu, v_f1)) + self.ppl_val_file.flush() + self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, (v_success + v_match), v_success, v_match)) + self.rl_val_file.flush() + + test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True) + t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, None, verbose=False)#, temperature=self.rl_config.temperature) + self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1)) + self.ppl_test_file.flush() + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match)) + self.rl_test_file.flush() + + # save model is needed + if v_success+v_match > best_rewards: + print("Model saved with success {} match {}".format(v_success, v_match)) + th.save(self.sys_model.state_dict(), self.rl_config.reward_best_model_path) + best_rewards = v_success+v_match + + + self.sys_model.train() + print('-'*15, 'Recording end', '-'*15) + except KeyboardInterrupt: + print("RL training stopped from keyboard") + + print("$$$ Load {}-model".format(self.rl_config.reward_best_model_path)) + self.sv_config.batch_size = 32 + self.sys_model.load_state_dict(th.load(self.rl_config.reward_best_model_path)) + + validate(self.sys_model, self.val_data, self.sv_config, use_py=True) + validate(self.sys_model, self.test_data, self.sv_config, use_py=True) + + with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f: + self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f)#, temperature=self.rl_config.temperature) + + with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f: + self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f)#, temperature=self.rl_config.temperature) + +class OfflinePLAS(object): + def __init__(self, agent, corpus, sv_config, rl_config, generate_func, name="", vae_gen=None): + self.tb = SummaryWriter(comment=name) + logger.info(f"Run will be logged in tensorboard as {name}") + + self.agent = agent + self.corpus = corpus + self.sv_config = sv_config + # self.sys_model = agent.model + self.rl_config = rl_config + # training func for supervised learning + self.train_func = task_train_single_batch + # self.record_func = record_task + self.validate_func = validate_offlinerl #policy evaluation + self.all_rewards = [] + self.is_gauss = agent.is_gauss + self.train_vae = rl_config.train_vae + + + # prepare data loader + train_dial, val_dial, test_dial = self.corpus.get_corpus() + self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) + self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) + self.ae_val_data = BeliefDbDataLoadersAE('Val', val_dial, self.sv_config) + + # create log files + if self.rl_config.record_freq > 0: + self.learning_exp_file = open(os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w') + self.ppl_val_file = open(os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w') + self.rl_val_file = open(os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w') + self.ppl_test_file = open(os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w') + self.rl_test_file = open(os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w') + # evaluation + self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', rl_config) + self.generate_func = generate_func + self.vae_generate_func = vae_gen + + if "validate_with_critic" in rl_config: + self.validate_with_critic = rl_config.validate_with_critic + else: + self.validate_with_critic = False + + def extract(self, data_feed, replay_buffer): + """ + extract experiences from corpus + """ + data_feed.epoch_init(self.sv_config, shuffle=False, verbose=False, fix_batch=True) + # self.sys_model.eval() + self.keys = [] + delex_dialogues = self.evaluator.delex_dialogues + belief_state = th.load("/gpfs/project/lubis/LAVA_code/LAVA_dev/data/MultiWOZ_2.1_SetSUMBT_EnD2/setsumbt_belief_states.bin") + + corpus_successes = 0 + corpus_matches = 0 + total = 0 + + all_actions = [] + pred_strs = [] + true_strs = [] + + n = 0 + + while True: + batch = data_feed.next_batch() + + n += 1 + + if n % 500 == 0: + print("Processing batch {}/{}".format(n, data_feed.num_batch)) + + if batch is None: + break + + batch_size = len(batch['keys']) + key = batch['keys'][0] + self.keys.append(key) + assert len(set(batch['keys'])) == 1 + + rewards = np.zeros(batch_size) + + states = [] + actions = [] + next_states = [] + next_actions = [] + dones = [] + + + # actions, pred_str, true_str= self._get_actions(batch) + # actions, _ = self._sample_actions(batch) + # get states, next state, action, and done from data + for turn_id, turn in enumerate(batch['contexts']): + # belief_state[key][slot_name][turn_id * 2] (the .bin file contains double the amount of turn, but turns 0 and 1 are identical, so are 2 and 3, 4 and 5, etc.) + state = {} + state['contexts'] = turn + state['bs'] = batch['bs'][turn_id] + state['db'] = batch['db'][turn_id] + state['context_lens'] = batch['context_lens'][turn_id] + state['keys'] = batch['keys'][turn_id] + state['goals'] = np.concatenate([batch['goals_list'][d][turn_id] for d in range(7)]) + + action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id].tolist(), do_pad=True)) + + try: + next_state = {} + next_state['contexts'] = batch['contexts'][turn_id + 1] + next_state['bs'] = batch['bs'][turn_id + 1] + next_state['db'] = batch['db'][turn_id + 1] + next_state['context_lens'] = batch['context_lens'][turn_id + 1] + next_state['keys'] = batch['keys'][turn_id + 1] + next_state['goals'] = np.concatenate([batch['goals_list'][d][turn_id + 1] for d in range(7)]) + + next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id + 1].tolist(), do_pad=True)) + # next_state['outputs'] = batch['outputs'][turn_id + 1] + done = 0 + except: + next_state = {} + next_state['contexts'] = np.zeros(batch['contexts'][turn_id].shape) + next_state['bs'] = [0] * self.corpus.bs_size + next_state['db'] = [0] * self.corpus.db_size + next_state['context_lens'] = batch['context_lens'][turn_id] + next_state['keys'] = batch['keys'][turn_id] + next_state['goals'] = [0] * self.corpus.goal_size + + next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, [0], do_pad=True)) + # next_state['outputs'] = np.zeros(self.sv_config.max_utt_len) + done = 1 + + actions.append(action) + states.append(state) + next_states.append(next_state) + next_actions.append(next_action) + dones.append(done) + + # compute reward of dialogue from corpus here + data = delex_dialogues[key] + goal, success, match, requestables, _ = self.evaluator._evaluateRealDialogue(data, key, soft=False) + corpus_successes += success + corpus_matches += match + total += 1 + + rewards[-1] = float(success) + g = self.agent.cvae.np2var(np.array([rewards[-1]]), FLOAT).view(1, 1) + + # self.all_rewards.append(reward) + # r = (rewards - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards)) + + # compute accumulated discounted reward + returns = [] + for _ in rewards: + returns.insert(0, float(g)) + g = g * self.rl_config.gamma + + + # add tuples to replay buffer + # assert(len(states) == len(rewards) and len(states) == len(actions) and len(states) == len(dones) and len(states)== len(next_states), len(next_actions)==len(next_states)) + replay_buffer.add(states, actions, rewards, next_states, next_actions, dones, returns) + # for i in range(len(states)): + # replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], next_action[i], dones[i], returns[i]) + + def _infoGain(self, P, Q): + M = [p + q for p, q in zip(P, Q)] + return 0.5 * (entropy(P, M, base=2) + entropy(Q, M, base=2)) + + def run(self): + # get exps, step, learn, act, validate + n = 0 + best_valid_loss = np.inf + best_rewards = -1 * np.inf + vae_epoch_id = 0 + + # BEFORE RUN, RECORD INITIAL PERFORMANCE + # print("Recoding initial performance") + self.agent.actor.eval() + self.agent.critic.eval() + # test_rewards, test_match = self.validate_func(self.agent, self.evaluator, self.test_data, self.rl_config, mode="test", use_py=True) + # t_success, t_match, t_bleu, t_f1 = self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, None, verbose=False) + + # self.ppl_test_file.write('{}\t{}\t{}\t{}\t{}\n'.format(n, np.mean(test_rewards), np.mean(test_match), t_bleu, t_f1)) + # self.ppl_test_file.flush() + # self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match)) + # self.rl_test_file.flush() + + try: + for epoch_id in trange(self.rl_config.nepoch, desc="PLAS_epoch"): + for n in trange(self.rl_config.nepisode, desc="PLAS_batch"): + # VAE training + if self.train_vae and n % self.rl_config.train_vae_freq == 0: + if vae_epoch_id == 0: + train_vae_nepisode = self.rl_config.train_vae_nepisode_init + offset_idx = 0 + else: + train_vae_nepisode = self.rl_config.train_vae_nepisode + offset_idx = self.rl_config.train_vae_nepisode_init + self.rl_config.train_vae_nepisode * (vae_epoch_id - 1) + + logger.info("=== VAE Training ==") + + logger.info("validating starting performance") + vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True) + # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False) + self.tb.add_scalar("train_vae_success", vae_success, offset_idx + 1) + self.tb.add_scalar("train_vae__match", vae_match, offset_idx + 1) + self.tb.add_scalar("train_vae_bleu", vae_bleu, offset_idx + 1) + self.tb.add_scalar("train_vae_f1", vae_f1, offset_idx + 1) + + self.agent.actor.eval() + self.agent.critic.eval() + self.agent.cvae.train() + for i in trange(train_vae_nepisode, desc="VAE training"): + vae_loss = self.agent.train_vae_model(i) + if i % 2500 == 0: + # logger.info(f"vae training batch {i}/{train_vae_nepisode}, vae_loss:{vae_loss}") + + # logger.info("validating ending performance") + vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True) + # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False) + self.tb.add_scalar("train_vae_success", vae_success, i + offset_idx) + self.tb.add_scalar("train_vae__match", vae_match, i + offset_idx) + self.tb.add_scalar("train_vae_bleu", vae_bleu, i + offset_idx) + self.tb.add_scalar("train_vae_f1", vae_f1, i + offset_idx) + + self.tb.add_scalar("vae_loss", vae_loss, (i+1) + (self.rl_config.nepisode/self.rl_config.train_vae_freq) * (epoch_id)) + + vae_epoch_id += 1 + + # reinforcement learning + self.agent.actor.train() + self.agent.critic.train() + self.agent.cvae.eval() + # tic = time.perf_counter() + plas_report, losses = self.agent.train(verbose=n%10==0, + max_words=self.rl_config.max_words, + temp=self.rl_config.temperature, + debug=n%500==0, + n=(n+1) + (self.rl_config.nepisode) * (epoch_id)) + # toc = time.perf_counter() + # print(f"One batch of forward pass plas training in {toc - tic:0.4f} seconds") + + + if n % 20 == 0: + for k, v in losses.items(): + self.tb.add_scalar(k, v, (self.rl_config.nepisode) * (epoch_id) + (n+1)) + + if n % 100 == 0 and n > 0: + logger.info("PLAS {}/{} episode, {}/{} epoch".format(n, self.rl_config.nepisode, epoch_id, self.rl_config.nepoch)) + self.learning_exp_file.write(plas_report + "\n") + self.learning_exp_file.flush() + + # supervised learning + if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0: + self.train_func(self.agent.actor, self.sl_train_data, self.sv_config) + + + if n > 1 and n % self.rl_config.record_freq == 0: + self.agent.actor.eval() + self.agent.critic.eval() + + print('-'*15, 'Recording start', '-'*15) + # save train reward + + # PPL & reward on validation + + print("Running validation on train set") + tr_success, tr_match, tr_bleu, tr_f1, tr_Q = self.generate_func(self.agent.actor, self.agent.cvae, self.train_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic) + self.tb.add_scalar("train_Q", tr_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + + + print("Running validation on valid set") + # valid_rewards, valid_match = self.validate_func(self.agent, self.evaluator, self.val_data, self.rl_config, use_py=True) + # tic = time.perf_counter() + v_success, v_match, v_bleu, v_f1, v_Q = self.generate_func(self.agent.actor, self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic) + # toc = time.perf_counter() + # print(f"One batch of validation pass of plas training in {toc - tic:0.4f} seconds") + val_critic_error = v_Q - v_success + # self.ppl_val_file.write('{}\t{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, np.mean(valid_rewards), np.mean(valid_match), v_bleu, v_f1)) + # self.ppl_val_file.flush() + # print(v_success, v_match, v_bleu, v_f1, v_Q, val_critic_error) + self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, (v_success + v_match), v_success, v_match)) + self.rl_val_file.flush() + self.tb.add_scalar("validation_success", v_success, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_match", v_match, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_bleu", v_bleu, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_f1", v_f1, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_critic_mse", val_critic_error, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + + print("Running validation on test set") + # test_rewards, test_match = self.validate_func(self.agent, self.evaluator, self.test_data, self.rl_config, mode="test", use_py=True) + t_success, t_match, t_bleu, t_f1, t_Q = self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic) + test_critic_error = t_Q - t_success + # self.ppl_test_file.write('{}\t{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, np.mean(test_rewards), np.mean(test_match), t_bleu, t_f1)) + # self.ppl_test_file.flush() + # print(t_success, t_match, t_bleu, t_f1, t_Q, test_critic_error) + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, (t_success + t_match), t_success, t_match)) + self.rl_test_file.flush() + self.tb.add_scalar("test_success", t_success, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_match", t_match, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_bleu", t_bleu, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_f1", t_f1, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_critic_mse", test_critic_error, (n+1) + (self.rl_config.nepisode) * (epoch_id)) + + + # save model is needed + if self.validate_with_critic: + if v_Q > best_rewards: + print("Model saved with estimated value of {}".format(np.mean(v_Q))) + print("Corpus success {} and match {}".format(np.mean(v_success), np.mean(v_match))) + th.save(self.agent.cvae.state_dict(), self.rl_config.reward_best_model_path) + th.save(self.agent.actor.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".actor")) + th.save(self.agent.critic.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".critic")) + best_rewards = v_Q + + else: + if v_success+v_match > best_rewards: + # if np.mean(v_success) > best_rewards: + print("Model saved with success {} and match {}".format(np.mean(v_success), np.mean(v_match))) + th.save(self.agent.cvae.state_dict(), self.rl_config.reward_best_model_path) + th.save(self.agent.actor.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".actor")) + th.save(self.agent.critic.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".critic")) + # best_rewards = np.mean(v_success) + best_rewards = v_success + v_match + + # for name, weight in self.agent.named_parameters(): + # tb.add_histogram(name, weight, n + 1 ) + # tb.add_histogram(f'{name}.grad', weight.grad, n + 1) + + + # self.sys_model.train() + print('-'*15, 'Recording end', '-'*15) + + except KeyboardInterrupt: + print("RL training stopped from keyboard") + self.tb.close() + + print("$$$ Load {}-model".format(self.rl_config.reward_best_model_path)) + self.sv_config.batch_size = 32 + self.agent.cvae.load_state_dict(th.load(self.rl_config.reward_best_model_path)) + self.agent.actor.load_state_dict(th.load(self.rl_config.reward_best_model_path.replace(".model", ".actor"))) + self.agent.critic.load_state_dict(th.load(self.rl_config.reward_best_model_path.replace(".model", ".critic"))) + + # validate(self.sys_model, self.val_data, self.sv_config, use_py=True) + # validate(self.sys_model, self.test_data, self.sv_config, use_py=True) + + with open(os.path.join(self.rl_config.record_path, 'train_file.txt'), 'w') as f: + self.generate_func(self.agent.actor, self.agent.cvae, self.train_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic) + + with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f: + self.generate_func(self.agent.actor, self.agent.cvae, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic) + + with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f: + self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic) + +class OfflineCritic(object): + def __init__(self, agent, corpus, sv_config, critic_config, generate_func, name="", vae_gen=None, forward_only=False): + self.tb = SummaryWriter(comment=name) + logger.info(f"Run will be logged in tensorboard as {name}") + + self.agent = agent + self.corpus = corpus + self.sv_config = sv_config + # self.sys_model = agent.model + self.critic_config = critic_config + # training func for supervised learning + self.train_func = task_train_single_batch + # self.record_func = record_task + self.validate_func = validate_offlinerl #policy evaluation + self.all_rewards = [] + self.is_gauss = agent.is_gauss + self.train_vae = critic_config.train_vae + + # prepare data loader + train_dial, val_dial, test_dial = self.corpus.get_corpus() + self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config) + self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config) + self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config) + self.ae_val_data = BeliefDbDataLoadersAE('Val', val_dial, self.sv_config) + + # create log files + log_mode = "r" if forward_only else "w" + if self.critic_config.record_freq > 0: + # self.learning_exp_file = open(os.path.join(self.critic_config.record_path, 'offline-learning.tsv'), 'w') + # self.ppl_val_file = open(os.path.join(self.critic_config.record_path, 'val-ppl.tsv'), 'w') + self.rl_val_file = open(os.path.join(self.critic_config.record_path, 'val-critic.tsv'), log_mode) + # self.ppl_test_file = open(os.path.join(self.critic_config.record_path, 'test-ppl.tsv'), 'w') + self.rl_test_file = open(os.path.join(self.critic_config.record_path, 'test-critic.tsv'), log_mode) + # evaluation + self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', critic_config) + self.generate_func = generate_func + self.vae_generate_func = vae_gen + + def extract(self, data_feed, replay_buffer): + """ + extract experiences from corpus + """ + data_feed.epoch_init(self.sv_config, shuffle=False, verbose=False, fix_batch=True) + # self.sys_model.eval() + self.keys = [] + delex_dialogues = self.evaluator.delex_dialogues + belief_state = th.load("/gpfs/project/lubis/LAVA_code/LAVA_dev/data/MultiWOZ_2.1_SetSUMBT_EnD2/setsumbt_belief_states.bin") + + corpus_successes = 0 + corpus_matches = 0 + total = 0 + + all_actions = [] + pred_strs = [] + true_strs = [] + + n = 0 + + while True: + batch = data_feed.next_batch() + + n += 1 + + if n % 500 == 0: + print("Processing batch {}/{}".format(n, data_feed.num_batch)) + + if batch is None: + break + + batch_size = len(batch['keys']) + key = batch['keys'][0] + self.keys.append(key) + assert len(set(batch['keys'])) == 1 + + rewards = np.zeros(batch_size) + + states = [] + actions = [] + next_states = [] + next_actions = [] + dones = [] + + + # actions, pred_str, true_str= self._get_actions(batch) + # actions, _ = self._sample_actions(batch) + # get states, next state, action, and done from data + for turn_id, turn in enumerate(batch['contexts']): + # for each turn, compute info gain wrt previous turn (if exist). + # belief_state[key][slot_name][turn_id * 2] (the .bin file contains double the amount of turn, but turns 0 and 1 are identical, so are 2 and 3, 4 and 5, etc.) + state = {} + state['contexts'] = turn + state['bs'] = batch['bs'][turn_id] + state['db'] = batch['db'][turn_id] + state['context_lens'] = batch['context_lens'][turn_id] + state['keys'] = batch['keys'][turn_id] + state['goals'] = np.concatenate([batch['goals_list'][d][turn_id] for d in range(7)]) + + action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id].tolist(), do_pad=True)) + + try: + next_state = {} + next_state['contexts'] = batch['contexts'][turn_id + 1] + next_state['bs'] = batch['bs'][turn_id + 1] + next_state['db'] = batch['db'][turn_id + 1] + next_state['context_lens'] = batch['context_lens'][turn_id + 1] + next_state['keys'] = batch['keys'][turn_id + 1] + next_state['goals'] = np.concatenate([batch['goals_list'][d][turn_id + 1] for d in range(7)]) + + next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id + 1].tolist(), do_pad=True)) + # next_state['outputs'] = batch['outputs'][turn_id + 1] + done = 0 + except: + next_state = {} + next_state['contexts'] = np.zeros(batch['contexts'][turn_id].shape) + next_state['bs'] = [0] * self.corpus.bs_size + next_state['db'] = [0] * self.corpus.db_size + next_state['context_lens'] = batch['context_lens'][turn_id] + next_state['keys'] = batch['keys'][turn_id] + next_state['goals'] = [0] * self.corpus.goal_size + + next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, [0], do_pad=True)) + # next_state['outputs'] = np.zeros(self.sv_config.max_utt_len) + done = 1 + + actions.append(action) + states.append(state) + next_states.append(next_state) + next_actions.append(next_action) + dones.append(done) + + + # compute reward of dialogue from corpus here + data = delex_dialogues[key] + goal, success, match, requestables, _ = self.evaluator._evaluateRealDialogue(data, key, soft=False) + corpus_successes += success + corpus_matches += match + total += 1 + + rewards[-1] = float(success) + g = self.agent.cvae.np2var(np.array([rewards[-1]]), FLOAT).view(1, 1) + + # compute accumulated discounted reward + returns = [] + for _ in rewards: + returns.insert(0, float(g)) + g = g * self.critic_config.gamma + + + # add tuples to replay buffer + # assert(len(states) == len(rewards) and len(states) == len(actions) and len(states) == len(dones) and len(states)== len(next_states), len(next_actions)==len(next_states)) + replay_buffer.add(states, actions, rewards, next_states, next_actions, dones, returns) + # for i in range(len(states)): + # replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], next_action[i], dones[i], returns[i]) + + def _infoGain(self, P, Q): + M = [p + q for p, q in zip(P, Q)] + return 0.5 * (entropy(P, M, base=2) + entropy(Q, M, base=2)) + + def run(self): + # get exps, step, learn, act, validate + n = 0 + best_valid_loss = np.inf + vae_epoch_id = 0 + + try: + for epoch_id in trange(self.critic_config.nepoch, desc="Critic_epoch"): + for n in trange(self.critic_config.nepisode, desc="Critic_batch"): + # VAE training + if self.train_vae and n % self.critic_config.train_vae_freq == 0: + if vae_epoch_id == 0: + train_vae_nepisode = self.critic_config.train_vae_nepisode_init + offset_idx = 0 + else: + train_vae_nepisode = self.critic_config.train_vae_nepisode + offset_idx = self.critic_config.train_vae_nepisode_init + self.critic_config.train_vae_nepisode * (vae_epoch_id - 1) + + logger.info("=== VAE Training ==") + + logger.info("validating starting performance") + vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True) + # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False) + self.tb.add_scalar("train_vae_success", vae_success, offset_idx + 1) + self.tb.add_scalar("train_vae_match", vae_match, offset_idx + 1) + self.tb.add_scalar("train_vae_bleu", vae_bleu, offset_idx + 1) + self.tb.add_scalar("train_vae_f1", vae_f1, offset_idx + 1) + + self.agent.critic.eval() + self.agent.cvae.train() + for i in trange(train_vae_nepisode, desc="VAE training"): + vae_loss = self.agent.train_vae_model(i) + if i % 2500 == 0: + vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True) + # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False) + self.tb.add_scalar("train_vae_success", vae_success, i + offset_idx) + self.tb.add_scalar("train_vae_match", vae_match, i + offset_idx) + self.tb.add_scalar("train_vae_bleu", vae_bleu, i + offset_idx) + self.tb.add_scalar("train_vae_f1", vae_f1, i + offset_idx) + + self.tb.add_scalar("vae_loss", vae_loss, (i+1) + (self.critic_config.nepisode/self.critic_config.train_vae_freq) * (epoch_id)) + + vae_epoch_id += 1 + + # critic_training + self.agent.critic.train() + self.agent.cvae.eval() + # tic = time.perf_counter() + plas_report, losses = self.agent.train_critic(verbose=n%10==0, + max_words=self.critic_config.max_words, + temp=self.critic_config.temperature, + debug=n%500==0, + n=(n+1) + (self.critic_config.nepisode) * (epoch_id)) + # toc = time.perf_counter() + # print(f"One batch forward pass in {toc - tic:0.4f} seconds") + + + + if n % 20 == 0: + for k, v in losses.items(): + self.tb.add_scalar(k, v, (self.critic_config.nepisode) * (epoch_id) + (n+1)) + + + if n % self.critic_config.record_freq == 0: + self.agent.critic.eval() + + print('-'*15, 'Recording start', '-'*15) + # save train reward + + # PPL & reward on validation + print("Running validation on valid set") + # tic = time.perf_counter() + if not self.agent.raw_response: + v_success, v_match, v_bleu, v_f1, v_Q = self.generate_func(self.agent.cvae, self.val_data, self.sv_config,self.critic_config, self.evaluator, None, verbose=False, critic=self.agent.critic, actor = self.agent.actor) + val_critic_error = v_Q - v_success + # print(v_success, v_match, v_bleu, v_f1, v_Q, val_critic_error) + self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_success, v_match, v_Q)) + self.tb.add_scalar("validation_critic_mse", val_critic_error, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + else: + v_Q = self.generate_func(self.val_data, self.agent, None) + self.rl_val_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_Q)) + # toc = time.perf_counter() + # print(f"One validation pass in {toc - tic:0.4f} seconds") + self.rl_val_file.flush() + + # self.tb.add_scalar("validation_success", v_success, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("validation_match", v_match, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("validation_bleu", v_bleu, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("validation_f1", v_f1, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + + print("Running validation on test set") + if not self.agent.raw_response: + t_success, t_match, t_bleu, t_f1, t_Q = self.generate_func(self.agent.cvae, self.test_data, self.sv_config, self.critic_config, self.evaluator, None, verbose=False, critic=self.agent.critic, actor = self.agent.actor) + test_critic_error = t_Q - t_success + # print(t_success, t_match, t_bleu, t_f1, t_Q, test_critic_error) + self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_success, t_match, t_Q)) + self.tb.add_scalar("test_critic_mse", test_critic_error, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + else: + t_Q = self.generate_func(self.test_data, self.agent, None) + self.rl_test_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_Q)) + + self.rl_test_file.flush() + # self.tb.add_scalar("test_success", t_success, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("test_match", t_match, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("test_bleu", t_bleu, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + # self.tb.add_scalar("test_f1", t_f1, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + + # always save critic as we don't have a good validation metric + print("Critic saved after {} batches".format(epoch_id * self.critic_config.nepisode + n)) + th.save(self.agent.critic.state_dict(), self.critic_config.critic_model_path) + # best_valid_loss = test_critic_error + + + # for name, weight in self.agent.named_parameters(): + # tb.add_histogram(name, weight, n + 1 ) + # tb.add_histogram(f'{name}.grad', weight.grad, n + 1) + + self.agent.critic.train() + print('-'*15, 'Recording end', '-'*15) + + except KeyboardInterrupt: + print("Critic training stopped from keyboard") + self.tb.close() + + print("$$$ Load {}-model".format(self.critic_config.critic_model_path)) + self.sv_config.batch_size = 32 + self.agent.critic.load_state_dict(th.load(self.critic_config.critic_model_path)) + + if not self.agent.raw_response: + + with open(os.path.join(self.critic_config.record_path, 'train_file.txt'), 'w') as f: + self.generate_func(self.agent.cvae, self.train_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor = self.agent.actor) + + with open(os.path.join(self.critic_config.record_path, 'valid_file.txt'), 'w') as f: + self.generate_func(self.agent.cvae, self.val_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor = self.agent.actor) + + with open(os.path.join(self.critic_config.record_path, 'test_file.txt'), 'w') as f: + self.generate_func(self.agent.cvae, self.test_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor=self.agent.actor) + + def run_behavior_policy(self): + # get exps, step, learn, act, validate + n = 0 + best_valid_loss = np.inf + vae_epoch_id = 0 + + try: + for epoch_id in trange(self.critic_config.nepoch, desc="Critic_epoch"): + for n in trange(self.critic_config.nepisode, desc="Critic_batch"): + # critic_training + self.agent.critic.train() + self.agent.cvae.eval() + # tic = time.perf_counter() + plas_report, losses = self.agent.train_critic(verbose=n%10==0, + max_words=self.critic_config.max_words, + temp=self.critic_config.temperature, + sl=not self.agent.corpus_response, + debug=n%500==0, + n=(n+1) + (self.critic_config.nepisode) * (epoch_id)) + + if n % 20 == 0: + for k, v in losses.items(): + self.tb.add_scalar(k, v, (self.critic_config.nepisode) * (epoch_id) + (n+1)) + + if n % self.critic_config.record_freq == 0: + self.agent.critic.eval() + + print('-'*15, 'Recording start', '-'*15) + # save train reward + + # PPL & reward on validation + print("Running validation on valid set") + v_Q = self.generate_func(self.val_data, self.agent, None) + self.rl_val_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_Q)) + self.rl_val_file.flush() + self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + + print("Running validation on test set") + t_Q = self.generate_func(self.test_data, self.agent, None) + self.rl_test_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_Q)) + self.rl_test_file.flush() + self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id)) + + # always save critic as we don't have a good validation metric + print("Critic saved after {} batches".format(epoch_id * self.critic_config.nepisode + n)) + th.save(self.agent.critic.state_dict(), self.critic_config.critic_model_path) + + self.agent.critic.train() + print('-'*15, 'Recording end', '-'*15) + + except KeyboardInterrupt: + print("Critic training stopped from keyboard") + self.tb.close() + + print("$$$ Load {}-model".format(self.critic_config.critic_model_path)) + self.sv_config.batch_size = 32 + self.agent.critic.load_state_dict(th.load(self.critic_config.critic_model_path)) + + +def validate_rl(dialog_eval, ctx_gen, num_episode=200): + print("Validate on training goals for {} episode".format(num_episode)) + reward_list = [] + agree_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + for _ in range(num_episode): + ctxs = ctx_gen.sample() + conv, agree, rewards = dialog_eval.run(ctxs) + true_reward = rewards[0] if agree else 0 + reward_list.append(true_reward) + agree_list.append(float(agree if agree is not None else 0.0)) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + results = {'sys_rew': np.average(reward_list), + 'avg_agree': np.average(agree_list), + 'sys_sent_unique': sent_metric.value(), + 'sys_unique': word_metric.value()} + return results + +def train_single_batch(model, train_data, config): + batch_cnt = 0 + optimizer = model.get_optimizer(config, verbose=False) + model.train() + + # decoding CE + train_data.epoch_init(config, shuffle=True, verbose=False) + for i in range(16): + batch = train_data.next_batch() + if batch is None: + train_data.epoch_init(config, shuffle=True, verbose=False) + batch = train_data.next_batch() + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + +def task_train_single_batch(model, train_data, config): + batch_cnt = 0 + optimizer = model.get_optimizer(config, verbose=False) + model.train() + + # decoding CE + train_data.epoch_init(config, shuffle=True, verbose=False) + for i in range(16): + batch = train_data.next_batch() + if batch is None: + train_data.epoch_init(config, shuffle=True, verbose=False) + batch = train_data.next_batch() + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + +def train(model, train_data, val_data, test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + num_noised_tokens = {} + + logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('***** Epoch 0/{} *****'.format(config.max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + + if batch is None: + break + + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + train_loss.add_loss(loss) + + if batch_cnt % config.print_step == 0: + # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(train_loss.pprint('Train', + window=config.print_step, + prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w))) + sys.stdout.flush() + + if batch_cnt % config.ckpt_step == 0: + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('==== Evaluating Model ====') + logger.info(train_loss.pprint('Train')) + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + # generation + if gen is not None: + gen(model, train_data, config, evaluator, num_batch=config.preview_batch_num) + + # validation + valid_loss = validate(model, val_data, config, batch_cnt) + _ = validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if train_data.noise_type is not None: + num_noised_tokens[done_epoch-1] = (train_data.num_noised_tokens, + train_data.num_noised_tokens / train_data.num_examples, + train_data.num_examples, + train_data.noised_tokens_dist) + train_data.num_noised_tokens = 0 + train_data.noised_tokens_dist = defaultdict(int) + train_data.num_examples = 0 + + + if done_epoch >= config.max_epoch \ + or config.early_stop and patience <= done_epoch: + + file_path = os.path.join(config.saved_path, "tokens_noised_count.json") + with open(file_path, "w+") as json_file: + json.dump(num_noised_tokens, json_file, indent=4) + + if done_epoch < config.max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + logger.info('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + # exit eval model + model.train() + train_loss.clear() + logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch)) + sys.stdout.flush() + + batch_cnt += 1 + +def mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + + logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('***** Epoch 0/{} *****'.format(config.max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + if batch is None: + break + + optimizer.zero_grad() + loss = model(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + batch_cnt += 1 + train_loss.add_loss(loss) + + if batch_cnt % config.print_step == 0: + # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(train_loss.pprint('Train', + window=config.print_step, + prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w))) + sys.stdout.flush() + + + + if batch_cnt % config.ckpt_step == 0: + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('==== Evaluating Model ====') + logger.info(train_loss.pprint('Train')) + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + # generation + if gen is not None: + gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num) + gen(model, aux_val_data, config, evaluator, num_batch=config.preview_batch_num, aux_mt=True) + + # validation + valid_loss = validate(model, val_data, config, batch_cnt) + _ = validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if not model.shared_train: + if done_epoch % config.aux_train_freq == 0: + model.train() + train_aux(model, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=gen) + + + if done_epoch >= config.max_epoch \ + or config.early_stop and patience <= done_epoch: + if done_epoch < config.max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + print('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + + + # exit eval model + model.train() + train_loss.clear() + + logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch)) + sys.stdout.flush() + +def train_aux(model, train_data, val_data, test_data, config, evaluator, gen=None): + patience = 10 + valid_loss_threshold = np.inf + best_valid_loss = np.inf + batch_cnt = 0 + optimizer = model.get_optimizer(config) + done_epoch = 0 + best_epoch = 0 + train_loss = LossManager() + model.train() + # logger.info(summary(model, show_weights=False)) + saved_models = [] + last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5 + + logger.info('+++++ Aux Training Begins at {} +++++'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('+++++ Epoch 0/{} +++++'.format(config.aux_max_epoch)) + while True: + train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch) + while True: + batch = train_data.next_batch() + if batch is None: + break + + optimizer.zero_grad() + loss = model.forward_aux(batch, mode=TEACH_FORCE) + model.backward(loss, batch_cnt) + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + batch_cnt += 1 + train_loss.add_loss(loss) + + if batch_cnt % config.ckpt_step == 0: + done_epoch += 1 + logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch)) + + logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info('++++ Evaluating Model ++++') + logger.info(train_loss.pprint('Aux train')) + + # generation + if gen is not None: + gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num, aux_mt=True) + + # validation + valid_loss = aux_validate(model, val_data, config, batch_cnt) + _ = aux_validate(model, test_data, config, batch_cnt) + + # update early stopping stats + if valid_loss < best_valid_loss: + if valid_loss <= valid_loss_threshold * config.improve_threshold: + patience = max(patience, done_epoch*config.patient_increase) + valid_loss_threshold = valid_loss + logger.info('Update patience to {}'.format(patience)) + + if config.save_model: + cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S") + logger.info('!!New best model with loss = {},at {}.'.format(valid_loss, cur_time)) + th.save(model.state_dict(), os.path.join(config.saved_path, 'aux-{}-model'.format(done_epoch))) + best_epoch = done_epoch + saved_models.append(done_epoch) + if len(saved_models) > last_n_model: + remove_model = saved_models[0] + saved_models = saved_models[-last_n_model:] + os.remove(os.path.join(config.saved_path, "aux-{}-model".format(remove_model))) + + best_valid_loss = valid_loss + + if done_epoch >= config.aux_max_epoch \ + or config.early_stop and patience <= done_epoch: + if done_epoch < config.aux_max_epoch: + logger.info('!!!!! Early stop due to run out of patience !!!!!') + print('Best validation loss = %f' % (best_valid_loss, )) + return best_epoch + + # exit eval model + model.train() + train_loss.clear() + # logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.aux_max_epoch)) + sys.stdout.flush() + +def validate(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + if use_py is not None: + # loss = model(batch, mode=TEACH_FORCE, use_py=use_py) + loss = model(batch, mode=TEACH_FORCE) + else: + loss = model(batch, mode=TEACH_FORCE) + + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def validate_offlinerl(agent, evaluator, val_data, config, mode="valid", batch_cnt=None, use_py=None): + + """ + get the true reward and q-reward of the dataset + """ + agent.model.eval() + agent.critic.eval() + val_data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True) + criterion = th.nn.MSELoss() + valid_rewards = [] + valid_match = [] + while True: + batch = val_data.next_batch() + if batch is None: + break + + with th.no_grad(): + actions, task_report, success, match = agent.run(batch, evaluator) + + valid_rewards.append(success) + valid_match.append(success) + + return valid_rewards, valid_match + +def debug(model, val_data, config, n_z=1, batch_cnt=None, use_py=None): + model.train() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + + de_tknize = get_detokenize() + while True: + batch = val_data.next_batch() + if batch is None: + break + + batch_size = batch['bs'].shape[0] + all_preds = defaultdict(list) + true_str = [] + zs, joint_logpz = model.sample_z(batch, n_z=n_z) + for i in range(n_z): + # true_labels = batch['outputs'].data.numpy() # (batch_size, output_seq_len) + + logprobs, outputs = model.decode_z(zs[i], config.batch_size, max_words=config.max_dec_len) + # move from GPU to CPU + + for b_id in range(len(outputs)): + pred_str = get_sent_list(model.vocab, de_tknize, outputs[b_id]) + all_preds[b_id].append(pred_str) + if i == 0: + true_str.append(get_sent(model.vocab, de_tknize, batch['outputs'], b_id)) + + for i in range(batch_size): + print('True: {}'.format(true_str[i], )) + for n in range(n_z): + print('Pred{}: {}'.format(n, all_preds[i][n])) + print('='*30) + +def validate_mt(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model(batch, mode=TEACH_FORCE) + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def aux_validate(model, val_data, config, batch_cnt=None, use_py=None): + model.eval() + val_data.epoch_init(config, shuffle=False, verbose=False) + losses = LossManager() + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model.forward_aux(batch, mode=TEACH_FORCE) + + losses.add_loss(loss) + losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt)) + + valid_loss = losses.avg_loss() + # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S"))) + logger.info(losses.pprint(val_data.name)) + logger.info('Total aux valid loss = {}'.format(valid_loss)) + sys.stdout.flush() + return valid_loss + +def generate(model, data, config, evaluator, num_batch, dest_f=None): + + def write(msg): + if msg is None or msg == '': + return + if dest_f is None: + print(msg) + else: + dest_f.write(msg + '\n') + + model.eval() + de_tknize = get_detokenize() + data.epoch_init(config, shuffle=num_batch is not None, verbose=False) + evaluator.initialize() + logger.info('Generation: {} batches'.format(data.num_batch + if num_batch is None + else num_batch)) + batch_cnt = 0 + print_cnt = 0 + while True: + batch_cnt += 1 + batch = data.next_batch() + if batch is None or (num_batch is not None and data.ptr > num_batch): + break + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.data.numpy() # (batch_size, output_seq_len) + + # get attention if possible + if config.dec_use_attn: + pred_attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]] + pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes(0, 1) # (batch_size, max_dec_len, max_ctx_len) + else: + pred_attns = None + # get context + ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) + ctx_len = batch.get('context_lens') # (batch_size, ) + + for b_id in range(pred_labels.shape[0]): + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + prev_ctx = '' + if ctx is not None: + ctx_str = [] + for t_id in range(ctx_len[b_id]): + temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) + # print('temp_str = %s' % (temp_str, )) + # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) + ctx_str.append(temp_str) + ctx_str = '|'.join(ctx_str)[-200::] + prev_ctx = 'Source context: {}'.format(ctx_str) + + evaluator.add_example(true_str, pred_str) + + if num_batch is None or batch_cnt < 2: + print_cnt += 1 + write('prev_ctx = %s' % (prev_ctx, )) + write('True: {}'.format(true_str, )) + write('Pred: {}'.format(pred_str, )) + write('='*30) + if num_batch is not None and print_cnt > 10: + break + + write(evaluator.get_report()) + # write(evaluator.get_groundtruth_report()) + write('Generation Done') + +def get_sent(vocab, de_tknize, data, b_id, stop_eos=True, stop_pad=True): + ws = [] + for t_id in range(data.shape[1]): + w = vocab[int(data[b_id, t_id])] + if (stop_eos and w == EOS) or (stop_pad and w == PAD): + break + if w != PAD: + ws.append(w) + + return de_tknize(ws) + +def get_sent_list(vocab, de_tknize, data, stop_eos=True, stop_pad=True): + ws = [] + for t_id in range(len(data)): + w = vocab[data[t_id]] + if (stop_eos and w == EOS) or (stop_pad and w == PAD): + break + if w != PAD: + ws.append(w) + + return de_tknize(ws) + +def most_frequent(List): + occ_count = Counter(List) + return occ_count.most_common(1)[0][0] + +def generate_with_name(model, data, config): + model.eval() + de_tknize = get_detokenize() + data.epoch_init(config, shuffle=False, verbose=False) + logger.info('Generation With Name: {} batches.'.format(data.num_batch)) + + from collections import defaultdict + res = defaultdict(dict) + while True: + batch = data.next_batch() + if batch is None: + break + keys, outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + true_labels = labels.cpu().data.numpy() # (batch_size, output_seq_len) + + for b_id in range(pred_labels.shape[0]): + pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) + true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) + dlg_name, dlg_turn = keys[b_id] + res[dlg_name][dlg_turn] = {'pred': pred_str, 'true': true_str} + + return res diff --git a/latent_dialog/metric.py b/latent_dialog/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..433c19e369fd7c8274cd1629a053137278f7079b --- /dev/null +++ b/latent_dialog/metric.py @@ -0,0 +1,151 @@ +import time +from collections import OrderedDict + + +class NumericMetric(object): + """Base class for a numeric metric.""" + def __init__(self): + self.k = 0 + self.n = 0 + + def reset(self): + pass + + def record(self, k, n=1): + self.k += k + self.n += n + + def value(self): + self.n = max(1, self.n) + return 1.0 * self.k / self.n + + +class AverageMetric(NumericMetric): + """Average.""" + def show(self): + return '%.2f' % (1. * self.value()) + + +class PercentageMetric(NumericMetric): + """Percentage.""" + def show(self): + return '%2.2f%%' % (100. * self.value()) + + +class TimeMetric(object): + """Time based metric.""" + def __init__(self): + self.t = 0 + self.n = 0 + + def reset(self): + self.last_t = time.time() + + def record(self, n=1): + self.t += time.time() - self.last_t + self.n += 1 + + def value(self): + self.n = max(1, self.n) + return 1.0 * self.t / self.n + + def show(self): + return '%.3fs' % (1. * self.value()) + + +class UniquenessMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + + def reset(self): + pass + + def record(self, sen): + self.seen.add(' '.join(sen)) + + def value(self): + return len(self.seen) + + def show(self): + return str(self.value()) + + +class TextMetric(object): + """Text based metric.""" + def __init__(self, text): + self.text = text + self.k = 0 + self.n = 0 + + def reset(self): + pass + + def value(self): + self.n = max(1, self.n) + return 1. * self.k / self.n + + def show(self): + return '%.2f' % (1. * self.value()) + + +class NGramMetric(TextMetric): + """Metric that evaluates n grams.""" + def __init__(self, text, ngram=-1): + super(NGramMetric, self).__init__(text) + self.ngram = ngram + + def record(self, sen): + n = len(sen) if self.ngram == -1 else self.ngram + for i in range(len(sen) - n + 1): + self.n += 1 + target = ' '.join(sen[i:i + n]) + if self.text.find(target) != -1: + self.k += 1 + + +class MetricsContainer(object): + """A container that stores and updates several metrics.""" + def __init__(self): + self.metrics = OrderedDict() + + def _register(self, name, ty, *args, **kwargs): + name = name.lower() + assert name not in self.metrics + self.metrics[name] = ty(*args, **kwargs) + + def register_average(self, name, *args, **kwargs): + self._register(name, AverageMetric, *args, **kwargs) + + def register_time(self, name, *args, **kwargs): + self._register(name, TimeMetric, *args, **kwargs) + + def register_percentage(self, name, *args, **kwargs): + self._register(name, PercentageMetric, *args, **kwargs) + + def register_ngram(self, name, *args, **kwargs): + self._register(name, NGramMetric, *args, **kwargs) + + def register_uniqueness(self, name, *args, **kwargs): + self._register(name, UniquenessMetric, *args, **kwargs) + + def record(self, name, *args, **kwargs): + name = name.lower() + assert name in self.metrics + self.metrics[name].record(*args, **kwargs) + + def reset(self): + for m in self.metrics.values(): + m.reset() + + def value(self, name): + return self.metrics[name].value() + + def show(self): + return ' '.join(['%s=%s' % (k, v.show()) for k, v in self.metrics.iteritems()]) + + def dict(self): + d = OrderedDict() + for k, v in self.metrics.items(): + d[k] = v.show() + return d diff --git a/latent_dialog/models_task.py b/latent_dialog/models_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e08c7e1bdd8f597a7c7409d24351e16f0dfaf983 --- /dev/null +++ b/latent_dialog/models_task.py @@ -0,0 +1,3382 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from latent_dialog.base_models import BaseModel, frange_cycle_linear +from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS +from latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type +from latent_dialog.enc2dec.encoders import RnnUttEncoder +from latent_dialog.enc2dec.decoders import DecoderRNN, GEN, TEACH_FORCE +from latent_dialog.criterions import NLLEntropy, CatKLLoss, Entropy, NormKLLoss, GaussianEntropy +from latent_dialog import nn_lib +import numpy as np +import pdb +import json + + +class SysPerfectBD2Word(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Word, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.policy = nn.Sequential(nn.Linear(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.dec_cell_size), nn.Tanh(), nn.Dropout(config.dropout)) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=self.utt_encoder.output_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # pack attention context + if self.config.dec_use_attn: + attn_context = enc_outs + else: + attn_context = None + + # create decoder initial states + dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) + + # decode + if self.config.dec_rnn_cell == 'lstm': + # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0) + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + return ret_dict, labels + if return_latent: + return Pack(nll=self.nll(dec_outputs, labels), + latent_action=dec_init_state) + else: + return Pack(nll=self.nll(dec_outputs, labels)) + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # pack attention context + if self.config.dec_use_attn: + attn_context = enc_outs + else: + attn_context = None + + # create decoder initial states + dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + +class SysPerfectBD2Cat(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Cat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if "policy_dropout" in config and config.policy_dropout: + if "policy_dropout_rate" in config: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + if not self.simple_posterior: + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + if "state_for_decoding" not in self.config: + self.state_for_decoding = False + else: + self.state_for_decoding = self.config.state_for_decoding + + if self.state_for_decoding: + dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size + else: + dec_hidden_size = config.dec_cell_size + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=dec_hidden_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + if "kl_annealing" in self.config and config.kl_annealing=="cyclical": + self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) + else: + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if isinstance(self.beta, float): + beta = self.beta + else: + if batch_cnt == None: + beta = self.beta[-1] + else: + beta = self.beta[int(batch_cnt)] + + + if self.simple_posterior or "kl_annealing" in self.config: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) # p(z|c) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=True) + else: + sample_y = self.gumbel_connector(logits_qy, hard=False) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + # if np.random.rand() < epsilon: # greedy exploration + # print("randomly sampling latent") + # idx = th.multinomial(th.cuda.FloatTensor(qy.shape).uniform_(), 1) + # else: # normal latent sampling + idx = th.multinomial(qy, 1).detach() + + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + + def sample_z(self, data_feed, n_z=1, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + zs = [] + logpzs = [] + for i in range(n_z): + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + zs.append(sample_y) + logpzs.append(joint_logpz) + + + return th.stack(zs), th.stack(logpzs) + + def categorical_logprob(logits_py, sample_z): + pdb.set_trace() + py = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_py = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + idx = th.multinomial(py, 1).detach() + logprob_sample_z = log_py.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + return joint_logpz + + def encode_state(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + return enc_last + + def get_z_via_rg(self, data_feed, hard=False): + enc_last = self.encode_state(data_feed) + logits_qy, log_qy = self.c2z(enc_last) + aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) + + return aux_sample_z, log_qy, logits_qy + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + + if data_feed: + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + # pack attention context + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.state_for_decoding: + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + + #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) + #attn_context = self.np2var(attn_context, FLOAT) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # has to be forward_rl because we don't have the golden target + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +class SysEncodedBD2Cat(BaseModel): + def __init__(self, corpus, config): + super(SysEncodedBD2Cat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.k_size = config.k_size + self.y_size = config.y_size + self.config = config + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if config.use_metadata_for_decoding: + self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=int(config.embed_size / 2), + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=int(config.dec_cell_size / 2), + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if "policy_dropout" in config and config.policy_dropout: + if "policy_dropout_rate" in config: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + + else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + if not self.simple_posterior: + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + if "state_for_decoding" not in self.config: + self.state_for_decoding = False + else: + self.state_for_decoding = self.config.state_for_decoding + + dec_hidden_size = config.dec_cell_size + if config.use_metadata_for_decoding: + if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat": + dec_hidden_size += self.metadata_encoder.output_size + if self.state_for_decoding: + dec_hidden_size += self.utt_encoder.output_size + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=dec_hidden_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + + if "kl_annealing" in self.config and config.kl_annealing=="cyclical": + self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) + else: + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + + + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if isinstance(self.beta, float): + beta = self.beta + else: + if batch_cnt == None: + beta = self.beta[-1] + else: + beta = self.beta[int(batch_cnt % self.config.n_iter)] + + if self.simple_posterior or "kl_annealing" in self.config: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def extract_short_ctx(self, data_feed): + utts = [] + ctx_lens = data_feed['context_lens'] # (batch_size, ) + context = data_feed['contexts'] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(context)): + utt = [] + for t_id in range(ctx_lens[b_id]): + utt.extend(context[b_id][t_id]) + try: + utt.extend(bs[b_id] + db[b_id]) + except: + pdb.set_trace() + utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True)) + return np.array(utts) + + def extract_metadata(self, data_feed): + utts = [] + bs = data_feed['bs'] + db = data_feed['db'] + if not isinstance(bs, list): + bs = data_feed['bs'].tolist() + db = data_feed['db'].tolist() + + for b_id in range(len(bs)): + utt = [] + if "metadata_db_only" in self.config and self.config.metadata_db_only: + utt.extend(db[b_id]) + else: + utt.extend(bs[b_id] + db[b_id]) + utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True)) + return np.array(utts) + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last.squeeze(), x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=True) + else: + sample_y = self.gumbel_connector(logits_qy, hard=False) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + if "metadata_to_decoder" in self.config: + if self.config.metadata_to_decoder == "add": + dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) + elif self.config.metadata_to_decoder == "avg": + dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) + else: + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + else: + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + enc_last = utt_summary.unsqueeze(1) + # create decoder initial states + logits_py, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + if self.config.use_metadata_for_decoding: + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1)) + if "metadata_to_decoder" in self.config: + if self.config.metadata_to_decoder == "add": + dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1) + elif self.config.metadata_to_decoder == "avg": + dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0) + else: + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + else: + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + + def sample_z(self, data_feed, n_z=1, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db + # metadata = self.np2var(self.extract_metadata(data_feed), LONG) + # out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) + + + # create decoder initial states + enc_last = utt_summary.unsqueeze(1) + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + zs = [] + logpzs = [] + for i in range(n_z): + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + zs.append(sample_y) + logpzs.append(joint_logpz) + + + return th.stack(zs), th.stack(logpzs) + + def decode_z(self, sample_y, batch_size, max_words=None, temp=1.0, gen_type='greedy'): + """ + generate response from latent var + """ + # pack attention context + metadata = self.np2var(self.extract_metadata(data_feed), LONG) + metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1)) + + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2) + + if self.config.use_metadata_for_decoding: + raise NotImplementedError + + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # has to be forward_rl because we don't have the golden target + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + +class SysAECat(BaseModel): + def __init__(self, corpus, config): + super(SysAECat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = False # does not use context cause AE task + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + if "ae_zero_padding" in self.config and self.config.ae_zero_padding: + # self.use_metadata = self.config.use_metadata + self.ae_zero_padding = self.config.ae_zero_padding + c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size + else: + # self.use_metadata = False + self.ae_zero_padding = False + c2z_input_size = self.utt_encoder.output_size + + + if "policy_dropout" in config and config.policy_dropout: + self.c2z = nn_lib.Hidden2DiscretewDropout(c2z_input_size, + config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + else: + self.c2z = nn_lib.Hidden2Discrete(c2z_input_size, + config.y_size, config.k_size, is_lstm=False) + + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + # if not self.simple_posterior: #q(z|x,c) + # if self.contextual_posterior: + # # x, c, BS, and DB + # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + # config.y_size, config.k_size, is_lstm=False) + # else: + # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + if self.ae_zero_padding: + enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) + else: + enc_last = utt_summary.squeeze(1) + + + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + # else: + # logits_py, log_py = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or (use_py is not None and use_py is True): + # sample_y = self.gumbel_connector(logits_py, hard=False) + # else: + # sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def get_z_via_vae(self, data_feed, hard=False): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.utt_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + logits_qy, log_qy = self.c2z(aux_enc_last) + aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) + + return aux_sample_z, logits_qy, log_qy + +class SysMTCat(BaseModel): + def __init__(self, corpus, config): + super(SysMTCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + if "shared_train" in config: + self.shared_train = config.shared_train + else: + self.shared_train = False + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + self.aux_pi_beta = self.config.aux_pi_beta if hasattr(self.config, 'aux_pi_beta') else 1.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.shared_train: + if "selective_fine_tune" in self.config and self.config.selective_fine_tune: + total_loss = loss.nll + self.config.beta * loss.aux_pi_kl + else: + total_loss = loss.nll + loss.ae_nll + self.config.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl + else: + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + + return total_loss + + def encode_state(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + return enc_last + + def encode_action(self, data_feed): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = aux_utt_summary.squeeze(1) + + return aux_enc_last + + def get_z_via_vae(self, data_feed, hard=False): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + logits_qy, log_qy = self.c2z(aux_enc_last) + aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) + + return aux_sample_z + + def get_z_via_rg(self, data_feed, hard=False): + enc_last = self.encode_state(data_feed) + logits_qy, log_qy = self.c2z(enc_last) + aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) + + return aux_sample_z, log_qy, logits_qy + + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + + # if data_feed: + # ctx_lens = data_feed['context_lens'] # (batch_size, ) + # short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + # bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + # pack attention context + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + # if self.state_for_decoding: + # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + + #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) + #attn_context = self.np2var(attn_context, FLOAT) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # has to be forward_rl because we don't have the golden target + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + return logprobs, outs + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # how to use z, alone or in combination with bs and db + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=False) + sample_y_discrete = self.gumbel_connector(logits_qy, hard=True) + log_py = self.log_uniform_y + # else: + # logits_py, log_py = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or (use_py is not None and use_py is True): + # sample_y = self.gumbel_connector(logits_py, hard=False) + # else: + # sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + if self.shared_train: + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + aux_sample_y = self.gumbel_connector(aux_logits_qy, hard=mode==GEN) + + log_py = self.log_uniform_y + else: + logits_py, log_py = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=False) + else: + sample_y = self.gumbel_connector(logits_qy, hard=True) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + + if self.shared_train: + if self.config.dec_use_attn: + aux_z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + aux_attn_context = [] + aux_temp_sample_y = aux_sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + aux_attn_context.append(th.mm(aux_temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + aux_attn_context = th.cat(aux_attn_context, dim=1) + aux_dec_init_state = th.sum(aux_attn_context, dim=1).unsqueeze(0) + else: + aux_dec_init_state = self.z_embedding(aux_sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + aux_attn_context = None + if self.config.dec_rnn_cell == 'lstm': + aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state]) + + + + # decode + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + if self.shared_train: + ae_dec_outputs, ae_dec_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=aux_dec_init_state, # tuple: (h, c) + attn_context=aux_attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + result['ae_nll'] = self.nll(ae_dec_outputs, labels) + aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True) + aux_kl = self.cat_kl_loss(aux_log_qy, log_py, batch_size, unit_average=True) + result['aux_pi_kl'] = aux_pi_kl + result['aux_kl'] = aux_kl + + + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + return result + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +class SysActZCat(BaseModel): + def __init__(self, corpus, config): + super(SysActZCat, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.k_size = config.k_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context + self.contextual_posterior = config.contextual_posterior # does not use context cause AE task + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + if "use_aux_c2z" in config: + self.use_aux_c2z = config.use_aux_c2z + else: + self.use_aux_c2z = False + + + + self.embedding = None + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + # if "policy_dropout" in config and config.policy_dropout: + # self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size, + # config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval) + # else: + self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, config.k_size, is_lstm=False) + if self.use_aux_c2z: + self.aux_c2z = nn_lib.Hidden2Discrete(self.aux_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + + self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False) + self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu) + + if not self.simple_posterior: #q(z|x,c) + if self.contextual_posterior: + # x, c, BS, and DB + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, + config.y_size, config.k_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False) + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + self.cat_kl_loss = CatKLLoss() + self.entropy_loss = Entropy() + self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size)) + self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0)) + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + if self.use_gpu: + self.log_uniform_y = self.log_uniform_y.cuda() + self.eye = self.eye.cuda() + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + if self.use_aux_c2z: + aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1)) + else: + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = aux_log_qy + else: + logits_py, log_py = self.c2z(enc_last) + aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or (use_py is not None and use_py is True): + sample_y = self.gumbel_connector(logits_py, hard=True) + else: + sample_y = self.gumbel_connector(logits_qy, hard=False) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + # (batch_size, response_size-1) + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + # (batch_size, max_ctx_len, ctx_cell_size) + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + logits_qy, log_qy = self.c2z(enc_last) + sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) + log_py = self.log_uniform_y + # else: + # p_mu, p_logvar = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or use_py: + # sample_z = self.gauss_connector(p_mu, p_logvar) + # else: + # sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_y + ret_dict['log_qy'] = log_qy + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + # regularization qy to be uniform + avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) + avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) + b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) + mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) + pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) + q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b + p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) + + result['pi_kl'] = pi_kl + result['diversity'] = th.mean(p) + result['nll'] = self.nll(dec_outputs, labels) + result['b_pr'] = b_pr + result['mi'] = mi + result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True) + return result + + def get_z_via_vae(self, data_feed, hard=False): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + logits_qy, log_qy = self.c2z(aux_enc_last) + aux_sample_z = self.gumbel_connector(logits_qy, hard=hard) + + + return aux_sample_z, logits_qy, log_qy + + def encode_state(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + return enc_last + + def encode_action(self, data_feed): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = aux_utt_summary.squeeze(1) + + return aux_enc_last + + def categorical_logprob(self, logits_py, sample_z, temp=0.1): + py = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_py = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + idx = th.multinomial(sample_z, 1).detach() + logprob_sample_z = log_py.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + return joint_logpz + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + # pack attention context + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_y + + def sample_z(self, data_feed, n_z=1, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + if self.simple_posterior: + logits_py, log_qy = self.c2z(enc_last) + else: + logits_py, log_qy = self.c2z(enc_last) + + qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) + + zs = [] + logpzs = [] + for i in range(n_z): + idx = th.multinomial(qy, 1).detach() + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + joint_logpz = th.sum(logprob_sample_z, dim=1) + sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_y.scatter_(1, idx, 1.0) + + zs.append(sample_y) + logpzs.append(joint_logpz) + + + return th.stack(zs), th.stack(logpzs) + + def get_z_via_rg(self, data_feed, hard=False): + enc_last = self.encode_state(data_feed) + logits_qy, log_qy = self.c2z(enc_last) + sample_z = self.gumbel_connector(logits_qy, hard=hard) + + return sample_z, log_qy, logits_qy + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + # pack attention context + + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + if self.config.dec_use_attn: + z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) + attn_context = [] + temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) + for z_id in range(self.y_size): + attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) + attn_context = th.cat(attn_context, dim=1) + dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) + else: + dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) + attn_context = None + + + # decode + + #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0) + #attn_context = self.np2var(attn_context, FLOAT) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # has to be forward_rl because we don't have the golden target + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=temp) + + return logprobs, outs + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +class SysPerfectBD2Gauss(BaseModel): + def __init__(self, corpus, config): + super(SysPerfectBD2Gauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + if "contextual posterior" in config: + self.contextual_posterior = config.contextual_posterior + else: + self.contextual_posterior = True # default value is true, i.e. q(z|x,c) + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + if "state_for_decoding" not in self.config: + self.state_for_decoding = False + else: + self.state_for_decoding = self.config.state_for_decoding + + if self.state_for_decoding: + dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size + else: + dec_hidden_size = config.dec_cell_size + + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=dec_hidden_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + p_mu, p_logvar = self.c2z(enc_last) + + sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() + logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_z + + def get_z_via_rg(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + q_mu, q_logvar = self.c2z(enc_last) + + sample_z = self.gauss_connector(q_mu, q_logvar) + + return sample_z, q_mu, q_logvar + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + + if data_feed: + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + # pack attention context + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) + if (dec_init_state != dec_init_state).any(): + pdb.set_trace() + attn_context = None + + # decode + if self.state_for_decoding: + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + + return logprobs, outs + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + +class SysAEGauss(BaseModel): + def __init__(self, corpus, config): + super(SysAEGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + # self.act_size = corpus.act_size + self.y_size = config.y_size + self.simple_posterior = True + self.contextual_posterior = False + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + # if "use_metadata" in self.config and self.config.use_metadata: + if "ae_zero_padding" in self.config and self.config.ae_zero_padding: + # self.use_metadata = self.config.use_metadata + self.ae_zero_padding = self.config.ae_zero_padding + c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size + else: + # self.use_metadata = False + self.ae_zero_padding = False + c2z_input_size = self.utt_encoder.output_size + + self.c2z = nn_lib.Hidden2Gaussian(c2z_input_size, + config.y_size, is_lstm=False) + + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + + if "kl_annealing" in self.config and config.kl_annealing=="cyclical": + if "n_iter" not in self.config: + config['n_iter'] = config.ckpt_step * config.max_epoch + self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10) + else: + self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0 + + def valid_loss(self, loss, batch_cnt=None): + if isinstance(self.beta, float): + beta = self.beta + else: + if batch_cnt == None: + beta = self.beta[-1] + else: + beta = self.beta[int(batch_cnt)] + + + if self.simple_posterior or "kl_annealing" in self.config: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + # print(short_ctx_utts[0]) + # print(out_utts[0]) + + + # create decoder initial states + # if self.use_metadata: + if self.ae_zero_padding: + enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) + else: + enc_last = utt_summary.squeeze(1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + # else: + # p_mu, p_logvar = self.c2z(enc_last) + # # encode response and use posterior to find q(z|x, c) + # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + # if self.contextual_posterior: + # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + # else: + # q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + # # use prior at inference time, otherwise use posterior + # if mode == GEN or use_py: + # sample_z = self.gauss_connector(p_mu, p_logvar) + # else: + # sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + # print(labels[0]) + # print("========") + # pdb.set_trace() + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def get_z_via_vae(self, responses): + batch_size = responses.shape[0] + aux_utt_summary, _, aux_enc_outs = self.utt_encoder(responses.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) + + return aux_sample_z, aux_q_mu, aux_q_logvar + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + +class SysMTGauss(BaseModel): + def __init__(self, corpus, config): + super(SysMTGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + if "shared_train" in config: + self.shared_train = config.shared_train + else: + self.shared_train = False + + if "use_aux_kl" in config: + self.use_aux_kl = config.use_aux_kl + else: + self.use_aux_kl = False + + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + # if self.shared_train: + # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, + # config.y_size, is_lstm=False) + + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + if "state_for_decoding" not in config: + self.state_for_decoding = False + else: + self.state_for_decoding = config.state_for_decoding + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + self.entropy_loss = GaussianEntropy() + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + self.aux_pi_beta = self.config.aux_pi_beta if hasattr(self.config, 'aux_pi_beta') else 1.0 + + def valid_loss(self, loss, batch_cnt=None): + if self.shared_train: + if "selective_fine_tune" in self.config and self.config.selective_fine_tune: + total_loss = loss.nll + self.config.beta * loss.aux_pi_kl + else: + total_loss = loss.nll + loss.ae_nll + self.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl + else: + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # aux_enc_last = aux_utt_summary.squeeze(1) + aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + # logprob_sample_z = self.gaussian_logprob(q_mu, self.zero, sample_z) + # joint_logpz = th.sum(logprob_sample_z, dim=1) + # pdb.set_trace() + + if self.shared_train: + # aux_q_mu, aux_q_logvar = self.aux_c2z(aux_enc_last) + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) + p_mu, p_logvar = self.zero, self.zero + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + if self.shared_train: + aux_dec_init_state = self.z_embedding(aux_sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + if self.shared_train: + aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + if self.shared_train: + ae_dec_outputs, ae_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=aux_dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + result['ae_nll'] = self.nll(ae_dec_outputs, labels) + aux_pi_kl = self.gauss_kl(q_mu, q_logvar, aux_q_mu, aux_q_logvar) + aux_kl = self.gauss_kl(aux_q_mu, aux_q_logvar, p_mu, p_logvar) + result['aux_pi_kl'] = aux_pi_kl + result['aux_kl'] = aux_kl + # result['aux_entropy'] = self.entropy_loss(aux_q_mu, aux_q_logvar) + + + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + # result['pi_entropy'] = self.entropy_loss(q_mu, q_logvar) + result['nll'] = self.nll(dec_outputs, labels) + return result + + def encode_state(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + return enc_last + + def encode_action(self, data_feed): + batch_size = data_feed.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = aux_utt_summary.squeeze(1) + + return aux_enc_last + + def get_z_via_vae(self, responses): + batch_size = responses.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) + + return aux_sample_z, aux_q_mu, aux_q_logvar + + def get_z_via_rg(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + q_mu, q_logvar = self.c2z(enc_last) + + sample_z = self.gauss_connector(q_mu, q_logvar) + + return sample_z, q_mu, q_logvar + + def gaussian_prob(self, mu, logvar, sample_z): + var = th.exp(logvar) + + den = th.sqrt(2 * np.pi * var) + po = - (th.pow((sample_z - mu), 2) / (2 * var)) + + prob = 1 / den * th.exp(po) + + return prob + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + return logprobs, outs, joint_logpz, sample_z + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + p_mu, p_logvar = self.c2z(enc_last) + + # sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() + sample_z = self.gauss_connector(p_mu, p_logvar) + logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_z + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + + if data_feed: + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + # pack attention context + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) + if (dec_init_state != dec_init_state).any(): + pdb.set_trace() + attn_context = None + + # decode + # if self.state_for_decoding: + # if not data_feed: + # raise ValueError + # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + + return logprobs, outs + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, sample_z=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + return result + +class SysActZGauss(BaseModel): + def __init__(self, corpus, config): + super(SysActZGauss, self).__init__(config) + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.y_size = config.y_size + self.simple_posterior = config.simple_posterior + self.contextual_posterior = config.contextual_posterior + + self.embedding = None + self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size, + embedding_dim=config.embed_size, + feat_size=0, + goal_nhid=0, + rnn_cell=config.utt_rnn_cell, + utt_cell_size=config.utt_cell_size, + num_layers=config.num_layers, + input_dropout_p=config.dropout, + output_dropout_p=config.dropout, + bidirectional=config.bi_utt_cell, + variable_lengths=False, + use_attn=config.enc_use_attn, + embedding=self.embedding) + + + self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, + # config.y_size, is_lstm=False) + + self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) + self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) + if not self.simple_posterior: + # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + # config.y_size, is_lstm=False) + if self.contextual_posterior: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, + config.y_size, is_lstm=False) + else: + self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) + + + self.decoder = DecoderRNN(input_dropout_p=config.dropout, + rnn_cell=config.dec_rnn_cell, + input_size=config.embed_size, + hidden_size=config.dec_cell_size, + num_layers=config.num_layers, + output_dropout_p=config.dropout, + bidirectional=False, + vocab_size=self.vocab_size, + use_attn=config.dec_use_attn, + ctx_cell_size=config.dec_cell_size, + attn_mode=config.dec_attn_mode, + sys_id=self.bos_id, + eos_id=self.eos_id, + use_gpu=config.use_gpu, + max_dec_len=config.max_dec_len, + embedding=self.embedding) + + self.nll = NLLEntropy(self.pad_id, config.avg_type) + if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty": + req_tokens = [] + for d in REQ_TOKENS.keys(): + req_tokens.extend(REQ_TOKENS[d]) + nll_weight = Variable(th.FloatTensor([10. if token in req_tokens else 1. for token in self.vocab])) + print("req tokens assigned with special weights") + if config.use_gpu: + nll_weight = nll_weight.cuda() + self.nll.set_weight(nll_weight) + + if "state_for_decoding" not in config: + self.state_for_decoding = False + else: + self.state_for_decoding = config.state_for_decoding + + + self.gauss_kl = NormKLLoss(unit_average=True) + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + def valid_loss(self, loss, batch_cnt=None): + if self.simple_posterior: + total_loss = loss.nll + if self.config.use_pr > 0.0: + total_loss += self.config.beta * loss.pi_kl + else: + total_loss = loss.nll + loss.pi_kl + + if self.config.use_mi: + total_loss += (loss.b_pr * self.beta) + + if self.config.use_diversity: + total_loss += loss.diversity + + if "match_z" in self.config and self.config.match_z: + total_loss += loss.z_mse + + return total_loss + + def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + short_target_utts = self.np2var(data_feed['outputs'], LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1) + # aux_enc_last = aux_utt_summary.squeeze(1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + # p_mu, p_logvar = self.aux_c2z(aux_enc_last) + p_mu, p_logvar = self.c2z(aux_enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + aux_sample_z = self.gauss_connector(p_mu, p_logvar) + else: + p_mu, p_logvar = self.c2z(enc_last) + # encode response and use posterior to find q(z|x, c) + x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) + if self.contextual_posterior: + q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) + else: + q_mu, q_logvar = self.xc2z(x_h.squeeze(1)) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + + # use prior at inference time, otherwise use posterior + if mode == GEN or use_py: + sample_z = self.gauss_connector(p_mu, p_logvar) + else: + sample_z = self.gauss_connector(q_mu, q_logvar) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + z_mse = F.mse_loss(aux_sample_z, sample_z) + result['pi_kl'] = pi_kl + result['z_mse'] = z_mse + # result['nll'] = self.nll(dec_outputs, labels) + return result + + def forward_rl(self, data_feed, max_words, temp=0.1): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # create decoder initial states + p_mu, p_logvar = self.c2z(enc_last) + + # sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() + sample_z = self.gauss_connector(p_mu, p_logvar) + logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) + joint_logpz = th.sum(logprob_sample_z, dim=1) + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.state_for_decoding: + dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + return logprobs, outs, joint_logpz, sample_z + + def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + # act_label = self.np2var(data_feed['act'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + batch_size = len(ctx_lens) + + utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1)) + + # get decoder inputs + dec_inputs = out_utts[:, :-1] + labels = out_utts[:, 1:].contiguous() + + # create decoder initial states + enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1) + + # create decoder initial states + if self.simple_posterior: + q_mu, q_logvar = self.c2z(enc_last) + sample_z = self.gauss_connector(q_mu, q_logvar) + p_mu, p_logvar = self.zero, self.zero + + # pack attention context + dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) + attn_context = None + + # decode + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, + dec_inputs=dec_inputs, + dec_init_state=dec_init_state, # tuple: (h, c) + attn_context=attn_context, + mode=mode, + gen_type=gen_type, + beam_size=self.config.beam_size) # (batch_size, goal_nhid) + if mode == GEN: + ret_dict['sample_z'] = sample_z + ret_dict['q_mu'] = q_mu + ret_dict['q_logvar'] = q_logvar + return ret_dict, labels + + else: + result = Pack(nll=self.nll(dec_outputs, labels)) + pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) + result['pi_kl'] = pi_kl + result['nll'] = self.nll(dec_outputs, labels) + return result + + def get_z_via_vae(self, responses): + batch_size = responses.shape[0] + aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1)) + + # create decoder initial states + aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1) + + aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last) + aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar) + + return aux_sample_z, aux_q_mu, aux_q_logvar + + def get_z_via_rg(self, data_feed): + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + q_mu, q_logvar = self.c2z(enc_last) + + sample_z = self.gauss_connector(q_mu, q_logvar) + + return sample_z, q_mu, q_logvar + + def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'): + """ + generate response from latent var + """ + + if data_feed: + ctx_lens = data_feed['context_lens'] # (batch_size, ) + short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) + bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) + + # pack attention context + if isinstance(sample_y, np.ndarray): + sample_y = self.np2var(sample_y, FLOAT) + + dec_init_state = self.z_embedding(sample_y.unsqueeze(0)) + if (dec_init_state != dec_init_state).any(): + pdb.set_trace() + attn_context = None + + # decode + # if self.state_for_decoding: + # if not data_feed: + # raise ValueError + # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # # create decoder initial states + # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2) + + if self.config.dec_rnn_cell == 'lstm': + dec_init_state = tuple([dec_init_state, dec_init_state]) + + # decode + logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, + dec_init_state=dec_init_state, + attn_context=attn_context, + vocab=self.vocab, + max_words=max_words, + temp=0.1) + + return logprobs, outs + + def gaussian_logprob(self, mu, logvar, sample_z): + var = th.exp(logvar) + constant = float(-0.5 * np.log(2*np.pi)) + logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) + return logprob + + return logprobs, outs, joint_logpz, sample_z + + def pad_to(self, max_len, tokens, do_pad): + if len(tokens) >= max_len: + # print("cutting off, ", tokens) + return tokens[: max_len-1] + [tokens[-1]] + elif do_pad: + return tokens + [0] * (max_len - len(tokens)) + else: + return tokens diff --git a/latent_dialog/nn_lib.py b/latent_dialog/nn_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..f20d21760b3606e8d7aeaa78381458fc2f2b0883 --- /dev/null +++ b/latent_dialog/nn_lib.py @@ -0,0 +1,225 @@ +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +from latent_dialog.utils import cast_type, FLOAT + + +class IdentityConnector(nn.Module): + def __init(self): + super(IdentityConnector, self).__init__() + + def forward(self, hidden_state): + return hidden_state + + +class Bi2UniConnector(nn.Module): + def __init__(self, rnn_cell, num_layer, hidden_size, output_size): + super(Bi2UniConnector, self).__init__() + if rnn_cell == 'lstm': + self.fch = nn.Linear(hidden_size*2*num_layer, output_size) + self.fcc = nn.Linear(hidden_size*2*num_layer, output_size) + else: + self.fc = nn.Linear(hidden_size*2*num_layer, output_size) + + self.rnn_cell = rnn_cell + self.hidden_size = hidden_size + self.output_size = output_size + + def forward(self, hidden_state): + """ + :param hidden_state: [num_layer, batch_size, feat_size] + :param inputs: [batch_size, feat_size] + :return: + """ + if self.rnn_cell == 'lstm': + h, c = hidden_state + num_layer = h.size()[0] + flat_h = h.transpose(0, 1).contiguous() + flat_c = c.transpose(0, 1).contiguous() + new_h = self.fch(flat_h.view(-1, self.hidden_size*num_layer)) + new_c = self.fch(flat_c.view(-1, self.hidden_size*num_layer)) + return (new_h.view(1, -1, self.output_size), + new_c.view(1, -1, self.output_size)) + else: + # FIXME fatal error here! + num_layer = hidden_state.size()[0] + new_s = self.fc(hidden_state.view(-1, self.hidden_size*num_layer)) + new_s = new_s.view(1, -1, self.output_size) + return new_s + + +class Hidden2Gaussian(nn.Module): + def __init__(self, input_size, output_size, is_lstm=False, has_bias=True): + super(Hidden2Gaussian, self).__init__() + if is_lstm: + self.mu_h = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar_h = nn.Linear(input_size, output_size, bias=has_bias) + + self.mu_c = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar_c = nn.Linear(input_size, output_size, bias=has_bias) + else: + self.mu = nn.Linear(input_size, output_size, bias=has_bias) + self.logvar = nn.Linear(input_size, output_size, bias=has_bias) + + self.is_lstm = is_lstm + + def forward(self, inputs): + """ + :param inputs: batch_size x input_size + :return: + """ + if self.is_lstm: + h, c= inputs + if h.dim() == 3: + h = h.squeeze(0) + c = c.squeeze(0) + + mu_h, mu_c = self.mu_h(h), self.mu_c(c) + logvar_h, logvar_c = self.logvar_h(h), self.logvar_c(c) + return mu_h+mu_c, logvar_h+logvar_c + else: + # if inputs.dim() == 3: + # inputs = inputs.squeeze(0) + mu = self.mu(inputs) + logvar = self.logvar(inputs) + return mu, logvar + + +class Hidden2Discrete(nn.Module): + def __init__(self, input_size, y_size, k_size, is_lstm=False, has_bias=True): + super(Hidden2Discrete, self).__init__() + self.y_size = y_size + self.k_size = k_size + latent_size = self.k_size*self.y_size + if is_lstm: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.p_c = nn.Linear(input_size, latent_size, bias=has_bias) + else: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.is_lstm = is_lstm + + def forward(self, inputs): + """ + :param inputs: batch_size x input_size + :return: + """ + if self.is_lstm: + h, c= inputs + if h.dim() == 3: + h = h.squeeze(0) + c = c.squeeze(0) + logits = self.p_h(h) + self.p_c(c) + else: + logits = self.p_h(inputs) + logits = logits.view(-1, self.k_size) + log_qy = F.log_softmax(logits, dim=1) + return logits, log_qy + + +class Hidden2DiscretewDropout(nn.Module): + def __init__(self, input_size, y_size, k_size, is_lstm=False, has_bias=True, p_dropout=0.1, dropout_on_eval=False): + super(Hidden2DiscretewDropout, self).__init__() + self.y_size = y_size + self.k_size = k_size + latent_size = self.k_size*self.y_size + self.dropout_on_eval = dropout_on_eval + self.p_dropout = p_dropout + if not dropout_on_eval: + self.dropout = nn.Dropout(p_dropout) + if is_lstm: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.p_c = nn.Linear(input_size, latent_size, bias=has_bias) + else: + self.p_h = nn.Linear(input_size, latent_size, bias=has_bias) + + self.is_lstm = is_lstm + + def forward(self, inputs): + """ + :param inputs: batch_size x input_size + :return: + """ + if self.dropout_on_eval: + drop = nn.Dropout(self.p_dropout) + else: + drop = self.dropout + if self.is_lstm: + + h, c= inputs + if h.dim() == 3: + h = h.squeeze(0) + c = c.squeeze(0) + logits = drop(self.p_h(h)) + drop(self.p_c(c)) + else: + logits = drop(self.p_h(inputs)) + logits = logits.view(-1, self.k_size) + log_qy = F.log_softmax(logits, dim=1) + return logits, log_qy + + + +class GaussianConnector(nn.Module): + def __init__(self, use_gpu): + super(GaussianConnector, self).__init__() + self.use_gpu = use_gpu + + def forward(self, mu, logvar): + """ + Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the + reparametrization trick. + TODO: this should be better be a instance method in a Gaussian class. + :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching + :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None. + :return: + """ + epsilon = th.randn(logvar.size()) + epsilon = cast_type(Variable(epsilon), FLOAT, self.use_gpu) + std = th.exp(0.5 * logvar) + z = mu + std * epsilon + return z + + +class GumbelConnector(nn.Module): + def __init__(self, use_gpu): + super(GumbelConnector, self).__init__() + self.use_gpu = use_gpu + + def sample_gumbel(self, logits, use_gpu, eps=1e-20): + u = th.rand(logits.size()) + sample = Variable(-th.log(-th.log(u + eps) + eps)) + sample = cast_type(sample, FLOAT, use_gpu) + return sample + + def gumbel_softmax_sample(self, logits, temperature, use_gpu): + """ Draw a sample from the Gumbel-Softmax distribution""" + eps = self.sample_gumbel(logits, use_gpu) + y = logits + eps + return F.softmax(y / temperature, dim=y.dim()-1) + + def forward(self, logits, temperature=1.0, hard=False, + return_max_id=False): + """ + :param logits: [batch_size, n_class] unnormalized log-prob + :param temperature: non-negative scalar + :param hard: if True take argmax + :param return_max_id + :return: [batch_size, n_class] sample from gumbel softmax + """ + y = self.gumbel_softmax_sample(logits, temperature, self.use_gpu) + _, y_hard = th.max(y, dim=1, keepdim=True) + if hard: + y_onehot = cast_type(Variable(th.zeros(y.size())), FLOAT, self.use_gpu) + y_onehot.scatter_(1, y_hard, 1.0) + y = y_onehot + if return_max_id: + return y, y_hard + else: + return y + + + diff --git a/latent_dialog/normalizer/__init__.py b/latent_dialog/normalizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc b/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a89295ad2294f7cd75fb4b3d46cb68e167d6985b Binary files /dev/null and b/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc differ diff --git a/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc b/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eff08c2b9522413e90ddfd5120f0bbd5389e298 Binary files /dev/null and b/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc differ diff --git a/latent_dialog/normalizer/delexicalize.py b/latent_dialog/normalizer/delexicalize.py new file mode 100644 index 0000000000000000000000000000000000000000..df4636b24d0083bffb9b6d1303dc05d447720970 --- /dev/null +++ b/latent_dialog/normalizer/delexicalize.py @@ -0,0 +1,283 @@ +import re +import os +# import simplejson as json +import json + + +digitpat = re.compile('\d+') +timepat = re.compile("\d{1,2}[:]\d{1,2}") +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + +CUR_PATH = os.path.join(os.path.dirname(__file__)) +fin = open(os.path.join(CUR_PATH, 'mapping.pair'), 'r') +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + +# FORMAT +# domain_value +# restaurant_postcode +# restaurant_address +# taxi_car8 +# taxi_number +# train_id etc.. + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + + +def normalize(text): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\":\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + + +def prepareSlotValuesIndependent(): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + dic = [] + dic_area = [] + dic_food = [] + dic_price = [] + + # read databases + for domain in domains: + try: + fin = file(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + domain + '_db.json')) + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if val == '?' or val == 'free': + pass + elif key == 'address': + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + if "road" in val: + val = val.replace("road", "rd") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "rd" in val: + val = val.replace("rd", "road") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "st" in val: + val = val.replace("st", "street") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "street" in val: + val = val.replace("street", "st") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif key == 'name': + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + if "b & b" in val: + val = val.replace("b & b", "bed and breakfast") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "bed and breakfast" in val: + val = val.replace("bed and breakfast", "b & b") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "hotel" in val and 'gonville' not in val: + val = val.replace("hotel", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "restaurant" in val: + val = val.replace("restaurant", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif key == 'postcode': + dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']')) + elif key == 'phone': + dic.append((val, '[' + domain + '_' + 'phone' + ']')) + elif key == 'trainID': + dic.append((normalize(val), '[' + domain + '_' + 'id' + ']')) + elif key == 'department': + dic.append((normalize(val), '[' + domain + '_' + 'department' + ']')) + + # NORMAL DELEX + elif key == 'area': + dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']')) + elif key == 'food': + dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']')) + elif key == 'pricerange': + dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']')) + else: + pass + # TODO car type? + except: + pass + + if domain == 'hospital': + dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']')) + + elif domain == 'police': + dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']')) + + # add at the end places from trains + fin = open(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + 'train' + '_db.json')) + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if key == 'departure' or key == 'destination': + dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']')) + + # add specific values: + for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']: + dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']')) + + # more general values add at the end + dic.extend(dic_area) + dic.extend(dic_food) + dic.extend(dic_price) + + return dic + + +def delexicalise(utt, dictionary): + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + return utt + + +def delexicaliseReferenceNumber(sent, metadata): + """Based on the belief state, we can find reference number that + during data gathering was created randomly.""" + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'] # , 'police'] + if metadata: + for domain in domains: + if metadata[domain]['book']['booked']: + for slot in metadata[domain]['book']['booked'][0]: + if slot == 'reference': + val = '[' + domain + '_' + slot + ']' + else: + val = '[' + domain + '_' + slot + ']' + key = normalize(metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with hashtag + key = normalize("#" + metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with ref# + key = normalize("ref#" + metadata[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + return sent + + +def delexicalse_num(sent): + # changes to numbers only here + digitpat = re.compile('\d+') + sent = re.sub(digitpat, '[value_count]', sent) + return sent + + +def e2e_delecalise(utt, dictionary, metadata): + utt = normalize(utt) + utt = delexicalise(utt, dictionary) + utt = delexicaliseReferenceNumber(utt, metadata) + return delexicalse_num(utt) + + +if __name__ == '__main__': + prepareSlotValuesIndependent() diff --git a/latent_dialog/normalizer/mapping.pair b/latent_dialog/normalizer/mapping.pair new file mode 100644 index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2 --- /dev/null +++ b/latent_dialog/normalizer/mapping.pair @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/latent_dialog/offlinerl_utils.py b/latent_dialog/offlinerl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de121dd23145faac8230dd52db96528a5934d217 --- /dev/null +++ b/latent_dialog/offlinerl_utils.py @@ -0,0 +1,652 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# +# Copyright © 2021 lubis <lubis@hilbert50> +# +# Distributed under terms of the MIT license. + +""" + +""" + +import torch as th +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +import pdb +import random +import copy +from collections import namedtuple, deque +from torch.autograd import Variable +from latent_dialog.enc2dec.encoders import RnnUttEncoder +from latent_dialog.utils import get_detokenize, cast_type, extract_short_ctx, np2var, LONG, FLOAT +from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS +import dill +import time + +class Actor(nn.Module): + def __init__(self, model, corpus, config): + super(Actor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.is_stochastic = config.is_stochastic + self.y_size = config.y_size + if 'k_size' in config: + self.k_size = config.k_size + self.is_gauss = False + else: + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + self.c2z = copy.deepcopy(model.c2z) + if not self.is_gauss: + self.gumbel_connector = copy.deepcopy(model.gumbel_connector) + else: + self.gauss_connector = copy.deepcopy(model.gauss_connector) + self.gaussian_logprob = model.gaussian_logprob + self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu) + + + def forward(self, data_feed, hard=False): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + if self.is_gauss: + q_mu, q_logvar = self.c2z(enc_last) + if self.is_stochastic: + sample_z = self.gauss_connector(q_mu, q_logvar) + else: + sample_z = q_mu + logprob_sample_z = self.gaussian_logprob(q_mu, q_logvar, sample_z) + else: + logits_qy, log_qy = self.c2z(enc_last) + qy = F.softmax(logits_qy / 1.0, dim=1) # (batch_size, vocab_size, ) + log_qy = F.log_softmax(logits_qy, dim=1) # (batch_size, vocab_size, ) + + if self.is_stochastic: + idx = th.multinomial(qy, 1).detach() + soft_z = self.gumbel_connector(logits_qy, hard=False) + else: + idx = th.argmax(th.exp(log_qy), dim=1, keepdim=True) + soft_z = th.exp(log_qy) + sample_z = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) + sample_z.scatter_(1, idx, 1.0) + logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) + + joint_logpz = th.sum(logprob_sample_z, dim=1) + return joint_logpz, sample_z + +class DeterministicGaussianActor(nn.Module): + def __init__(self, model, corpus, config): + super(DeterministicGaussianActor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.y_size = config.y_size + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + + self.policy = copy.deepcopy(model.c2z) + + def forward(self, data_feed): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + mu, logvar = self.policy(enc_last) + z = mu + if self.max_action is not None: + z = self.max_action * th.tanh(z) + + return z, mu, logvar + +class StochasticGaussianActor(nn.Module): + def __init__(self, model, corpus, config): + super(StochasticGaussianActor, self).__init__() + self.vocab = corpus.vocab + self.vocab_dict = corpus.vocab_dict + self.vocab_size = len(self.vocab) + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.bos_id = self.vocab_dict[BOS] + self.eos_id = self.vocab_dict[EOS] + self.pad_id = self.vocab_dict[PAD] + self.config = config + + self.use_gpu = config.use_gpu + + self.embedding = None + self.y_size = config.y_size + self.max_action = config.max_action if "max_action" in config else None + self.is_gauss = True + + self.utt_encoder = copy.deepcopy(model.utt_encoder) + self.policy = copy.deepcopy(model.c2z) + self.gauss_connector = copy.deepcopy(model.gauss_connector) + + def forward(self, data_feed, n_z=1): + short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) + # create decoder initial states + enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) + + q_mu, q_logvar = self.policy(enc_last) + if n_z > 1: + z = [self.gauss_connector(q_mu, q_logvar) for _ in range(n_z)] + else: + z = self.gauss_connector(q_mu, q_logvar) + + return z, q_mu, q_logvar + +class RecurrentCritic(nn.Module): + def __init__(self,cvae, corpus, config, args): + super(RecurrentLatentCritic, self).__init__() + + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + self.action_dim = config.y_size #TODO adjust for categorical + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + self.goal_to_critic = args.goal_to_critic + if self.goal_to_critic: + raise NotImplementedError + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + self.q11 = nn.Linear(self.input_dim, 500) + self.q12 = nn.Linear(500, 300) + self.q13 = nn.Linear(300, 100) + self.q14 = nn.Linear(100, 20) + self.q15 = nn.Linear(20, 1) + + self.q21 = nn.Linear(self.input_dim, 500) + self.q22 = nn.Linear(500, 300) + self.q23 = nn.Linear(300, 100) + self.q24 = nn.Linear(100, 20) + self.q25 = nn.Linear(20, 1) + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + q1 = self.q11(sa) + q1 = F.relu(self.q12(q1)) + q1 = F.relu(self.q13(q1)) + q1 = F.relu(self.q14(q1)) + q1 = self.q15(q1) + + + q2 = self.q21(sa) + q2 = F.relu(self.q22(q2)) + q2 = F.relu(self.q23(q2)) + q2 = F.relu(self.q24(q2)) + q2 = self.q25(q2) + + return q1, q2 + + def q1(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(0)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + q1 = self.q11(sa) + q1 = F.relu(self.q12(q1)) + q1 = F.relu(self.q13(q1)) + q1 = F.relu(self.q14(q1)) + q1 = self.q15(q1) + + return q1 + + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(len(ctx_summary)): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + + #OPTION 1 add goal to encoder for each time step + if self.goal_to_critic and self.add_goal=="early": + sa = th.cat([sa, goals], dim = 1) + + output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1))) + + #OPTION 2 add goal combined with hidden state to predict final score + if self.goal_to_critic and self.add_goal=="late": + output = th.cat([output, goals.unsqueeze(1)], dim = 2) + + q1 = self.q11(output.squeeze(1)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = th.sigmoid(q1) + elif self.activation_function == "tanh": + q1 = F.tanh(q1) + + return q1 + +class SingleHierarchicalRecurrentCritic(nn.Module): + def __init__(self, cvae, corpus, config, args): + super(SingleHierarchicalRecurrentCritic, self).__init__() + + self.hidden_size = 500 + self.args = args + if "model_path" in args: + _path = args.model_path + else: + _path = args.sv_model_path + + if "gauss" in _path: + self.is_gauss = True + else: + self.is_gauss = False + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + try: + self.action_dim = cvae.aux_encoder.output_size + except: + self.action_dim = cvae.utt_encoder.output_size + else: + if self.is_gauss: + self.action_dim = config.y_size + else: + if args.embed_z_for_critic: + self.embed_z = True + self.action_dim = config.dec_cell_size + else: + self.embed_z = False + self.action_dim = config.y_size * config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + + self.goal_to_critic = args.goal_to_critic + self.add_goal = args.add_goal + if self.goal_to_critic: + self.goal_size = corpus.goal_size + if self.add_goal == "early": + self.input_dim += self.goal_size + + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + try: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + except: + self.action_encoder = copy.deepcopy(cvae.utt_encoder) + else: + self.action_encoder = None + + self.dialogue_encoder = nn.LSTM( + input_size = self.input_dim, + hidden_size = self.hidden_size, + dropout=0.1 + ) + + if self.add_goal=="late": + self.q11 = nn.Linear(self.hidden_size + self.goal_size, 1) + else: + self.q11 = nn.Linear(self.hidden_size, 1) + self.activation_function = args.critic_actf if "critic_actf" in args else "none" + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False) + else: + self.d = th.nn.Dropout(p=0.0, inplace=False) + + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) # takes about 0.006seconds + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(len(ctx_summary)): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + + #OPTION 1 add goal to encoder for each time step + if self.goal_to_critic and self.add_goal=="early": + sa = th.cat([sa, goals], dim = 1) + + output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1))) + + #OPTION 2 add goal combined with hidden state to predict final score + if self.goal_to_critic and self.add_goal=="late": + output = th.cat([output, goals.unsqueeze(1)], dim = 2) + + q1 = self.q11(output.squeeze(1)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = th.sigmoid(q1) + elif self.activation_function == "tanh": + q1 = F.tanh(q1) + + return q1 + + def forward_target(self, data_feed, act, corpus_act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + q1s =[] + for i in range(bs_label.shape[0]): + if self.word_plas: + if i > 0: + corpus_resp_summary, _, _ = self.action_encoder(corpus_act[:i].unsqueeze(1)) + else: + corpus_resp_summary = th.Tensor([]) + if self.args.use_gpu: + corpus_resp_summary = corpus_resp_summary.cuda() + actor_resp_summary, _, _ = self.action_encoder(act[i].unsqueeze(1)) + sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_resp_summary[:i], actor_resp_summary], dim=0).squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_act[:i], act[i].unsqueeze(0)], dim=0)], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'][:i+1], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(i+1): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + + #OPTION 1 add goal to encoder for each time step + if self.goal_to_critic and self.add_goal=="early": + sa = th.cat([sa, goals], dim = 1) + + output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1))) + + #OPTION 2 add goal combined with hidden state to predict final score + if self.goal_to_critic and self.add_goal=="late": + output = th.cat([output, goals.unsqueeze(1)], dim = 2) + + q1 = self.q11(output.squeeze(1)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = th.sigmoid(q1) + elif self.activation_function == "tanh": + q1 = F.tanh(q1) + + q1s.append(q1[-1]) + + return th.cat(q1s, dim=0).unsqueeze(1) + +class SingleRecurrentCritic(nn.Module): + def __init__(self, cvae, corpus, config, args): + super(SingleRecurrentCritic, self).__init__() + + if "gauss" in args.saved_path: + self.is_gauss = True + else: + self.is_gauss = False + self.embedding = None + self.word_plas = args.word_plas + self.state_dim = cvae.utt_encoder.output_size + if self.word_plas: + self.action_dim = cvae.aux_encoder.output_size + else: + if self.is_gauss: + self.action_dim = config.y_size + else: + if args.embed_z_for_critic: + self.action_dim = config.dec_cell_size + else: + self.action_dim = config.y_size * config.k_size + + self.bs_size = corpus.bs_size + self.db_size = corpus.db_size + self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim + + self.goal_to_critic = args.goal_to_critic + if self.goal_to_critic: + self.goal_size = corpus.goal_size + self.input_dim += self.goal_size + + + self.use_gpu = config.use_gpu + + self.state_encoder = copy.deepcopy(cvae.utt_encoder) + if self.word_plas: + self.action_encoder = copy.deepcopy(cvae.aux_encoder) + else: + self.action_encoder = None + + self.q11 = nn.Linear(self.input_dim, 1) + self.activation_function = args.critic_actf if "critic_actf" in args else "none" + + self.critic_dropout = args.critic_dropout + if self.critic_dropout: + self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False) + else: + self.d = th.nn.Dropout(p=0.0, inplace=False) + + def forward(self, data_feed, act): + ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu) + bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + db_label = np2var(data_feed['db'], FLOAT, self.use_gpu) # (batch_size, max_ctx_len, max_utt_len) + + ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1)) + if self.word_plas: + resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1) + else: + sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1) + + if self.goal_to_critic: + try: + goals = np2var(data_feed['goals'], FLOAT, self.use_gpu) + except KeyError: + goals = [] + for turn_id in range(len(ctx_summary)): + goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)])) + goals = np2var(np.asarray(goals), FLOAT, self.use_gpu) + sa = th.cat([sa, goals], dim = 1) + + q1 = self.q11(self.d(sa)) + + if self.activation_function == "relu": + q1 = F.relu(q1) + elif self.activation_function == "sigmoid": + q1 = th.sigmoid(q1) + + return q1 + +class ReplayBuffer(object): + """ + Buffer to store experiences, to be used in off-policy learning + """ + def __init__(self, config): + + self.batch_size = config.batch_size + self.fix_episode = config.fix_episode + + self.experiences = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "next_action", "done", "Return"]) + self.memory = deque() + self.seed = random.seed(config.random_seed) + + def add(self, states, actions, rewards, next_states, next_actions, dones, Returns): + if self.fix_episode: + self._add_episode(states, actions, rewards, next_states, next_actions, dones, Returns) + else: + for i in range(len(states)): + self._add(states[i], actions[i], rewards[i], next_states[i], next_actions[i], dones[i], Returns[i]) + + def _add(self, state, action, reward, next_state, next_action, done, Return): + e = self.experiences(state, action, reward, next_state, next_action, done, Return) + self.memory.append(e) + + def _add_episode(self, states, actions, rewards, next_states, next_actions, dones, Returns): + ep = [] + for s, a, r, s_, a_, d, R in zip(states, actions, rewards, next_states, next_actions, dones, Returns): + ep.append(self.experiences(s, a, r, s_, a_, d, R)) + self.memory.append(ep) + + + def sample(self): + if self.fix_episode: + return self._sample_episode() + else: + return self._sample() + + def _sample(self): + experiences = random.sample(self.memory, k = self.batch_size) + + states = {} + states['contexts'] = np.asarray([e.state['contexts'] for e in experiences]) + states['bs'] = np.asarray([e.state['bs'] for e in experiences]) + states['db'] = np.asarray([e.state['db'] for e in experiences]) + states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) + states['goals'] = np.asarray([e.state['goals'] for e in experiences]) + + actions = np.asarray([e.action for e in experiences if e is not None]) + rewards = np.asarray([e.reward for e in experiences if e is not None]) + + next_states = {} + next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences]) + next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences]) + next_states['db'] = np.asarray([e.next_state['db'] for e in experiences]) + next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) + next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences]) + + next_actions = np.asarray([e.next_action for e in experiences if e is not None]) + + dones = np.asarray([e.done for e in experiences if e is not None]) + returns = np.asarray([e.Return for e in experiences if e is not None]) + + return (states, actions, rewards, next_states, next_actions, dones, returns) + + def _sample_episode(self): + episodes = random.sample(self.memory, k = 1) + + for experiences in episodes: + states = {} + states['contexts'] = np.asarray([e.state['contexts'] for e in experiences]) + states['bs'] = np.asarray([e.state['bs'] for e in experiences]) + states['db'] = np.asarray([e.state['db'] for e in experiences]) + states['keys'] = [e.state['keys'] for e in experiences] + states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) + states['goals'] = np.asarray([e.state['goals'] for e in experiences]) + + actions = np.asarray([e.action for e in experiences if e is not None]) + rewards = np.asarray([e.reward for e in experiences if e is not None]) + + next_states = {} + next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences]) + next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences]) + next_states['db'] = np.asarray([e.next_state['db'] for e in experiences]) + next_states['keys'] = [e.next_state['keys'] for e in experiences] + next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) + next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences]) + + next_actions = np.asarray([e.next_action for e in experiences if e is not None]) + + dones = np.asarray([e.done for e in experiences if e is not None]) + returns = np.asarray([e.Return for e in experiences if e is not None]) + + return (states, actions, rewards, next_states, next_actions, dones, returns) + + def __len__(self): + return len(self.memory) + + def save(self, path): + with open(path, 'wb') as f: + dill.dump(self.memory, f) + + def load(self, path): + with open(path, 'rb') as f: + self.memory = dill.load(f) + + def load_add(self, path): + with open(path, 'rb') as f: + self.memory += dill.load(f) + diff --git a/latent_dialog/record.py b/latent_dialog/record.py new file mode 100644 index 0000000000000000000000000000000000000000..613d34ef78e3c0e7ffcc9977ecfa50bf2c347e21 --- /dev/null +++ b/latent_dialog/record.py @@ -0,0 +1,169 @@ +import numpy as np +from latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN, GEN_VALID +from collections import Counter + + +class UniquenessSentMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + self.all_sents = [] + + def record(self, sen): + self.seen.add(' '.join(sen)) + self.all_sents.append(' '.join(sen)) + + def value(self): + return len(self.seen) + + def top_n(self, n): + return Counter(self.all_sents).most_common(n) + + +class UniquenessWordMetric(object): + """Metric that evaluates the number of unique sentences.""" + def __init__(self): + self.seen = set() + + def record(self, word_list): + self.seen.update(word_list) + + def value(self): + return len(self.seen) + + +def record_task(n_epsd, model, val_data, config, ppl_f, dialog, ctx_gen_eval, rl_f): + record_ppl(n_epsd, model, val_data, config, ppl_f) + record_rl_task(n_epsd, dialog, ctx_gen_eval, rl_f) + + +def record(n_epsd, model, val_data, sv_config, lm_model, ppl_f, dialog, ctx_gen_eval, rl_f): + record_ppl_with_lm(n_epsd, model, val_data, sv_config, lm_model, ppl_f) + record_rl(n_epsd, dialog, ctx_gen_eval, rl_f) + + +def record_ppl_with_lm(n_epsd, model, data, config, lm_model, ppl_f): + model.eval() + loss_list = [] + data.epoch_init(config, shuffle=False, verbose=True) + while True: + batch = data.next_batch() + if batch is None: + break + for i in range(1): + loss = model(batch, mode=TEACH_FORCE, use_py=True) + loss_list.append(loss.nll.item()) + + # USE LM to test generation performance + data.epoch_init(config, shuffle=False, verbose=False) + gen_loss_list = [] + # first generate + while True: + batch = data.next_batch() + if batch is None: + break + + outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) + # move from GPU to CPU + labels = labels.cpu() + pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]] + pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len) + # clean up the pred labels + clean_pred_labels = np.zeros((pred_labels.shape[0], pred_labels.shape[1]+1)) + clean_pred_labels[:, 0] = model.sys_id + for b_id in range(pred_labels.shape[0]): + for t_id in range(pred_labels.shape[1]): + token = pred_labels[b_id, t_id] + clean_pred_labels[b_id, t_id + 1] = token + if token in [model.eos_id] or t_id == pred_labels.shape[1]-1: + break + + pred_out_lens = np.sum(np.sign(clean_pred_labels), axis=1) + max_pred_lens = np.max(pred_out_lens) + clean_pred_labels = clean_pred_labels[:, 0:int(max_pred_lens)] + batch['outputs'] = clean_pred_labels + batch['output_lens'] = pred_out_lens + loss = lm_model(batch, mode=TEACH_FORCE) + gen_loss_list.append(loss.nll.item()) + + avg_loss = np.average(loss_list) + avg_ppl = np.exp(avg_loss) + gen_avg_loss = np.average(gen_loss_list) + gen_avg_ppl = np.exp(gen_avg_loss) + + ppl_f.write('{}\t{}\t{}\n'.format(n_epsd, avg_ppl, gen_avg_ppl)) + ppl_f.flush() + model.train() + + +def record_ppl(n_epsd, model, val_data, config, ppl_f): + model.eval() + loss_list = [] + val_data.epoch_init(config, shuffle=False, verbose=True) + while True: + batch = val_data.next_batch() + if batch is None: + break + loss = model(batch, mode=TEACH_FORCE, use_py=True) + loss_list.append(loss.nll.item()) + aver_loss = np.average(loss_list) + aver_ppl = np.exp(aver_loss) + ppl_f.write('{}\t{}\n'.format(n_epsd, aver_ppl)) + ppl_f.flush() + model.train() + + +def record_rl(n_epsd, dialog, ctx_gen, rl_f): + conv_list = [] + reward_list = [] + agree_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + + for ctxs in ctx_gen.ctxs: + conv, agree, rewards = dialog.run(ctxs) + true_reward = rewards[0] if agree else 0 + reward_list.append(true_reward) + conv_list.append(conv) + agree_list.append(float(agree) if agree is not None else 0.0) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + + # json.dump(conv_list, text_f, indent=4) + aver_reward = np.average(reward_list) + aver_agree = np.average(agree_list) + unique_sent_num = sent_metric.value() + unique_word_num = word_metric.value() + print(sent_metric.top_n(10)) + + rl_f.write('{}\t{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, aver_agree, unique_sent_num, unique_word_num)) + rl_f.flush() + + +def record_rl_task(n_epsd, dialog, goal_gen, rl_f): + conv_list = [] + reward_list = [] + sent_metric = UniquenessSentMetric() + word_metric = UniquenessWordMetric() + print("Begin RL testing") + cnt = 0 + for g_key, goal in goal_gen.iter(1): + cnt += 1 + conv, success = dialog.run(g_key, goal) + true_reward = success + reward_list.append(true_reward) + conv_list.append(conv) + for turn in conv: + if turn[0] == 'Elder': + sent_metric.record(turn[1]) + word_metric.record(turn[1]) + + # json.dump(conv_list, text_f, indent=4) + aver_reward = np.average(reward_list) + unique_sent_num = sent_metric.value() + unique_word_num = word_metric.value() + rl_f.write('{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, unique_sent_num, unique_word_num)) + rl_f.flush() + print("End RL testing") \ No newline at end of file diff --git a/latent_dialog/utils.py b/latent_dialog/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6244f87538e4399b70d497dd4f18c28b878c0b91 --- /dev/null +++ b/latent_dialog/utils.py @@ -0,0 +1,140 @@ +import os +import numpy as np +import random +import torch as th +from torch.autograd import Variable +from nltk import RegexpTokenizer +from nltk.tokenize.treebank import TreebankWordDetokenizer +import logging +import sys +from collections import defaultdict + +INT = 0 +LONG = 1 +FLOAT = 2 + + +class Pack(dict): + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def add(self, **kwargs): + for k, v in kwargs.items(): + self[k] = v + + def copy(self): + pack = Pack() + for k, v in self.items(): + if type(v) is list: + pack[k] = list(v) + else: + pack[k] = v + return pack + + @staticmethod + def msg_from_dict(dictionary, tokenize, speaker2id, bos_id, eos_id, include_domain=False): + pack = Pack() + for k, v in dictionary.items(): + pack[k] = v + pack['speaker'] = speaker2id[pack.speaker] + pack['conf'] = dictionary.get('conf', 1.0) + utt = pack['utt'] + if 'QUERY' in utt or "RET" in utt: + utt = str(utt) + # utt = utt.translate(None, ''.join([':', '"', "{", "}", "]", "["])) + utt = utt.translate(str.maketrans('', '', ''.join([':', '"', "{", "}", "]", "["]))) + utt = str(utt) + if include_domain: + pack['utt'] = [bos_id, pack['speaker'], pack['domain']] + tokenize(utt) + [eos_id] + else: + pack['utt'] = [bos_id, pack['speaker']] + tokenize(utt) + [eos_id] + return pack + +def get_tokenize(): + return RegexpTokenizer(r'\w+|#\w+|<\w+>|%\w+|[^\w\s]+').tokenize + +def get_detokenize(): + return lambda x: TreebankWordDetokenizer().detokenize(x) + +def cast_type(var, dtype, use_gpu): + if use_gpu: + if dtype == INT: + var = var.type(th.cuda.IntTensor) + elif dtype == LONG: + var = var.type(th.cuda.LongTensor) + elif dtype == FLOAT: + var = var.type(th.cuda.FloatTensor) + else: + raise ValueError('Unknown dtype') + else: + if dtype == INT: + var = var.type(th.IntTensor) + elif dtype == LONG: + var = var.type(th.LongTensor) + elif dtype == FLOAT: + var = var.type(th.FloatTensor) + else: + raise ValueError('Unknown dtype') + return var + +def read_lines(file_name): + """Reads all the lines from the file.""" + assert os.path.exists(file_name), 'file does not exists %s' % file_name + lines = [] + with open(file_name, 'r') as f: + for line in f: + lines.append(line.strip()) + return lines + +def set_seed(seed): + """Sets random seed everywhere.""" + random.seed(seed) + np.random.seed(seed) + th.manual_seed(seed) + if th.cuda.is_available(): + th.cuda.manual_seed(seed) + th.backends.cudnn.enabled = False + th.backends.cudnn.benchmark = False + th.backends.cudnn.deterministic = True + +def prepare_dirs_loggers(config, script=""): + logFormatter = logging.Formatter("%(message)s") + rootLogger = logging.getLogger() + rootLogger.setLevel(logging.DEBUG) + + consoleHandler = logging.StreamHandler(sys.stdout) + consoleHandler.setLevel(logging.DEBUG) + consoleHandler.setFormatter(logFormatter) + rootLogger.addHandler(consoleHandler) + + # if hasattr(config, 'forward_only') and config.forward_only: + if 'forward_only' in config and config.forward_only: + return + + fileHandler = logging.FileHandler(os.path.join(config.saved_path,'session.log')) + fileHandler.setLevel(logging.DEBUG) + fileHandler.setFormatter(logFormatter) + rootLogger.addHandler(fileHandler) + +def get_chat_tokenize(): + return nltk.RegexpTokenizer(r'\w+|<sil>|[^\w\s]+').tokenize + +class missingdict(defaultdict): + def __missing__(self, key): + return self.default_factory() + +def extract_short_ctx(context, context_lens, backward_size=1): + utts = [] + for b_id in range(context.shape[0]): + utts.append(context[b_id, context_lens[b_id]-1]) + return np.array(utts) + +def np2var(inputs, dtype, use_gpu): + if inputs is None: + return None + return cast_type(Variable(th.from_numpy(inputs)), + dtype, + use_gpu) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4ea15340c736d5766f80318768080feffabfa68 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +dill==0.3.4 +matplotlib==3.5.1 +nltk==3.5 +numpy==1.19.5 +requests==2.28.1 +scikit_learn==1.1.2 +scipy==1.9.1 +tabulate==0.8.9 +torch==1.8.2+cu111 +tqdm==4.57.0