diff --git a/data.zip b/data.zip
new file mode 100644
index 0000000000000000000000000000000000000000..a0d264ee2b4380d22755f4b05ff749d56b56687e
Binary files /dev/null and b/data.zip differ
diff --git a/experiments_woz/__init__.py b/experiments_woz/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b38317c87521b9439cf7fa52ab450c5f7b77a513
--- /dev/null
+++ b/experiments_woz/__init__.py
@@ -0,0 +1,2 @@
+# @Time    : 10/18/18 1:49 PM
+# @Author  : Tiancheng Zhao
\ No newline at end of file
diff --git a/experiments_woz/__pycache__/__init__.cpython-36.pyc b/experiments_woz/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1d79281e874bb19e4c5f22fe4b733796b45130d
Binary files /dev/null and b/experiments_woz/__pycache__/__init__.cpython-36.pyc differ
diff --git a/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc b/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2d25cfe50e02ef84e0673d40529caddbf430334
Binary files /dev/null and b/experiments_woz/__pycache__/dialog_utils.cpython-36.pyc differ
diff --git a/experiments_woz/critic.py b/experiments_woz/critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..0438bea2aa59cd13a35ef79f8e9e874e58bcc8c0
--- /dev/null
+++ b/experiments_woz/critic.py
@@ -0,0 +1,262 @@
+import time
+import os
+import sys
+import random
+sys.path.append('../')
+import json
+import torch as th
+import pdb
+from tqdm import trange
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+from latent_dialog.corpora import NormMultiWozCorpus
+from latent_dialog.models_task import *
+from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
+from latent_dialog.main import OfflineCritic
+from latent_dialog.evaluators import MultiWozEvaluator
+from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic
+
+
+def main(seed, pretrained_folder, pretrained_model_id):
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[START]', start_time, '='*30)
+    # RL configuration
+    env = 'gpu'
+
+    exp_dir = os.path.join('sys_config_log_model', pretrained_folder, "critic-"+start_time)
+    if "rl" in pretrained_folder:
+        join_fmt = ".model"
+        config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json")
+    else:
+        join_fmt = "-model"
+        config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json")
+
+    # create exp folder
+    if not os.path.exists(exp_dir):
+        os.mkdir(exp_dir)
+
+    critic_config = Pack(
+        config_path = config_path,
+        model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}'.format(pretrained_model_id, join_fmt)),
+        critic_config_path = os.path.join(exp_dir, 'critic_config.json'),
+        critic_model_path = os.path.join(exp_dir, 'critic_model'),
+        saved_path = exp_dir,
+        record_path = exp_dir,
+        record_freq = 500,
+        sv_train_freq= 0,  # TODO pay attention to main.py, cuz it is also controlled there
+        use_gpu = env == 'gpu',
+        nepoch = 1,
+        nepisode = 10000,
+        word_plas=True,
+        fix_episode=True,
+        train_with_full_data=True,
+        reward_type="default", #default, turnPenalty, or infoGain
+        infoGain_threshhold = 0.2, # only if reward_type is infoGain
+        goal_to_critic=True,
+        add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
+        critic_kl_loss=False,
+        critic_kl_alpha=0.1,
+        critic_dropout=True, # use with regularize_critic=False
+        critic_dropout_rate=0.3, # only if dropout is true
+        critic_dropout_agg="min", #avg or min
+        critic_sample=False,
+        critic_transformer=False,
+        critic_actf="sigmoid", #relu or sigmoid or tanh or none
+        embed_z_for_critic=True, #for categorical action only, when word_plas=False
+        critic_loss="mse", #mse or huber
+        critic_rl_lr = 0.01,
+        decay_critic_lr=False,
+        train_vae=False,
+        train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
+        train_vae_nepisode=0, # nepisode for the rest of VAE training
+        train_vae_nepisode_init=10000, # nepisode for first VAE training
+        weighted_vae_nll=True,
+        rl_lr = 0.01,
+        max_words = 50,
+        temperature=1.0,
+        momentum = 0.0,
+        nesterov = False,
+        gamma = 0.99,
+        tau=0.005,
+        lmbda=0.5,
+        beta=0.001,
+        batch_size=16,
+        rl_clip = 1.0,
+        n_z=10,
+        random_seed = seed,
+        policy_dropout=False,
+        dropout_on_eval=False,
+        fail_info_penalty=False
+    )
+
+    if "plas" in pretrained_folder:
+        critic_config['actor_path'] = critic_config.model_path.replace(".model", ".actor")
+        critic_config['actor_config'] = critic_config.model_path.replace("reward_best.model", "rl_config.json")
+    else:
+        critic_config['actor_path'] = None
+        critic_config['actor_config'] = None
+
+    prepare_dirs_loggers(critic_config)
+
+    # list config keys that are being compared for tensorboard naming
+    tb_keys = ["critic_rl_lr", "reward_type", "word_plas", "train_with_full_data"]
+    tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys])
+
+    # load previous supervised learning configuration and corpus
+    config = Pack(json.load(open(critic_config.config_path)))
+    config['dropout'] = 0.0
+    config['use_gpu'] = critic_config.use_gpu
+    config['policy_dropout'] = critic_config.policy_dropout
+    config['dropout_on_eval'] = critic_config.dropout_on_eval
+    # assert config.train_path == critic_config.train_path
+
+    # set random seed
+    if critic_config.random_seed is None:
+        try:
+            critic_config.random_seed = config.seed
+        except:
+            critic_config.random_seed = config.random_seed
+    set_seed(critic_config.random_seed)
+
+
+    try:
+        corpus = NormMultiWozCorpus(config)
+    except FileNotFoundError:
+        config['train_path'] = config.train_path.replace("/home/lubis", "")
+        config['valid_path'] = config.valid_path.replace("/home/lubis", "")
+        config['test_path'] = config.test_path.replace("/home/lubis", "")
+        corpus = NormMultiWozCorpus(config)
+    except PermissionError:
+        config['train_path'] = config.train_path.replace("/root", "..")
+        config['valid_path'] = config.valid_path.replace("/root", "..")
+        config['test_path'] = config.test_path.replace("/root", "..")
+        corpus = NormMultiWozCorpus(config)
+
+    critic_config['train_path'] = config['train_path']
+    critic_config['valid_path'] = config['valid_path']
+    critic_config['test_path'] = config['test_path']
+    
+    critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill")
+    critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill")
+    critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill")
+
+    if critic_config.fix_episode:
+        critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill")
+
+    
+    critic_config['y_size'] = config['y_size']
+
+
+    # save configuration
+    with open(critic_config.critic_config_path, 'w') as f:
+        json.dump(critic_config, f, indent=4)
+
+    if "rl" in pretrained_folder:
+        if "gauss" in pretrained_folder:
+            if "plas" in pretrained_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+        else:
+            sys_model = SysPerfectBD2Cat(corpus, config)
+    else:
+        if "actz" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysActZGauss(corpus, config)
+            else:
+                sys_model = SysActZCat(corpus, config)
+        elif "mt" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysMTCat(corpus, config)
+        else:
+            if "gauss" in pretrained_folder:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Cat(corpus, config)
+
+    if config.use_gpu:
+        sys_model.cuda()
+
+    mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
+    sys_model.load_state_dict(mt_model_dict)
+
+    sys_model.eval()
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    if critic_config.word_plas:
+        agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+    else:
+        agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+
+    main = OfflineCritic(agent, corpus, config, critic_config, task_generate_critic, name=tensorboard_name, vae_gen=task_generate)
+    # save sys model
+    # th.save(sys_model.state_dict(), critic_config.rl_model_path)
+
+    # initialize train buffer
+    if os.path.isfile(critic_config.train_memory_path):
+    # if False:
+        print("Loading replay buffer for training from {}".format(critic_config.train_memory_path))
+        agent.train_buffer.load(critic_config.train_memory_path)
+        print(len(agent.train_buffer))
+        if "train_with_full_data" in critic_config and critic_config.train_with_full_data:
+            print(f"adding buffer {critic_config.valid_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.valid_memory_path)
+            print(len(agent.train_buffer))
+            print(f"adding buffer {critic_config.test_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.test_memory_path)
+            print(len(agent.train_buffer))
+
+    else:
+        print("Extracting experiences from training data")
+        main.extract(main.train_data, main.agent.train_buffer)
+        print("Saving experiences to {}".format(critic_config.train_memory_path))
+        agent.train_buffer.save(critic_config.train_memory_path)
+
+    # initialize valid buffer
+    if os.path.isfile(critic_config.valid_memory_path):
+        print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.load(critic_config.valid_memory_path)
+    else:
+        print("Extracting experiences from valid data")
+        main.extract(main.val_data, main.agent.valid_buffer)
+        print("Saving experiences to {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.save(critic_config.valid_memory_path)
+    
+    # initialize test buffer
+    if os.path.isfile(critic_config.test_memory_path):
+        print("Loading replay buffer for test from {}".format(critic_config.test_memory_path))
+        agent.test_buffer.load(critic_config.test_memory_path)
+    else:
+        print("Extracting experiences from test data")
+        main.extract(main.test_data, main.agent.test_buffer)
+        print("Saving experiences to {}".format(critic_config.test_memory_path))
+        agent.test_buffer.save(critic_config.test_memory_path)
+
+    if critic_config.use_gpu:
+        agent.critic.cuda()
+        agent.critic_target.cuda()
+
+    # with open(exp_dir + "/test_performance_start.txt", "w") as f:
+        # task_run_critic(main.test_data, agent, None, evaluator=evaluator, outfile=f)
+
+    #train critic
+    main.run()
+ 
+    # with open(exp_dir + "/test_performance_end.txt", "w") as f:
+        # task_run_critic(main.test_data, agent, None, evaluator=evaluator, outfile=f)
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '='*30)
+
+
+if __name__ == '__main__' :
+    # fixed target actor, i.e. model that the critic wants to evaluate
+    # for non-LAVA based model, use critic_json.py
+    folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48"
+    id_ = "reward_best"
+    main(None, folder, id_)
+
+
diff --git a/experiments_woz/critic_json.py b/experiments_woz/critic_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca94e953065649b6cb302a8c21cf5f3561237714
--- /dev/null
+++ b/experiments_woz/critic_json.py
@@ -0,0 +1,305 @@
+import time
+import os
+import sys
+import random
+sys.path.append('../')
+import json
+import torch as th
+import pdb
+from tqdm import trange
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+from latent_dialog.corpora import NormMultiWozCorpus
+from latent_dialog.models_task import *
+from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
+from latent_dialog.main import OfflineCritic
+from latent_dialog.evaluators import MultiWozEvaluator
+from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE
+from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic
+from argparse import ArgumentParser
+
+
+def main(seed, pretrained_folder, pretrained_model_id, response_path):
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[START]', start_time, '='*30)
+    # RL configuration
+    env = 'gpu'
+
+    exp_dir = os.path.join('sys_config_log_model', "/".join(response_path.split("/")[-2:-1]).replace(".json", ""), "critic-"+start_time)
+
+    if "rl" in pretrained_folder:
+        join_fmt = "."
+        config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json")
+    else:
+        join_fmt = "-"
+        config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json")
+
+    # create exp folder
+    if not os.path.exists(exp_dir):
+        os.makedirs(exp_dir)
+
+    critic_config = Pack(
+        config_path = config_path,
+        model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic
+        vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False
+        vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model",
+        actor_path = None,
+        critic_config_path = os.path.join(exp_dir, 'critic_config.json'),
+        critic_model_path = os.path.join(exp_dir, 'critic_model'),
+        saved_path = exp_dir,
+        # ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'),
+        # reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'),
+        record_path = exp_dir,
+        record_freq = 500,
+        sv_train_freq= 0,  # TODO pay attention to main.py, cuz it is also controlled there
+        use_gpu = env == 'gpu',
+        nepoch = 1,
+        nepisode = 1000,
+        word_plas=True,
+        raw_response=True,
+        response_path=response_path,
+        fix_episode=True,
+        train_with_pseudotraj=False,
+        train_with_full_data=False,
+        reward_type="default", #default, turnPenalty, or infoGain
+        infoGain_threshhold = 0.2, # only if reward_type is infoGain
+        add_match_to_reward=False,
+        soft_success=False,
+        goal_to_critic=True,
+        add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
+        critic_kl_loss=False,
+        critic_kl_alpha=0.1,
+        critic_dropout=True, # use with regularize_critic=False
+        critic_dropout_rate=0.3, # only if dropout is true
+        critic_dropout_agg="min", #avg or min
+        critic_sample=False,
+        critic_transformer=False,
+        critic_actf="sigmoid", #relu or sigmoid or tanh or none
+        embed_z_for_critic=True, #for categorical action only, when word_plas=False
+        # critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default
+        critic_loss="mse", #mse or huber
+        critic_rl_lr = 0.01,
+        decay_critic_lr=False,
+        train_vae=False,
+        train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
+        train_vae_nepisode=0, # nepisode for the rest of VAE training
+        train_vae_nepisode_init=10000, # nepisode for first VAE training
+        weighted_vae_nll=True,
+        rl_lr = 0.01,
+        max_words = 50,
+        temperature=1.0,
+        momentum = 0.0,
+        nesterov = False,
+        gamma = 0.99,
+        tau=0.005,
+        lmbda=0.5,
+        beta=0.001,
+        batch_size=16,
+        rl_clip = 1.0,
+        n_z=10,
+        random_seed = seed,
+        policy_dropout=False,
+        dropout_on_eval=False,
+        fail_info_penalty=False
+    )
+
+    prepare_dirs_loggers(critic_config)
+
+    # list config keys that are being compared for tensorboard naming
+    tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"]
+    tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys])
+
+    # load previous supervised learning configuration and corpus
+    config = Pack(json.load(open(critic_config.config_path)))
+    config['dropout'] = 0.0
+    config['use_gpu'] = critic_config.use_gpu
+    config['policy_dropout'] = critic_config.policy_dropout
+    config['dropout_on_eval'] = critic_config.dropout_on_eval
+    # assert config.train_path == critic_config.train_path
+
+    # set random seed
+    if critic_config.random_seed is None:
+        try:
+            critic_config.random_seed = config.seed
+        except:
+            critic_config.random_seed = config.random_seed
+    set_seed(critic_config.random_seed)
+
+
+    try:
+        corpus = NormMultiWozCorpus(config)
+    except FileNotFoundError:
+        config['train_path'] = config.train_path.replace("/home/lubis", "")
+        config['valid_path'] = config.valid_path.replace("/home/lubis", "")
+        config['test_path'] = config.test_path.replace("/home/lubis", "")
+        corpus = NormMultiWozCorpus(config)
+
+    critic_config['train_path'] = config['train_path']
+    critic_config['valid_path'] = config['valid_path']
+    critic_config['test_path'] = config['test_path']
+    
+    if critic_config.reward_type == "default":
+        critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill")
+        critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill")
+        critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill")
+    else:
+        critic_config['train_memory_path'] = config['train_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
+        critic_config['valid_memory_path'] = config['valid_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
+        critic_config['test_memory_path'] = config['test_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
+
+    if critic_config.fix_episode:
+        critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill")
+
+    if critic_config.soft_success:
+        critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-soft.dill")
+        critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-soft.dill")
+        critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-soft.dill")
+
+    if critic_config.add_match_to_reward:
+        critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-wMatch.dill")
+        critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-wMatch.dill")
+        critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-wMatch.dill")
+
+
+    critic_config['y_size'] = config['y_size']
+
+
+    # save configuration
+    with open(critic_config.critic_config_path, 'w') as f:
+        json.dump(critic_config, f, indent=4)
+
+    if "rl" in pretrained_folder:
+        if "gauss" in pretrained_folder:
+            sys_model = SysPerfectBD2Gauss(corpus, config)
+        else:
+            sys_model = SysPerfectBD2Cat(corpus, config)
+    else:
+        if "actz" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysActZGauss(corpus, config)
+            else:
+                sys_model = SysActZCat(corpus, config)
+        elif "mt" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysMTCat(corpus, config)
+        else:
+            if "gauss" in pretrained_folder:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Cat(corpus, config)
+
+    vae_config = Pack(json.load(open(critic_config.vae_config_path)))
+    if critic_config.raw_response and not critic_config.word_plas:
+        if "gauss" in critic_config.vae_model_path:
+            vae_model = SysAEGauss(corpus, vae_config)
+        else:
+            vae_model = SysAECat(corpus, vae_config)
+        vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage)
+        vae_model.load_state_dict(vae_model_dict)
+    else:
+        vae_model = None
+
+    if config.use_gpu:
+        sys_model.cuda()
+        if vae_model is not None:
+            vae_model.cuda()
+        
+
+    mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
+    sys_model.load_state_dict(mt_model_dict)
+
+    sys_model.eval()
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    if critic_config.word_plas:
+        agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+    else:
+        agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model)
+
+    main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic, name=tensorboard_name, vae_gen=task_generate)
+    # save sys model
+    # th.save(sys_model.state_dict(), critic_config.rl_model_path)
+
+    # initialize train buffer
+    if os.path.isfile(critic_config.train_memory_path):
+        print("Loading replay buffer for training from {}".format(critic_config.train_memory_path))
+        agent.train_buffer.load(critic_config.train_memory_path)
+        print(len(agent.train_buffer))
+        if "train_with_full_data" in critic_config and critic_config.train_with_full_data:
+            print(f"adding buffer {critic_config.valid_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.valid_memory_path)
+            print(len(agent.train_buffer))
+            print(f"adding buffer {critic_config.test_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.test_memory_path)
+            print(len(agent.train_buffer))
+
+    else:
+        print("Extracting experiences from training data")
+        main.extract(main.train_data, main.agent.train_buffer)
+        print("Saving experiences to {}".format(critic_config.train_memory_path))
+        agent.train_buffer.save(critic_config.train_memory_path)
+
+    # initialize valid buffer
+    if os.path.isfile(critic_config.valid_memory_path):
+        print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.load(critic_config.valid_memory_path)
+    else:
+        print("Extracting experiences from valid data")
+        main.extract(main.val_data, main.agent.valid_buffer)
+        print("Saving experiences to {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.save(critic_config.valid_memory_path)
+    
+    # initialize test buffer
+    if os.path.isfile(critic_config.test_memory_path):
+        print("Loading replay buffer for test from {}".format(critic_config.test_memory_path))
+        agent.test_buffer.load(critic_config.test_memory_path)
+    else:
+        print("Extracting experiences from test data")
+        main.extract(main.test_data, main.agent.test_buffer)
+        print("Saving experiences to {}".format(critic_config.test_memory_path))
+        agent.test_buffer.save(critic_config.test_memory_path)
+
+    if critic_config.use_gpu:
+        agent.critic.cuda()
+        agent.critic_target.cuda()
+
+    #check system performance
+    train_dial, val_dial, test_dial = corpus.get_corpus()
+    train_data = BeliefDbDataLoaders('Train', train_dial, vae_config)
+    val_data = BeliefDbDataLoaders('Val', val_dial, vae_config)
+    test_data = BeliefDbDataLoaders('Test', test_dial, vae_config)
+
+    with open(exp_dir + "/test_performance_start.txt", "w") as f:
+        task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
+    #train critic
+    main.run()
+ 
+    with open(exp_dir + "/test_performance_end.txt", "w") as f:
+        task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '='*30)
+
+
+if __name__ == '__main__' :
+    parser = ArgumentParser()
+
+    parser.add_argument("--infile", type=str, default="../data/augpt/test-predictions.json")
+    args = parser.parse_args()
+
+    # pick corresponding encoder
+    if "hdsa" in args.infile or "HDSA" in args.infile:
+        # MWOZ 2.0
+        folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48"
+        id_ = "reward_best"
+    elif "augpt" in args.infile:
+        #MWOZ 2.1
+        folder = "2021-11-25-11-52-47-mt_gauss" 
+        id_ = "29"
+
+    main(None, folder, id_, args.infile)
+
+
diff --git a/experiments_woz/critic_mwoz.py b/experiments_woz/critic_mwoz.py
new file mode 100644
index 0000000000000000000000000000000000000000..87bda4e117d93747e914a1d316fb5559b22ccd94
--- /dev/null
+++ b/experiments_woz/critic_mwoz.py
@@ -0,0 +1,272 @@
+import time
+import os
+import sys
+import random
+sys.path.append('../')
+import json
+import torch as th
+import pdb
+from tqdm import trange
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+from latent_dialog.corpora import NormMultiWozCorpus
+from latent_dialog.models_task import *
+from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
+from latent_dialog.main import OfflineCritic
+from latent_dialog.evaluators import MultiWozEvaluator
+from experiments_woz.dialog_utils import task_generate_critic, task_run_critic_on_behavior_policy
+
+
+def main(seed, pretrained_folder, pretrained_model_id):
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[START]', start_time, '='*30)
+    # RL configuration
+    env = 'gpu'
+
+    exp_dir = os.path.join('sys_config_log_model', "/".join(["mwoz", "critic-"+start_time]))
+    if "rl" in pretrained_folder:
+        join_fmt = "."
+        config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json")
+    else:
+        join_fmt = "-"
+        config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json")
+
+    # create exp folder
+    if not os.path.exists(exp_dir):
+        os.makedirs(exp_dir)
+
+    critic_config = Pack(
+        config_path = config_path,
+        model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic
+        vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False
+        vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model",
+        actor_path = None,
+
+        critic_config_path = os.path.join(exp_dir, 'critic_config.json'),
+        critic_model_path = os.path.join(exp_dir, 'critic_model'),
+        saved_path = exp_dir,
+
+        # ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'),
+        # reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'),
+        record_path = exp_dir,
+        record_freq = 500,
+        sv_train_freq= 0,  # TODO pay attention to main.py, cuz it is also controlled there
+        use_gpu = env == 'gpu',
+        nepoch = 10,
+        nepisode = 10000,
+        word_plas=True,
+        raw_response=False,
+        corpus_response=True,
+        # response_path=response_path,
+        fix_episode=True,
+        train_with_pseudotraj=False,
+        train_with_full_data=True,
+        reward_type="default", #default, turnPenalty, or infoGain
+        infoGain_threshhold = 0.2, # only if reward_type is infoGain
+        add_match_to_reward=False,
+        soft_success=False,
+        goal_to_critic=True,
+        add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
+        critic_kl_loss=False,
+        critic_kl_alpha=0.1,
+        critic_dropout=True, # use with regularize_critic=False
+        critic_dropout_rate=0.3, # only if dropout is true
+        critic_dropout_agg="min", #avg or min
+        critic_sample=False,
+        critic_transformer=False,
+        critic_actf="sigmoid", #relu or sigmoid or tanh or none
+        embed_z_for_critic=True, #for categorical action only, when word_plas=False
+        # critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default
+        critic_loss="mse", #mse or huber
+        critic_rl_lr = 0.01,
+        decay_critic_lr=False,
+        train_vae=False,
+        train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
+        train_vae_nepisode=0, # nepisode for the rest of VAE training
+        train_vae_nepisode_init=10000, # nepisode for first VAE training
+        weighted_vae_nll=True,
+        rl_lr = 0.01,
+        max_words = 50,
+        temperature=1.0,
+        momentum = 0.0,
+        nesterov = False,
+        gamma = 0.99,
+        tau=0.005,
+        lmbda=0.5,
+        beta=0.001,
+        batch_size=16,
+        rl_clip = 1.0,
+        n_z=10,
+        random_seed = seed,
+        policy_dropout=False,
+        dropout_on_eval=False,
+        fail_info_penalty=False
+    )
+
+    prepare_dirs_loggers(critic_config)
+
+    # list config keys that are being compared for tensorboard naming
+    tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"]
+    tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys])
+
+    # load previous supervised learning configuration and corpus
+    config = Pack(json.load(open(critic_config.config_path)))
+    config['dropout'] = 0.0
+    config['use_gpu'] = critic_config.use_gpu
+    config['policy_dropout'] = critic_config.policy_dropout
+    config['dropout_on_eval'] = critic_config.dropout_on_eval
+    # assert config.train_path == critic_config.train_path
+
+    # set random seed
+    if critic_config.random_seed is None:
+        try:
+            critic_config.random_seed = config.seed
+        except:
+            critic_config.random_seed = config.random_seed
+    set_seed(critic_config.random_seed)
+
+
+    try:
+        corpus = NormMultiWozCorpus(config)
+    except FileNotFoundError:
+        config['train_path'] = config.train_path.replace("/home/lubis", "")
+        config['valid_path'] = config.valid_path.replace("/home/lubis", "")
+        config['test_path'] = config.test_path.replace("/home/lubis", "")
+        corpus = NormMultiWozCorpus(config)
+
+    critic_config['train_path'] = config['train_path']
+    critic_config['valid_path'] = config['valid_path']
+    critic_config['test_path'] = config['test_path']
+    
+    critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill")
+    critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill")
+    critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill")
+
+    if critic_config.fix_episode:
+        critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill")
+        critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill")
+
+    critic_config['y_size'] = config['y_size']
+
+    # save configuration
+    with open(critic_config.critic_config_path, 'w') as f:
+        json.dump(critic_config, f, indent=4)
+
+    if "rl" in pretrained_folder:
+        if "gauss" in pretrained_folder:
+            sys_model = SysPerfectBD2Gauss(corpus, config)
+        else:
+            sys_model = SysPerfectBD2Cat(corpus, config)
+    else:
+        if "actz" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysActZGauss(corpus, config)
+            else:
+                sys_model = SysActZCat(corpus, config)
+        elif "mt" in pretrained_folder:
+            if "gauss" in pretrained_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysMTCat(corpus, config)
+        else:
+            if "gauss" in pretrained_folder:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Cat(corpus, config)
+
+    if critic_config.corpus_response and not critic_config.word_plas:
+        vae_config = Pack(json.load(open(critic_config.vae_config_path)))
+        if "gauss" in critic_config.vae_model_path:
+            vae_model = SysAEGauss(corpus, vae_config)
+        else:
+            vae_model = SysAECat(corpus, vae_config)
+        vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage)
+        vae_model.load_state_dict(vae_model_dict)
+    else:
+        vae_model = None
+
+    if config.use_gpu:
+        sys_model.cuda()
+        if vae_model is not None:
+            vae_model.cuda()
+        
+
+    mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
+    sys_model.load_state_dict(mt_model_dict)
+
+    sys_model.eval()
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    if critic_config.word_plas:
+        agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+    else:
+        agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model)
+
+    main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic_on_behavior_policy, name=tensorboard_name, vae_gen=None)
+    # save sys model
+    # th.save(sys_model.state_dict(), critic_config.rl_model_path)
+
+    # initialize train buffer
+    if os.path.isfile(critic_config.train_memory_path):
+        print("Loading replay buffer for training from {}".format(critic_config.train_memory_path))
+        agent.train_buffer.load(critic_config.train_memory_path)
+        if "train_with_full_data" in critic_config and critic_config.train_with_full_data:
+            print(f"adding buffer {critic_config.valid_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.valid_memory_path)
+            print(len(agent.train_buffer))
+            print(f"adding buffer {critic_config.test_memory_path} to training buffer")
+            agent.train_buffer.load_add(critic_config.test_memory_path)
+            print(len(agent.train_buffer))
+
+    else:
+        print("Extracting experiences from training data")
+        main.extract(main.train_data, main.agent.train_buffer)
+        print("Saving experiences to {}".format(critic_config.train_memory_path))
+        agent.train_buffer.save(critic_config.train_memory_path)
+
+    # initialize valid buffer
+    if os.path.isfile(critic_config.valid_memory_path):
+        print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.load(critic_config.valid_memory_path)
+    else:
+        print("Extracting experiences from valid data")
+        main.extract(main.val_data, main.agent.valid_buffer)
+        print("Saving experiences to {}".format(critic_config.valid_memory_path))
+        agent.valid_buffer.save(critic_config.valid_memory_path)
+    
+    # initialize test buffer
+    if os.path.isfile(critic_config.test_memory_path):
+        print("Loading replay buffer for test from {}".format(critic_config.test_memory_path))
+        agent.test_buffer.load(critic_config.test_memory_path)
+    else:
+        print("Extracting experiences from test data")
+        main.extract(main.test_data, main.agent.test_buffer)
+        print("Saving experiences to {}".format(critic_config.test_memory_path))
+        agent.test_buffer.save(critic_config.test_memory_path)
+
+    if critic_config.use_gpu:
+        agent.critic.cuda()
+        agent.critic_target.cuda()
+
+    with open(exp_dir + "/test_performance_start.txt", "w") as f:
+        task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
+
+    #train critic
+    main.run_behavior_policy()
+ 
+    with open(exp_dir + "/test_performance_end.txt", "w") as f:
+        task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '='*30)
+
+
+if __name__ == '__main__' :
+
+    #pretrained encoder for critic
+    folder = "2021-11-25-11-52-47-mt_gauss" 
+    id_ = "29"
+
+    main(None, folder, id_)
+
+
diff --git a/experiments_woz/dialog_utils.py b/experiments_woz/dialog_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af63c856d9f67c86bdd5175d9ed50f0936b5ba2e
--- /dev/null
+++ b/experiments_woz/dialog_utils.py
@@ -0,0 +1,906 @@
+import numpy as np
+import scipy.stats as st
+import torch as th
+from latent_dialog.enc2dec.decoders import GEN, DecoderRNN
+from latent_dialog.main import get_sent
+from latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type
+from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS
+from collections import defaultdict
+import pdb
+import warnings
+
+def mean_confidence_interval(data, confidence=0.95):
+    a = 1.0 * np.array(data)
+    n = len(a)
+    m, se = np.mean(a), st.sem(a)
+    h = se * st.t.ppf((1 + confidence) / 2., n-1)
+    return m, m-h, m+h
+
+def task_generate(model, data, config, evaluator, num_batch, dest_f=None, verbose=True, aux_mt=False):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    model.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+        if aux_mt:
+            #TODO aux rl forward?
+            outputs, labels = model.forward_aux(batch, mode=GEN, gen_type=config.gen_type)
+        else:
+            outputs, labels = model.forward(batch, mode=GEN, gen_type=config.gen_type)
+
+        # move from GPU to CPU
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            # TODO attn
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            if verbose and (num_batch is None or batch_cnt < 2):
+                write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('-' * 40)
+
+    task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+    write(task_report_new)
+    task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1
+
+def task_generate_critic(model, data, config, critic_config, evaluator, num_batch, dest_f=None, verbose=True, critic=None, actor=None, outfile=None):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    model.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    Q = []
+    generated_dialogs = defaultdict(list)
+    cur_key = ""
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+
+        batch_size = len(batch['bs'])
+
+        # with intermediate latent step
+        if not critic_config.word_plas:
+            if critic_config.actor_path is not None:
+                raise NotImplementedError
+            else:
+                if "gauss" in critic_config.model_path:
+                    sample_z, _, _ = model.get_z_via_rg(batch)
+                    corpus_z, _, _ = model.get_z_via_vae(batch['outputs'])
+                else:
+                    sample_z, _, _ = model.get_z_via_rg(batch, hard=True)
+                _, pred_labels = model.decode_z(sample_z, batch_size, batch, critic_config.max_words, critic_config.temperature)
+                if type(pred_labels[0]) == int:
+                    pred_labels = [pred_labels]
+                pred_labels_gpu =  model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in pred_labels]), LONG)
+            # NOTE the above pass result in lower success rate for categorical models
+        else:
+            # with forward pass
+            # outputs, labels = model.forward(batch, mode=GEN, gen_type=config.gen_type)
+            # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+            # pred_labels_gpu = model.np2var(pred_labels, LONG)
+
+            # with forward rl pass
+            if "actor_path" in critic_config and critic_config.actor_path is not None:
+                sample_z, _, _ = actor(batch)
+                _, pred_labels = model.decode_z(sample_z, batch_size, batch, critic_config.max_words, critic_config.temperature)
+            else:
+
+                logprobs, pred_labels, _, _ = model.forward_rl(batch, max_words=critic_config.max_words, temp=critic_config.temperature)
+            if type(pred_labels[0]) == int:
+                pred_labels = [pred_labels]
+            pred_labels_gpu =  model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in pred_labels]), LONG)
+
+
+        true_labels = batch['outputs']
+        true_labels_gpu =  model.np2var(np.asarray([model.pad_to(critic_config.max_words, a, do_pad=True) for a in true_labels.tolist()]), LONG)
+
+
+        if critic is not None:
+            with th.no_grad():
+                if critic_config.word_plas:
+                    # since we are only taking the first Q, we do not need forward_target
+                    Q.append(critic(batch, pred_labels_gpu))
+                    Q_pred = critic.forward_target(batch, pred_labels_gpu.unsqueeze(1), true_labels_gpu)
+                else:
+                    if critic.is_gauss:
+                        Q.append(critic(batch, sample_z)[0])
+                        Q_pred = critic.forward_target(batch, sample_z, corpus_z)
+                    else:
+                        if critic_config.embed_z_for_critic:
+                            cat_action = model.z_embedding(sample_z.view(1, -1, model.y_size * model.k_size)).squeeze(0)
+                        else:
+                            cat_action = sample_z.view(-1, model.y_size * model.k_size)
+
+                        Q.append(critic(batch, cat_action)[0])
+
+
+        # move from GPU to CPU
+        # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        # pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        # true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        if ctx_len is None:
+            pdb.set_trace()
+        keys = batch['keys']
+        if cur_key == "":
+            cur_key = keys[0]
+
+        for b_id in range(len(pred_labels)):
+            if keys[b_id] != cur_key:
+                tmp_gen = defaultdict(list)
+                tmp_gen[cur_key] = generated_dialogs[cur_key]
+                task_report_new, _, _= evaluator.evaluateModel(tmp_gen, mode=data.name, new_version=True)
+                print(task_report_new)
+                cur_key = keys[b_id]
+
+
+            # TODO attn
+            if critic_config.word_plas:
+                pred_str = get_sent(model.vocab, de_tknize, pred_labels_gpu, b_id)
+            else:
+                pred_str = get_sent(model.vocab, de_tknize, np.asarray(pred_labels[b_id]).reshape(1,-1), 0)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            # if verbose and (num_batch is None or batch_cnt < 2):
+            write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+            write('True: {}'.format(true_str, ))
+            write('Pred: {}'.format(pred_str, ))
+            write('Q: {}'.format(Q_pred[b_id]))
+            write('-' * 40)
+
+    task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+    write(task_report_new)
+    task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    if critic is not None:
+
+        data_q = np.asarray(th.cat(Q).cpu()).squeeze()
+        mean = np.mean(data_q)
+        std = np.std(data_q)
+        lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q))
+        if outfile is not None:
+            outfile.write(f"{mean}, {std}, {lower}, {upper}")
+        print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}")
+
+        # average_critic_return = th.mean(th.cat(Q)).item()
+        # variance_critic_return = th.var(th.cat(Q)).item()
+        # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}")
+    else:
+        mean = 0.0
+
+    # write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1, mean
+
+def task_run_critic(data, agent, num_batch, evaluator=None, dest_f=None, verbose=True, f=None, outfile=None):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    agent.cvae.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(agent.cvae.config, shuffle=num_batch is not None, verbose=False, fix_batch=agent.cvae.config.fix_batch)
+    if evaluator is not None:
+        evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    Q = []
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+
+        batch_size = len(batch['bs'])
+        key = batch['keys'][0]
+
+        if key in agent.raw_responses.keys():
+            pred_labels_gpu =  agent.cvae.np2var(np.asarray([agent.cvae.pad_to(agent.critic.args.max_words, a, do_pad=True) for a in agent.raw_responses[key]]), LONG)
+        else:
+            print(key, "skipped")
+            continue
+        
+        true_labels = batch['outputs']
+
+        with th.no_grad():
+            if agent.critic.args.word_plas:
+                # since we are only taking the first Q, we do not need forward_target
+                Q.append(agent.critic(batch, pred_labels_gpu))
+            else:
+                z_t, _, _ = agent.vae.get_z_via_vae(pred_labels_gpu)
+                Q.append(agent.critic(batch, z_t))
+
+        # get context
+        if evaluator is not None:
+            ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+            ctx_len = batch.get('context_lens')  # (batch_size, )
+            if ctx_len is None:
+                pdb.set_trace()
+            keys = batch['keys']
+
+            for b_id in range(len(pred_labels_gpu.cpu())):
+                if agent.args.word_plas:
+                    pred_str = get_sent(agent.cvae.vocab, de_tknize, pred_labels_gpu, b_id)
+                else:
+                    pred_str = get_sent(agent.cvae.vocab, de_tknize, np.asarray(pred_labels_gpu[b_id]).reshape(1,-1), 0)
+                true_str = get_sent(agent.cvae.vocab, de_tknize, true_labels, b_id)
+                prev_ctx = ''
+                if ctx is not None:
+                    ctx_str = []
+                    for t_id in range(ctx_len[b_id]):
+                        temp_str = get_sent(agent.cvae.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                        # print('temp_str = %s' % (temp_str, ))
+                        # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                        ctx_str.append(temp_str)
+                    ctx_str = '|'.join(ctx_str)[-200::]
+                    prev_ctx = 'Source context: {}'.format(ctx_str)
+
+                generated_dialogs[keys[b_id]].append(pred_str)
+                evaluator.add_example(true_str, pred_str)
+
+            # if verbose and (num_batch is None or batch_cnt < 2):
+                # write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                # write('True: {}'.format(true_str, ))
+                # write('Pred: {}'.format(pred_str, ))
+                # write('-' * 40)
+
+
+    if evaluator is not None:
+        task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+        outfile.write(task_report_new)
+        task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+        resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+        print(resp_report)
+        outfile.write(resp_report)
+
+    data_q = np.asarray(th.cat(Q).cpu()).squeeze()
+    mean = np.mean(data_q)
+    std = np.std(data_q)
+    lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q))
+    if outfile is not None:
+        outfile.write(f"\ncritic mean, std, lower_bound, upper_bound\n{mean}, {std}, {lower}, {upper}")
+
+    print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}")
+
+    # average_critic_return = th.mean(th.cat(Q)).item()
+    # variance_critic_return = th.var(th.cat(Q)).item()
+    # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}")
+
+    # write(task_report)
+    # write(resp_report)
+    # write('Generation Done')
+    return mean
+
+def task_run_critic_on_behavior_policy(data, agent, num_batch, dest_f=None, verbose=True, outfile=None):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    agent.cvae.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(agent.cvae.config, shuffle=num_batch is not None, verbose=False, fix_batch=agent.cvae.config.fix_batch)
+    # evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    Q = []
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+
+        batch_size = len(batch['bs'])
+        key = batch['keys'][0]
+        true_labels = agent.cvae.np2var(batch['outputs'], LONG)
+        with th.no_grad():
+            if not agent.critic.args.word_plas:
+                z_t, _, _ = agent.vae.get_z_via_vae(pred_labels_gpu)
+                Q.append(agent.critic(batch, z_t))
+            else:
+                Q.append(agent.critic(batch, true_labels))
+
+    data_q = np.asarray(th.cat(Q).cpu()).squeeze()
+    mean = np.mean(data_q)
+    std = np.std(data_q)
+    lower, upper = st.t.interval(0.95, len(data_q)-1, loc=np.mean(data_q), scale=st.sem(data_q))
+    if outfile is not None:
+        outfile.write(f"{mean}, {std}, {lower}, {upper}")
+
+    print(f"mean_q: {mean}, 95% confidence bound: {std}, lower bound: {lower}, upper bound: {upper}")
+
+
+    # average_critic_return = th.mean(th.cat(Q)).item()
+    # variance_critic_return = th.var(th.cat(Q)).item()
+    # write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}")
+
+    # write(task_report)
+    # write(resp_report)
+    # write('Generation Done')
+    return mean
+
+
+def task_generate_augpt(model, data, config, evaluator, num_batch, dest_f=None, verbose=True):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    model.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    # evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    generated_dialogs = defaultdict(list)
+    responses = []
+    true_responses = []
+    beliefs = []
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+
+        # move from GPU to CPU
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            # TODO attn
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            responses.append(pred_str)
+            true_responses.append(true_str)
+            beliefs.append(batch['raw_bs'][b_id])
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, np.array(ctx[b_id][t_id]).reshape(1, -1), 0, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+                bs_str = get_sent(model.vocab, de_tknize, np.array(batch['bs'][b_id]).reshape(1, -1), 0, stop_eos=False)
+                db_str = get_sent(model.vocab, de_tknize, np.array(batch['db'][b_id]).reshape(1, -1), 0, stop_eos=False)
+                prev_metadata = 'BS: {}, DB: {}'.format(bs_str, db_str)
+
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            if verbose and (num_batch is None or batch_cnt < 2):
+                write(prev_metadata)
+                write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('-' * 40)
+
+    corpus_success, corpus_match, corpus_domain_result, corpus_task_report = evaluator.evaluate(beliefs, true_responses)
+    write('====Corpus Result====')
+    write(corpus_task_report)
+
+    success, match, domain_result, task_report = evaluator.evaluate(beliefs, responses)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    write('====System result per domain====')
+    for d, (m, s) in domain_result.items():
+        write("{}: match {:2.2f}% success {:2.2f}%".format(d, m*100, s*100))
+        
+    write('====System result====')
+    write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1
+
+def task_generate_plas(actor, cvae, data, config, evaluator, num_batch, dest_f=None, verbose=True, critic=None):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    actor.eval()
+    cvae.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    Q = []
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+        try:
+            # _, z = actor(batch)
+            action, _, _ = actor(batch)
+        except:
+            _, action, _, _ = actor(batch)
+            # z = actor(batch)
+        if critic is not None:
+            with th.no_grad():
+                if actor.is_gauss:
+                    Q.append(critic(batch, action)[0])
+                else:
+                    if actor.embed_z_for_critic:
+                        cat_action = actor.z_embedding(action.view(1, -1, actor.y_size * actor.k_size)).squeeze(0)
+                    else:
+                        cat_action = action.view(-1, actor.y_size * actor.k_size)
+                    Q.append(critic(batch, cat_action)[0])
+
+
+        try:
+            logprobs, outputs = cvae.decode_z(action, len(batch['context_lens']), batch, config.max_dec_len)
+        except:
+            pdb.set_trace()
+            logprobs, outputs = cvae.decode_z(action, len(batch['context_lens']), batch, config.max_dec_len)
+        if type(outputs[0]) == int:
+            outputs = [outputs]
+        pred_labels = np.asarray([cvae.pad_to(config.max_dec_len, a, do_pad=True) for a in outputs])
+        true_labels = batch['outputs']
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            # TODO attn
+            pred_str = get_sent(cvae.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(cvae.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(cvae.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            if verbose and (num_batch is None or batch_cnt < 2):
+                write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('-' * 40)
+
+    task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+    write(task_report_new)
+    task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    if critic is not None:
+        average_critic_return = th.mean(th.cat(Q)).item()
+        variance_critic_return = th.var(th.cat(Q)).item()
+        write(f"Q_avg: {average_critic_return}, Q_var: {variance_critic_return}")
+    else:
+        average_critic_return = 0.0
+    write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1, average_critic_return
+
+def task_generate_wSampling(model, data, config, evaluator, num_batch, dest_f=None, verbose=True, n_z=10):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    def is_masked_action(bs_label, db_label, response):
+        """
+        check if the generated response should be masked based on belief state and db result
+        a) inform when there is no db match
+        b) inform no option when there is a match
+        c) out of domain action
+        d) no offer with a particular slot?
+        e) inform/request time on domains other than train and restaurant
+        """
+        for domain, bs_idx, db_idx in zip(DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX):
+            if bs_label[bs_idx] == 0: # if domain is inactive
+                if any([p in response for p in REQ_TOKENS[domain]]): # but a token from that domain is present
+                    print(">> inactive domain {} is mentioned".format(domain))
+                    return True
+            else: # domain is active
+                if any([p in response for p in ["sorry", "no", "not" "cannot"]]): # system inform no offer
+                    if db_idx < 0: # domain has no db
+                        print(">> inform no offer on domain {} without DB".format(domain))
+                        return True
+                    elif  db_label[db_idx] != 1: # there are matches
+                        print(">> inform no offer when there are matches on domain {}".format(domain))
+                        return True
+                    # if "[value_" in response:
+                        # print(">> inform no offer mentioning criteria")
+                        # return True # always only inform "no match for your criteria" w/o mentioning them explicitly
+                elif any([p in response for p in REQ_TOKENS[domain]]) or "i have [value_count]" in response or "there are [value_count]" in response: # if requestable token is present
+                    # TODO also check for i have [value_count] match, not only the requestable tokens
+                    if db_idx >= 0 and int(db_label[db_idx]) == 1: # and domain has a DB to be queried and there are no matches
+                        print(">> inform match when there are no DB match on domain {}".format(domain))
+                        return True
+
+        return False
+
+    model.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+
+        outputs = []
+        for i in range(n_z):
+            output, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+            pred = [t.cpu().data.numpy() for t in output[DecoderRNN.KEY_SEQUENCE]]
+            pred = np.array(pred, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+            outputs.append(pred) #(n_z, batch_size, max_dec_len)
+
+        # move from GPU to CPU
+        labels = labels.cpu()
+        # pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        # pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(true_labels.shape[0]):
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            # TODO attn
+            for i in range(n_z):
+                flag = False
+                pred_labels = np.array(outputs[i], dtype=int)
+                pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+
+                if not is_masked_action(batch['bs'][b_id], batch['db'][b_id], pred_str):
+                    flag = True
+                    break
+                else:
+                    if i == 0:
+                        print("----------")
+                        print(true_str)
+                    print(pred_str)
+            if not flag:
+                warnings.warn("No sampled actions passed the masking test, taking the last sampled action")
+               
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            if verbose and (num_batch is None or batch_cnt < 2):
+                write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('-' * 40)
+
+    task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+    write(task_report_new)
+    task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1
+
+def task_generate_actz(model, data, config, evaluator, num_batch, dest_f=None, enc="utt", verbose=True):
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    model.eval()
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False, fix_batch=config.fix_batch)
+    evaluator.initialize()
+    print('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    generated_dialogs = defaultdict(list)
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+        if enc=="aux":
+            outputs, labels = model.forward_aez(batch, mode=GEN, gen_type=config.gen_type)
+        else:
+            outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+
+        # move from GPU to CPU
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            # TODO attn
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            generated_dialogs[keys[b_id]].append(pred_str)
+            evaluator.add_example(true_str, pred_str)
+
+            if verbose and (num_batch is None or batch_cnt < 2):
+                write('%s-prev_ctx = %s' % (keys[b_id], prev_ctx,))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('-' * 40)
+    
+    task_report_new, _, _= evaluator.evaluateModel(generated_dialogs, mode=data.name, new_version=True)
+    write(task_report_new)
+    
+    task_report, success, match = evaluator.evaluateModel(generated_dialogs, mode=data.name)
+    resp_report, bleu, prec, rec, f1 = evaluator.get_report()
+    write(task_report)
+    write(resp_report)
+    write('Generation Done')
+    return success, match, bleu, f1
+
+def dump_latent(model, data, config):
+    latent_results = defaultdict(list)
+    model.eval()
+    batch_cnt = 0
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True)
+
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None:
+            break
+
+        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+            
+        sample_y = outputs['sample_z'].cpu().data.numpy().reshape(-1, config.y_size, config.k_size)
+        log_qy = outputs['log_qy'].cpu().data.numpy().reshape(-1, config.y_size, config.k_size)
+
+        if config.dec_use_attn:
+            attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]]
+            attns = np.array(attns).squeeze(2).swapaxes(0, 1)
+        else:
+            attns = None
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    ctx_str.append(temp_str)
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            latent_results[keys[b_id]].append({'context': prev_ctx, 'gt_resp': true_str,
+                                               'pred_resp': pred_str, 'domain': batch['goals_list'],
+                                               'sample_y': sample_y[b_id], 'log_qy': log_qy[b_id],
+                                               'attns': attns[b_id] if attns is not None else None})
+    latent_results = dict(latent_results)
+    return latent_results
+
+def dump_latent_gauss(model, data, config):
+    latent_results = defaultdict(list)
+    model.eval()
+    batch_cnt = 0
+    de_tknize = lambda x: ' '.join(x)
+    data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True)
+
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None:
+            break
+
+        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)
+        
+        sample_y = outputs['sample_z'].cpu().data.numpy().reshape(-1, config.y_size)
+        #TODO qmu is not stored in outputs
+        q_mu = outputs['q_mu'].cpu().data.numpy().reshape(-1, config.y_size)
+        q_logvar = outputs['q_logvar'].cpu().data.numpy().reshape(-1, config.y_size)
+
+        if config.dec_use_attn:
+            attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]]
+            attns = np.array(attns).squeeze(2).swapaxes(0, 1)
+        else:
+            attns = None
+
+        # get context
+        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens')  # (batch_size, )
+        keys = batch['keys']
+
+        for b_id in range(pred_labels.shape[0]):
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    ctx_str.append(temp_str)
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            latent_results[keys[b_id]].append({'context': prev_ctx, 'gt_resp': true_str,
+                                               'pred_resp': pred_str, 'domain': batch['goals_list'],
+                                               'sample_y': sample_y[b_id], 'q_mu': q_mu[b_id], 'q_logvar' : q_logvar[b_id],
+                                               'attns': attns[b_id] if attns is not None else None})
+    latent_results = dict(latent_results)
+    return latent_results
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/experiments_woz/mt_gauss.py b/experiments_woz/mt_gauss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c4853a2bcee54310c8732e5e337a183e6b0ad15
--- /dev/null
+++ b/experiments_woz/mt_gauss.py
@@ -0,0 +1,255 @@
+import time
+import os
+import json
+import sys
+from pathlib import Path
+sys.path.append((Path(__file__).parent / '../').resolve().as_posix())
+import torch as th
+import pdb
+import logging
+import random
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+import latent_dialog.corpora as corpora
+from latent_dialog.data_loaders import BeliefDbDataLoadersAE, BeliefDbDataLoaders
+from latent_dialog.evaluators import MultiWozEvaluator
+from latent_dialog.models_task import SysMTGauss
+from latent_dialog.main import train, validate, mt_train
+import latent_dialog.domain as domain
+from experiments_woz.dialog_utils import task_generate
+
+# def main(seed):
+
+def main(seed, pretrained_folder, pretrained_model_id):
+    domain_name = 'object_division'
+    domain_info = domain.get_domain(domain_name)
+
+    if pretrained_folder is not None:
+        ae_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json')
+        ae_config = Pack(json.load(open(ae_config_path)))
+        ae_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id))
+        train_path=ae_config.train_path
+        valid_path=ae_config.valid_path
+        test_path=ae_config.test_path
+
+    else:
+        ae_model_path = None
+        ae_config_path = None
+        train_path='../data/data_2.1/train_dials.json'
+        valid_path='../data/data_2.1/val_dials.json'
+        test_path='../data/data_2.1/test_dials.json'
+
+    if seed is None:
+        seed = ae_config.seed
+ 
+    base_path = Path(__file__).parent
+    config = Pack(
+        seed = seed,
+        ae_model_path=ae_model_path,
+        ae_config_path=ae_config_path,
+        train_path=train_path,
+        valid_path=valid_path,
+        test_path=test_path,
+        dact_path=(base_path / '../data/norm-multi-woz/damd_dialog_acts.json').resolve().as_posix(),
+        ae_zero_pad=True,
+        max_vocab_size=1000,
+        last_n_model=1,
+        max_utt_len=50,
+        max_dec_len=50,
+        backward_size=2,
+        batch_size=128,
+        use_gpu=True,
+        op='adam',
+        init_lr=0.0005,
+        l2_norm=1e-05,
+        momentum=0.0,
+        grad_clip=5.0,
+        dropout=0.5,
+        max_epoch=50,
+        # aux_train_freq=10,
+        # aux_max_epoch=1, # epoch per training, i.e every aux_train_freq, train aux_max_epoch time
+        shared_train=True,
+        embed_size=256,
+        num_layers=1,
+        utt_rnn_cell='gru',
+        utt_cell_size=300,
+        bi_utt_cell=True,
+        enc_use_attn=True,
+        dec_use_attn=False,
+        dec_rnn_cell='lstm',
+        dec_cell_size=300,
+        dec_attn_mode='cat',
+        y_size=200,
+        beta = 0.01,
+        # aux_pi_beta = 0.01, # default to 1.0 if not set
+        simple_posterior=True,
+        contextual_posterior=False,
+        use_mi =False,
+        use_pr =True,
+        use_diversity = False,
+        #
+        beam_size=20,
+        fix_batch = True,
+        fix_train_batch=False,
+        avg_type='word',
+        # avg_type='slot',
+        # slot_weight=10,
+        use_aux_kl=False,
+        selective_fine_tune=False,
+        print_step=300,
+        ckpt_step=1416,
+        improve_threshold=0.996,
+        patient_increase=2.0,
+        save_model=True,
+        early_stop=False,
+        gen_type='greedy',
+        preview_batch_num=50,
+        k=domain_info.input_length(),
+        init_range=0.1,
+        pretrain_folder='2021-11-25-16-47-37-mt_gauss',
+        forward_only=False
+    )
+    set_seed(config.seed)
+
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    stats_path = (base_path / 'sys_config_log_model').resolve().as_posix()
+    if config.forward_only:
+        saved_path = os.path.join(stats_path, config.pretrain_folder)
+        config = Pack(json.load(open(os.path.join(saved_path, 'config.json'))))
+        config['forward_only'] = True
+    else:
+        saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0])
+        if not os.path.exists(saved_path):
+            os.makedirs(saved_path)
+    config.saved_path = saved_path
+
+    if ae_config_path is not None:
+        config["max_dec_len"] = ae_config.max_dec_len
+        config["dec_use_attn"] = ae_config.dec_use_attn
+        config["dec_rnn_cell"] = ae_config.dec_rnn_cell
+        config["dec_cell_size"] = ae_config.dec_cell_size
+        config["dec_attn_mode"] = ae_config.dec_attn_mode
+        config["embed_size"] = ae_config.embed_size
+        config["y_size"] = ae_config.y_size
+        # Denoising setup
+        if "noise_type" in ae_config:
+            noise_type = ae_config.noise_type
+        else:
+            noise_type = None
+        if noise_type:
+            config["noise_type"] = ae_config.noise_type
+            config["noise_p"] = ae_config.noise_p
+            config["remove_tokens"] = ae_config.remove_tokens
+            config["no_special"] = ae_config.no_special
+    else:
+        noise_type = None
+        vocab_dict = None
+
+
+    prepare_dirs_loggers(config)
+    logger = logging.getLogger()
+    logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30))
+    config.saved_path = saved_path
+
+    # save configuration
+    with open(os.path.join(saved_path, 'config.json'), 'w') as f:
+        json.dump(config, f, indent=4)  # sort_keys=True
+
+    # data for AE training
+    aux_corpus = corpora.NormMultiWozCorpusAE(config)
+    vocab_dict = aux_corpus.vocab_dict if noise_type else None
+    aux_train_dial, aux_val_dial, aux_test_dial = aux_corpus.get_corpus()
+
+    aux_train_data = BeliefDbDataLoadersAE('Train', aux_train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger)
+    aux_val_data = BeliefDbDataLoadersAE('Val', aux_val_dial, config)
+    aux_test_data = BeliefDbDataLoadersAE('Test', aux_test_dial, config)
+
+    # data for RG training
+    corpus = corpora.NormMultiWozCorpus(config)
+    train_dial, val_dial, test_dial = corpus.get_corpus()
+
+    train_data = BeliefDbDataLoaders('Train', train_dial, config)
+    val_data = BeliefDbDataLoaders('Val', val_dial, config)
+    test_data = BeliefDbDataLoaders('Test', test_dial, config)
+
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    model = SysMTGauss(corpus, config)
+    model_dict = model.state_dict()
+
+    # load params from saved ae_model
+    if pretrained_folder is not None:
+        ae_model_dict = th.load(config.ae_model_path, map_location=lambda storage, location: storage)
+        tmp_k = [k for k in model_dict.keys() if k not in ae_model_dict.keys()]
+        aux_model_dict = {}
+        for k in tmp_k:
+            aux_model_dict[k] = ae_model_dict[k.replace("aux", "utt")] #utt_encoder param in ae model is now moved to aux_params in a new dict
+        model_dict.update(ae_model_dict) # all params except aux_encoder
+        model_dict.update(aux_model_dict) # aux encoder
+        model.load_state_dict(model_dict)
+
+    if config.use_gpu:
+        model.cuda()
+    
+    # only train utt_encoder
+    if config.ae_config_path is not None and config.selective_fine_tune:
+        for name, param in model.named_parameters():
+            if "utt_encoder" not in name:
+            # if "decoder" in name or "z_embedding" in name:
+                param.requires_grad = False
+
+    best_epoch = None
+    if not config.forward_only:
+        try:
+            best_epoch = mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=task_generate)
+        except KeyboardInterrupt:
+            print('Training stopped by keyboard.')
+    if best_epoch is None:
+        model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if 'model' in p and 'rl' not in p and 'aux' not in p])
+        best_epoch = model_ids[-1]
+
+    print("$$$ Load {}-model".format(best_epoch))
+    # config.batch_size = 32
+    model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch))))
+
+
+    logger.info("Forward Only Evaluation")
+
+    validate(model, val_data, config)
+    validate(model, test_data, config)
+
+    with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f)
+
+    with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f)
+
+    with open(os.path.join(saved_path, '{}_{}_valid_AE_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, aux_val_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True)
+     
+    with open(os.path.join(saved_path, '{}_{}_test_AE_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, aux_test_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True)
+
+
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '=' * 30)
+
+if __name__ == "__main__":
+    ### Train from scratch ###
+
+    for i in [1, 2, 3, 4]:
+        main(i, None, None)
+    
+
+    ### Load pre-trained AE ###
+    # multiwoz 2.0
+    # pretrained = {"2020-02-28-16-49-48-sl_gauss_ae":100}
+
+    # multiwoz 2.1
+    # pretrained = {"2020-10-15-17-11-59-sl_gauss_ae":92}
+    
+    # for p in pretrained.keys():
+        # folder = p
+        # id_ = pretrained[p]
+        # main(None, folder, id_)
+
diff --git a/experiments_woz/plas_gauss.py b/experiments_woz/plas_gauss.py
new file mode 100644
index 0000000000000000000000000000000000000000..65e8961270d78b64bc63fe6a614846d3c9714411
--- /dev/null
+++ b/experiments_woz/plas_gauss.py
@@ -0,0 +1,217 @@
+import time
+import os
+import sys
+import random
+sys.path.append('../')
+import json
+import torch as th
+import pdb
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+from latent_dialog.corpora import NormMultiWozCorpus
+from latent_dialog.models_task import SysMTGauss, SysActZGauss
+from latent_dialog.agent_task import LatentPLASAgent, PLASAgent
+from latent_dialog.main import OfflinePLAS
+from latent_dialog.evaluators import MultiWozEvaluator
+from experiments_woz.dialog_utils import task_generate_plas, task_generate
+
+
+def main(seed, pretrained_folder, pretrained_model_id):
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[START]', start_time, '='*30)
+    # RL configuration
+    env = 'gpu'
+
+    exp_dir = os.path.join('sys_config_log_model', pretrained_folder, 'plas_rl-'+start_time)
+    # create exp folder
+    if not os.path.exists(exp_dir):
+        os.mkdir(exp_dir)
+
+    rl_config = Pack(
+        sv_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json'),
+        sv_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)),
+        rl_config_path = os.path.join(exp_dir, 'rl_config.json'),
+        rl_model_path = os.path.join(exp_dir, 'rl_model'),
+        saved_path = exp_dir,
+        ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'),
+        reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'),
+        record_path = exp_dir,
+        record_freq = 100,
+        sv_train_freq= 0,  # TODO pay attention to main.py, cuz it is also controlled there
+        use_gpu = env == 'gpu',
+        nepoch = 1,
+        nepisode = 10000,
+        tune_pi_only=False,
+        is_stochastic=False,
+        word_plas=False,
+        z_loss=True, #set to False if word_plas == True
+        actor_rl_lr = 0.005,
+        importance_threshold=10,
+        fix_episode=True,
+        validate_with_critic=True,
+        goal_to_critic=True,
+        add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
+        critic_kl_loss=True,
+        critic_kl_alpha=0.1,
+        critic_dropout=True, # use with regularize_critic=False
+        critic_dropout_rate=0.3, # only if dropout is true
+        critic_dropout_agg="min", #avg or min
+        critic_sample=False,
+        critic_transformer=False,
+        critic_actf="none", #relu or sigmoid or tanh or none
+        critic_loss="mse", #mse or huber
+        critic_rl_lr = 0.01,
+        train_vae=True,
+        train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
+        train_vae_nepisode=0, # nepisode for the rest of VAE training
+        train_vae_nepisode_init=10000, # nepisode for first VAE training
+        weighted_vae_nll=True,
+        vae_beta=0.0, #during pre-training 0.01 or 0.1, no need to change distributions as we only tune the encoder
+        rl_lr = 0.01,
+        max_words = 50,
+        temperature=1.0,
+        momentum = 0.0,
+        nesterov = False,
+        gamma = 0.99,
+        tau=0.005,
+        lmbda=0.5,
+        beta=0.001,
+        batch_size=16,
+        rl_clip = 1.0,
+        n_z=10,
+        random_seed = seed,
+        policy_dropout=False,
+        dropout_on_eval=False,
+        fail_info_penalty=False
+    )
+
+    prepare_dirs_loggers(rl_config)
+
+    # list config keys that are being compared for tensorboard naming
+    tb_keys = ["validate_with_critic", "actor_rl_lr", "critic_rl_lr", "critic_actf"]
+    tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-" + "-".join([f"{k}={rl_config[k]}" for k in tb_keys])
+    # tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-" + "recurrent_critic_forward_target"
+
+    # load previous supervised learning configuration and corpus
+    sv_config = Pack(json.load(open(rl_config.sv_config_path)))
+    sv_config['dropout'] = 0.0
+    sv_config['use_gpu'] = rl_config.use_gpu
+    sv_config['policy_dropout'] = rl_config.policy_dropout
+    sv_config['dropout_on_eval'] = rl_config.dropout_on_eval
+    # assert sv_config.train_path == rl_config.train_path
+
+    # set random seed
+    if rl_config.random_seed is None:
+        rl_config.random_seed = sv_config.seed
+    set_seed(rl_config.random_seed)
+
+
+    try:
+        corpus = NormMultiWozCorpus(sv_config)
+    except FileNotFoundError:
+        sv_config['train_path'] = sv_config.train_path.replace("/home/lubis", "")
+        sv_config['valid_path'] = sv_config.valid_path.replace("/home/lubis", "")
+        sv_config['test_path'] = sv_config.test_path.replace("/home/lubis", "")
+        corpus = NormMultiWozCorpus(sv_config)
+
+    rl_config['train_path'] = sv_config['train_path']
+    rl_config['valid_path'] = sv_config['valid_path']
+    rl_config['test_path'] = sv_config['test_path']
+    
+    rl_config['train_memory_path'] = sv_config['train_path'].replace(".json", ".dill")
+    rl_config['valid_memory_path'] = sv_config['valid_path'].replace(".json", ".dill")
+    rl_config['test_memory_path'] = sv_config['test_path'].replace(".json", ".dill")
+
+    if rl_config.fix_episode:
+        rl_config['train_memory_path'] = rl_config['train_memory_path'].replace(".dill", "-ep.dill")
+        rl_config['valid_memory_path'] = rl_config['valid_memory_path'].replace(".dill", "-ep.dill")
+        rl_config['test_memory_path'] = rl_config['test_memory_path'].replace(".dill", "-ep.dill")
+
+    rl_config['y_size'] = sv_config['y_size']
+
+
+    # save configuration
+    with open(rl_config.rl_config_path, 'w') as f:
+        json.dump(rl_config, f, indent=4)
+
+    if "mt_" in pretrained_folder:
+        sys_model = SysMTGauss(corpus, sv_config)
+    else:
+        sys_model = SysActZGauss(corpus, sv_config)
+    if sv_config.use_gpu:
+        sys_model.cuda()
+
+    mt_model_dict = th.load(rl_config.sv_model_path, map_location=lambda storage, location: storage)
+    sys_model.load_state_dict(mt_model_dict)
+
+    sys_model.eval()
+    evaluator = MultiWozEvaluator('SysWoz', sv_config)
+
+    if rl_config.word_plas:
+        agent = PLASAgent(sys_model, corpus, rl_config, name='System', tune_pi_only=rl_config.tune_pi_only)
+    else:
+        agent = LatentPLASAgent(sys_model, corpus, rl_config, evaluator, name='System', tune_pi_only=rl_config.tune_pi_only)
+
+    plas = OfflinePLAS(agent, corpus, sv_config, rl_config, task_generate_plas, name=tensorboard_name, vae_gen=task_generate)
+    # save sys model
+    # th.save(sys_model.state_dict(), rl_config.rl_model_path)
+
+    # initialize train buffer
+    if os.path.isfile(rl_config.train_memory_path):
+    # if False:
+        print("Loading replay buffer for training from {}".format(rl_config.train_memory_path))
+        plas.agent.train_buffer.load(rl_config.train_memory_path)
+    else:
+        print("Extracting experiences from training data")
+        plas.extract(plas.train_data, plas.agent.train_buffer)
+        print("Saving experiences to {}".format(rl_config.train_memory_path))
+        plas.agent.train_buffer.save(rl_config.train_memory_path)
+
+    # initialize valid buffer
+    if os.path.isfile(rl_config.valid_memory_path):
+        print("Loading replay buffer for validation from {}".format(rl_config.valid_memory_path))
+        plas.agent.valid_buffer.load(rl_config.valid_memory_path)
+    else:
+        print("Extracting experiences from valid data")
+        plas.extract(plas.val_data, plas.agent.valid_buffer)
+        print("Saving experiences to {}".format(rl_config.valid_memory_path))
+        plas.agent.valid_buffer.save(rl_config.valid_memory_path)
+    
+    # initialize test buffer
+    if os.path.isfile(rl_config.test_memory_path):
+        print("Loading replay buffer for test from {}".format(rl_config.test_memory_path))
+        plas.agent.test_buffer.load(rl_config.test_memory_path)
+    else:
+        print("Extracting experiences from test data")
+        plas.extract(plas.test_data, plas.agent.test_buffer)
+        print("Saving experiences to {}".format(rl_config.test_memory_path))
+        plas.agent.test_buffer.save(rl_config.test_memory_path)
+
+    if sv_config.use_gpu:
+        agent.actor.cuda()
+        agent.actor_target.cuda()
+        agent.critic.cuda()
+        agent.critic_target.cuda()
+
+    # start RL
+    plas.run()
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '='*30)
+
+
+if __name__ == '__main__' :
+    pretrained = {}
+    f_in = "sys_config_log_model/tmp.lst"
+
+    with open(f_in, "r") as f:
+        lines = f.readlines()
+        for l in lines:
+            if ";" not in l and "cat" not in l and "rl" not in l:
+                pretrained["/".join(l.split("/")[:-1])] = int(l.split("/")[-1].split("_")[1])
+    
+    for p in pretrained.keys():
+        folder = p
+        id_ = pretrained[p]
+        main(None, folder, id_)
+
+
diff --git a/experiments_woz/run_critic.py b/experiments_woz/run_critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..b113b93713459737036f10003617781b98ae4086
--- /dev/null
+++ b/experiments_woz/run_critic.py
@@ -0,0 +1,135 @@
+import time
+import os
+import sys
+import random
+sys.path.append('../')
+import json
+import torch as th
+import pdb
+from tqdm import trange
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+from latent_dialog.corpora import NormMultiWozCorpus
+from latent_dialog.models_task import *
+from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
+from latent_dialog.main import OfflineCritic
+from latent_dialog.evaluators import MultiWozEvaluator
+from experiments_woz.dialog_utils import task_generate_critic, task_generate
+
+
+def main(pretrained_critic_folder):
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[START]', start_time, '='*30)
+    # RL configuration
+    env = 'gpu'
+
+    exp_dir = os.path.join('sys_config_log_model', "/".join(pretrained_critic_folder.split("/")[:-1]))
+    print(exp_dir)
+    critic_config_path = exp_dir + "/critic_config.json" 
+
+    critic_config = Pack(json.load(open(critic_config_path)))
+    config = Pack(json.load(open(critic_config.config_path)))
+    config['dropout'] = 0.0
+    config['use_gpu'] = critic_config.use_gpu
+    config['policy_dropout'] = critic_config.policy_dropout
+    config['dropout_on_eval'] = critic_config.dropout_on_eval
+
+    # set random seed
+    if critic_config.random_seed is None:
+        try:
+            critic_config.random_seed = config.seed
+        except:
+            critic_config.random_seed = config.random_seed
+    set_seed(critic_config.random_seed)
+
+
+    try:
+        corpus = NormMultiWozCorpus(config)
+    except FileNotFoundError:
+        config['train_path'] = config.train_path.replace("/home/lubis", "")
+        config['valid_path'] = config.valid_path.replace("/home/lubis", "")
+        config['test_path'] = config.test_path.replace("/home/lubis", "")
+        corpus = NormMultiWozCorpus(config)
+    except PermissionError:
+        pdb.set_trace()
+        config['train_path'] = config.train_path.replace("/root", "..")
+        config['valid_path'] = config.valid_path.replace("/root", "..")
+        config['test_path'] = config.test_path.replace("/root", "..")
+        corpus = NormMultiWozCorpus(config)
+
+    if "augpt" in pretrained_critic_folder or "mwoz" in pretrained_critic_folder or "HDSA" in pretrained_critic_folder:
+        pretrained_critic_folder = critic_config.model_path
+
+    if "rl" in pretrained_critic_folder:
+        if "gauss" in pretrained_critic_folder:
+            if "plas" in pretrained_critic_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+        else:
+            sys_model = SysPerfectBD2Cat(corpus, config)
+    else:
+        if "actz" in pretrained_critic_folder:
+            if "gauss" in pretrained_critic_folder:
+                sys_model = SysActZGauss(corpus, config)
+            else:
+                sys_model = SysActZCat(corpus, config)
+        elif "mt" in pretrained_critic_folder:
+            if "gauss" in pretrained_critic_folder:
+                sys_model = SysMTGauss(corpus, config)
+            else:
+                sys_model = SysMTCat(corpus, config)
+        else:
+            if "gauss" in pretrained_critic_folder:
+                sys_model = SysPerfectBD2Gauss(corpus, config)
+            else:
+                sys_model = SysPerfectBD2Cat(corpus, config)
+
+    if config.use_gpu:
+        sys_model.cuda()
+
+    model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
+    sys_model.load_state_dict(model_dict)
+
+    sys_model.eval()
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    if critic_config.word_plas or critic_config.raw_response:
+        agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+    else:
+        agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
+
+
+    agent_critic_dict = th.load(exp_dir + "/critic_model", map_location=lambda storage, location: storage)
+    agent.critic.load_state_dict(agent_critic_dict)
+    if "actor_path" in critic_config and critic_config.actor_path is not None:
+        agent_actor_dict = th.load(critic_config.actor_path, map_location=lambda storage, location: storage)
+        agent.actor.load_state_dict(agent_actor_dict)
+
+    main = OfflineCritic(agent, corpus, config, critic_config, task_generate_critic, name="", vae_gen=task_generate, forward_only=True)
+
+    if critic_config.use_gpu:
+        agent.critic.cuda()
+        agent.critic_target.cuda()
+
+    with open(exp_dir + "/conf_interval.txt", "w") as f:
+        t_success, t_match, t_bleu, t_f1, t_Q = main.generate_func(main.agent.cvae, main.test_data, main.sv_config, critic_config, main.evaluator, None, verbose=False, critic=agent.critic, actor = agent.actor, outfile=f)
+    
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '='*30)
+
+
+if __name__ == '__main__' :
+    pretrained = []
+    # list of test-critic.tsv files of critics to run
+    f_in = "sys_config_log_model/critic_report.lst"
+
+    with open(f_in, "r") as f:
+        lines = f.readlines()
+        for l in lines:
+            if ";" not in l:
+                pretrained.append(l)
+    
+    for p in pretrained:
+        main(p)
+
+
diff --git a/experiments_woz/sl_gauss_ae.py b/experiments_woz/sl_gauss_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..041d0488e2ea4876d78175750e3847c8df715c0a
--- /dev/null
+++ b/experiments_woz/sl_gauss_ae.py
@@ -0,0 +1,165 @@
+import time
+import os
+import sys
+sys.path.append('../')
+import json
+import pdb
+import torch as th
+import logging
+import random
+from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
+import latent_dialog.corpora as corpora
+from latent_dialog.data_loaders import BeliefDbDataLoadersAE
+from latent_dialog.evaluators import MultiWozEvaluator
+from latent_dialog.models_task import SysAEGauss
+from latent_dialog.main import train, validate
+import latent_dialog.domain as domain
+from dialog_utils import task_generate
+import pickle as pkl
+
+# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
+# os.environ["CUDA_VISIBLE_DEVICES"]="1"
+
+def main(seed):
+    domain_name = 'object_division'
+    domain_info = domain.get_domain(domain_name)
+    config = Pack(
+        seed=seed,
+        # train_path='../data/norm-multi-woz/train_dials.json',
+        # valid_path='../data/norm-multi-woz/val_dials.json',
+        # test_path='../data/norm-multi-woz/test_dials.json',
+        train_path='../data/data_2.1/train_dials.json',
+        valid_path='../data/data_2.1/val_dials.json',
+        test_path='../data/data_2.1/test_dials.json',
+        dact_path='../data/norm-multi-woz/damd_dialog_acts.json',
+        max_vocab_size=1000,
+        last_n_model=1,
+        max_utt_len=50,
+        max_dec_len=50,
+        backward_size=2,
+        batch_size=128,
+        use_gpu=True,
+        op='adam',
+        init_lr=0.001,
+        l2_norm=1e-05,
+        momentum=0.0,
+        grad_clip=5.0,
+        dropout=0.5,
+        max_epoch=100,
+        embed_size=256,
+        num_layers=1,
+        utt_rnn_cell='gru',
+        utt_cell_size=300,
+        bi_utt_cell=True,
+        enc_use_attn=True,
+        dec_use_attn=False,
+        dec_rnn_cell='lstm',
+        # must be same as ctx_cell_size due to the passed initial state
+        dec_cell_size=300,
+        # must be same as ctx_cell_size due to the passed initial state
+        dec_attn_mode='cat',
+        y_size=200,
+        beta = 0.01,
+        simple_posterior=True,
+        contextual_posterior=False,
+        use_pr = True,
+        use_metadata=False,
+        ae_zero_padding=True,
+        beam_size=20,
+        fix_batch = True,
+        fix_train_batch=False,
+        avg_type='word',
+        print_step=300,
+        ckpt_step=1416,
+        improve_threshold=0.996,
+        patient_increase=2.0,
+        save_model=True,
+        early_stop=False,
+        gen_type='greedy',
+        preview_batch_num=50,
+        k=domain_info.input_length(),
+        init_range=0.1,
+        pretrain_folder='2021-11-09-13-55-11-sl_gauss_ae',
+        ## denoising
+        # noise_type=None,
+        noise_type="tokens_removal",
+        noise_p=0.1, # Probability to add noise to a response
+        no_special=False, # True - do not allow tokens switching with special tokens
+        remove_tokens=False, # True -remove token completely, False - change it with UNK
+        forward_only=False
+    )
+    set_seed(config.seed)
+    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    stats_path = 'sys_config_log_model'
+    if config.forward_only:
+        saved_path = os.path.join(stats_path, config.pretrain_folder)
+        config = Pack(json.load(open(os.path.join(saved_path, 'config.json'))))
+        config['forward_only'] = True
+    else:
+        saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0])
+        if not os.path.exists(saved_path):
+            os.makedirs(saved_path)
+    config.saved_path = saved_path
+
+    prepare_dirs_loggers(config)
+    logger = logging.getLogger()
+    logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30))
+    config.saved_path = saved_path
+
+    # save configuration
+    with open(os.path.join(saved_path, 'config.json'), 'w') as f:
+        json.dump(config, f, indent=4)  # sort_keys=True
+
+    corpus = corpora.NormMultiWozCorpusAE(config)
+    train_dial, val_dial, test_dial = corpus.get_corpus()
+
+    # Denoising setup
+    noise_type = config.noise_type
+    vocab_dict = corpus.vocab_dict if noise_type else None
+
+    train_data = BeliefDbDataLoadersAE('Train', train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger)
+    val_data = BeliefDbDataLoadersAE('Val', val_dial, config)
+    test_data = BeliefDbDataLoadersAE('Test', test_dial, config)
+
+    config['n_iter'] = config.ckpt_step  * config.max_epoch
+
+    evaluator = MultiWozEvaluator('SysWoz', config)
+
+    model = SysAEGauss(corpus, config)
+
+    if config.use_gpu:
+        model.cuda()
+
+    best_epoch = None
+    if not config.forward_only:
+        try:
+            # best_epoch = train(model, train_data, val_data, test_data, config, evaluator, gen=task_generate)
+            best_epoch = train(model, train_data, val_data, test_data, config, evaluator, gen=None)
+        except KeyboardInterrupt:
+            print('Training stopped by keyboard.')
+    if best_epoch is None:
+        model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if '-model' in p])
+        best_epoch = model_ids[-1]
+
+    print("$$$ Load {}-model".format(best_epoch))
+    config.batch_size = 32
+    model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch))))
+
+    logger.info("Forward Only Evaluation")
+
+    validate(model, val_data, config)
+    validate(model, test_data, config)
+
+    with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f)
+
+    with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f:
+        task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f)
+
+    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
+    print('[END]', end_time, '=' * 30)
+
+if __name__ == "__main__":
+    for i in [5, 23, 42, 72, 112]:
+        main(i)
+
diff --git a/experiments_woz/sys_config_log_model/pretrained.zip b/experiments_woz/sys_config_log_model/pretrained.zip
new file mode 100644
index 0000000000000000000000000000000000000000..939ab34e46f57bf4a39dc3cfd1e17e2e8afd2bf4
Binary files /dev/null and b/experiments_woz/sys_config_log_model/pretrained.zip differ
diff --git a/latent_dialog/__init__.py b/latent_dialog/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cabc5d9c605965db3e69614595ab9bb891814c9
--- /dev/null
+++ b/latent_dialog/__init__.py
@@ -0,0 +1,2 @@
+# @Time    : 10/18/18 1:55 PM
+# @Author  : Tiancheng Zhao
\ No newline at end of file
diff --git a/latent_dialog/__pycache__/__init__.cpython-36.pyc b/latent_dialog/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a40514e24fef0191401e7d2520b656d034a58a24
Binary files /dev/null and b/latent_dialog/__pycache__/__init__.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/agent_task.cpython-36.pyc b/latent_dialog/__pycache__/agent_task.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b870eafe6a1870ae7775a1af813060a3121846a
Binary files /dev/null and b/latent_dialog/__pycache__/agent_task.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc b/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d1a90c26006d81f3c806cffc69fbd96eab587f4
Binary files /dev/null and b/latent_dialog/__pycache__/augpt_utils.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc b/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f4f94c4300597092fb0fce488727881b92e95d6
Binary files /dev/null and b/latent_dialog/__pycache__/base_data_loaders.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/base_models.cpython-36.pyc b/latent_dialog/__pycache__/base_models.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2d43761c04791f5a244f6ba42f32fbf5f5c238e
Binary files /dev/null and b/latent_dialog/__pycache__/base_models.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/corpora.cpython-36.pyc b/latent_dialog/__pycache__/corpora.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2083c5776fd671a1ecf5808f3db31f07d6665c01
Binary files /dev/null and b/latent_dialog/__pycache__/corpora.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/criterions.cpython-36.pyc b/latent_dialog/__pycache__/criterions.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4847cef4e410bb5de379ff1d1f3333bc1a3dcf71
Binary files /dev/null and b/latent_dialog/__pycache__/criterions.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/data_loaders.cpython-36.pyc b/latent_dialog/__pycache__/data_loaders.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e19daa6211efafd5ec322229d4baba7b4149cd01
Binary files /dev/null and b/latent_dialog/__pycache__/data_loaders.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/domain.cpython-36.pyc b/latent_dialog/__pycache__/domain.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fea5bacdc823e0b903ed2012e92f451da85ce27
Binary files /dev/null and b/latent_dialog/__pycache__/domain.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/evaluators.cpython-36.pyc b/latent_dialog/__pycache__/evaluators.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8221ef99bfc4b830075dede8395a49cba2f79b82
Binary files /dev/null and b/latent_dialog/__pycache__/evaluators.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/main.cpython-36.pyc b/latent_dialog/__pycache__/main.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3822d03171b95140a4382e57850dcebb93347388
Binary files /dev/null and b/latent_dialog/__pycache__/main.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/models_task.cpython-36.pyc b/latent_dialog/__pycache__/models_task.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c81368e21c74f01de607eea0f1b779ea20211a7
Binary files /dev/null and b/latent_dialog/__pycache__/models_task.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/nn_lib.cpython-36.pyc b/latent_dialog/__pycache__/nn_lib.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6460889d0df5dd58a4f10b237aa4ad776e492a3
Binary files /dev/null and b/latent_dialog/__pycache__/nn_lib.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc b/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1697d553dc333135918fbfe07c8cdfb9bbb9f519
Binary files /dev/null and b/latent_dialog/__pycache__/offlinerl_utils.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/record.cpython-36.pyc b/latent_dialog/__pycache__/record.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2745dbe54ae7796ce791d1a9fcabe06cc488ac9f
Binary files /dev/null and b/latent_dialog/__pycache__/record.cpython-36.pyc differ
diff --git a/latent_dialog/__pycache__/utils.cpython-36.pyc b/latent_dialog/__pycache__/utils.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3db9a1ff54198ce831a4511c878887141c3d5ba6
Binary files /dev/null and b/latent_dialog/__pycache__/utils.cpython-36.pyc differ
diff --git a/latent_dialog/agent_task.py b/latent_dialog/agent_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..414214799c982b6e032fe465836394e6a0fd3ca6
--- /dev/null
+++ b/latent_dialog/agent_task.py
@@ -0,0 +1,1746 @@
+import torch as th
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import numpy as np
+from latent_dialog.utils import LONG, FLOAT, Pack, get_detokenize, np2var
+from latent_dialog.main import get_sent, LossManager
+from latent_dialog.data_loaders import BeliefDbDataLoaders
+from latent_dialog.base_models import frange_cycle_linear
+from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS
+from latent_dialog.enc2dec.decoders import TEACH_FORCE
+from latent_dialog.criterions import NormKLLoss, CatKLLoss
+from latent_dialog.offlinerl_utils import *
+from latent_dialog.augpt_utils import augpt_normalize, replace_augpt_tokens, replace_hdsa_tokens
+from collections import deque, namedtuple, defaultdict
+import random
+import pdb
+import logging
+import dill
+import warnings
+import copy
+import json
+import time
+
+logger = logging.getLogger()
+
+class OfflineRlAgent(object):
+    def __init__(self, model, corpus, args, name, tune_pi_only):
+        self.model = model
+        self.corpus = corpus
+        self.args = args
+        self.name = name
+        self.raw_goal = None
+        self.vec_goals_list = None
+        self.logprobs = None
+        print("Do we only tune the policy: {}".format(tune_pi_only))
+        self.opt = optim.SGD(
+            [p for n, p in self.model.named_parameters() if 'c2z' in n or not tune_pi_only],
+            lr=self.args.rl_lr,
+            momentum=self.args.momentum,
+            nesterov=(self.args.nesterov and self.args.momentum > 0))
+        self.all_rewards = []
+        self.all_grads = []
+        self.model.train()
+
+    def print_dialog(self, dialog, reward, stats):
+        for t_id, turn in enumerate(dialog):
+            if t_id % 2 == 0:
+                print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>'])))
+            else:
+                print("Sys: {}".format(' '.join(turn)))
+        report = ['{}: {}'.format(k, v) for k, v in stats.items()]
+        print("Reward {}. {}".format(reward, report))
+
+    def run(self, batch, evaluator, max_words=None, temp=0.1):
+        self.logprobs = []
+        self.dlg_history =[]
+        batch_size = len(batch['keys'])
+        logprobs, outs = self.model.forward_rl(batch, max_words, temp)
+        if batch_size == 1:
+            logprobs = [logprobs]
+            outs = [outs]
+
+        key = batch['keys'][0]
+        sys_turns = []
+        # construct the dialog history for printing
+        for turn_id, turn in enumerate(batch['contexts']):
+            user_input = self.corpus.id2sent(turn[-1])
+            self.dlg_history.append(user_input)
+            sys_output = self.corpus.id2sent(outs[turn_id])
+            self.dlg_history.append(sys_output)
+            sys_turns.append(' '.join(sys_output))
+
+        for log_prob in logprobs:
+            self.logprobs.extend(log_prob)
+        generated_dialog = {key: sys_turns}
+        return evaluator.evaluateModel(generated_dialog, mode="offline_rl")
+
+    def update(self, reward, stats):
+        self.all_rewards.append(reward)
+        # standardize the reward
+        r = (reward - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards))
+        # compute accumulated discounted reward
+        g = self.model.np2var(np.array([r]), FLOAT).view(1, 1)
+        rewards = []
+        for _ in self.logprobs:
+            rewards.insert(0, g)
+            g = g * self.args.gamma
+
+        loss = 0
+        # estimate the loss using one MonteCarlo rollout
+        for lp, r in zip(self.logprobs, rewards):
+            loss -= lp * r
+        print(loss)
+        self.opt.zero_grad()
+        loss.backward()
+        nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip)
+        self.opt.step()
+
+class OfflineLatentRlAgent(OfflineRlAgent):
+    def run(self, batch, evaluator, max_words=None, temp=0.1):
+        self.logprobs = []
+        self.dlg_history =[]
+        batch_size = len(batch['keys'])
+        logprobs, outs, logprob_z, sample_z = self.model.forward_rl(batch, max_words, temp)
+        if batch_size == 1:
+            outs = [outs]
+        key = batch['keys'][0]
+        sys_turns = []
+        # construct the dialog history for printing
+        for turn_id, turn in enumerate(batch['contexts']):
+            user_input = self.corpus.id2sent(turn[-1])
+            # print("Usr: {}".format(' '.join([t for t in user_input if t != '<pad>'])))
+            self.dlg_history.append(user_input)
+            sys_output = self.corpus.id2sent(outs[turn_id])
+            self.dlg_history.append(sys_output)
+            # print("Sys: {}".format(' '.join([t for t in sys_output if t != '<pad>'])))
+            sys_turns.append(' '.join(sys_output))
+
+        for b_id in range(batch_size):
+            self.logprobs.append(logprob_z[b_id])
+        # compute reward here
+        generated_dialog = {key: sys_turns}
+        return evaluator.evaluateModel(generated_dialog, mode="offline_rl")
+
+class PLASAgent(object):
+    def __init__(self, cvae, corpus, args, name, tune_pi_only):
+        self.is_stochastic = args.is_stochastic
+        self.fix_episode = args.fix_episode
+
+        if "gauss" in args.sv_config_path:
+            self.is_gauss = True 
+        else:
+            self.is_gauss = False 
+            self.z_embedding = copy.deepcopy(cvae.z_embedding)
+            self.embed_z_for_critic = args.embed_z_for_critic
+
+        self.actor = Actor(cvae, corpus, args)
+        self.actor_target = copy.deepcopy(self.actor)
+        self.opt = optim.SGD(self.actor.parameters(), lr = args.rl_lr) 
+        
+        self.cvae = cvae #plas cvae model
+        for n, p in self.cvae.named_parameters():
+                p.requires_grad = False
+
+        if "z_loss" not in args:
+            args['z_loss'] = False
+        else:
+            self.z_loss = args.z_loss
+            if self.is_gauss:
+                self.z_lossf = NormKLLoss(unit_average=True)
+            else:
+                self.z_lossf = nn.CrossEntropyLoss()
+
+        if args.critic_loss == "mse":
+            self.q_lossf = nn.MSELoss()
+        elif args.critic_loss == "huber":
+            self.q_lossf = nn.HuberLoss()
+        self.regf = nn.MSELoss()
+
+        self.args = args
+        self.discount = args.gamma
+        self.tau = args.tau
+        self.lmbda = args.lmbda
+        self.beta = args.beta
+
+        self.critic_dropout = args.critic_dropout
+        if args.critic_dropout:
+            if self.fix_episode:
+                self.critic = SingleHierarchicalRecurrentCritic(self.cvae, corpus, self.cvae.config, args)
+            else:
+                self.critic = SingleRecurrentCritic(self.cvae, corpus, self.cvae.config, args)
+
+        else:
+            self.critic = RecurrentCritic(self.cvae, corpus, self.cvae.config, args)
+        self.critic_target = copy.deepcopy(self.critic)
+        self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.rl_lr)
+
+        self.q_lossf = nn.MSELoss()
+        self.regf = nn.MSELoss()
+
+        self.train_buffer = ReplayBuffer(args)
+        self.valid_buffer = ReplayBuffer(args)
+        self.test_buffer = ReplayBuffer(args)
+        self.corpus = corpus
+        self.name = name
+        self.raw_goal = None
+        self.vec_goals_list = None
+        self.logprobs = None
+        self.n_z = args.n_z
+
+    def print_dialog(self, dialog, reward, stats):
+        for t_id, turn in enumerate(dialog):
+            if t_id % 2 == 0:
+                print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>'])))
+            else:
+                print("Sys: {}".format(' '.join(turn)))
+        # report = ['{}: {}'.format(k, v) for k, v in stats.items()]
+        # print("Reward {}. {}".format(reward, report))
+        print("Reward {}\n{}".format(reward, stats))
+    
+    def run(self, batch, evaluator):
+        """
+        run one dialogue and compute success rate with pseudo trajectory
+        """
+        self.logprobs = []
+        self.dlg_history =[]
+        batch_size = len(batch['keys'])
+        de_tknize = get_detokenize()
+
+        # ret_dict contains the z and the response 
+        z = self.actor(batch)
+        _, outputs, _, sample_y = self.cvae.decode_z(z, batch_size, batch, self.args.max_words, self.args.temperature) 
+        pred_labels = np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in outputs])
+
+        key = batch['keys'][0]
+
+        sys_turns = []
+        # construct the dialog history for printing and calculating reward
+        for turn_id, turn in enumerate(batch['contexts']):
+            # user_input = self.corpus.id2sent(turn[-1])
+            user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+            self.dlg_history.append(user_input)
+            sys_output = get_sent(self.cvae.vocab, de_tknize, pred_labels, turn_id)
+            # sys_output = self.corpus.id2sent(outs[turn_id])
+            self.dlg_history.append(sys_output)
+            sys_turns.append(sys_output)
+
+        # return the reward here
+        generated_dialog = {key: sys_turns}
+        task_report, success, match = evaluator.evaluateModel(generated_dialog, mode="offline_rl", verbose=False)
+
+        return sample_y,task_report, success, match
+
+    def train(self, verbose=False, max_words=None, temp=0.1, debug=False, n=0):
+        de_tknize = get_detokenize()
+        self.logprobs = []
+
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        ctx_lens = state['context_lens']
+        batch_size = len(state['context_lens'])
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+
+        # predict a_t
+        joint_log_pz_t, z_t = self.actor(state)
+        logprobs_t, a_prime_t = self.cvae.decode_z(z_t, batch_size, state, max_words, temp)
+        a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu)
+
+        for log_prob in joint_log_pz_t: # use likelihood of latent instead of words
+            self.logprobs.append(log_prob)
+
+        with th.no_grad():
+            # predict a_t+1
+            _, z_t1 = self.actor_target(next_state, hard=False)
+            logprobs_t1, a_prime_t1  = self.cvae.decode_z(z_t1, batch_size, next_state, max_words, temp)
+            a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu)
+
+            # Critic Training
+            if not self.critic_dropout:
+                target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1)
+                # Soft Clipped Double Q-learning 
+                target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+            else:
+                Qs = [self.critic_target(next_state, a_prime_t1) for _ in range(5)]
+                target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+
+        if not self.critic_dropout:
+            current_Q1, current_Q2 = self.critic(state, out_utts)
+        else:
+            current_Q1 = self.critic(state, out_utts)
+
+
+        y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * target_Q
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+        critic_loss = q1_loss + q2_loss
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward(retain_graph=True)
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+
+        # Update through DPG
+        actor_loss = 0
+        if self.critic_dropout:
+            q_values = self.critic(state, a_prime_t)
+        else:
+            q_values = self.critic.q1(state, a_prime_t)
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                print("True_Sys: {}".format(true_output))
+                sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id)
+                print("Pred_Sys: {}".format(sys_output))
+                print(q_values[turn_id], self.logprobs[turn_id], Return[turn_id])
+
+        for lp, q in zip(self.logprobs, q_values):
+            actor_loss -= lp * q
+        actor_loss = th.mean(actor_loss)
+
+
+        self.opt.zero_grad()
+        actor_loss.backward()
+        nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip)
+        self.opt.step()
+
+
+        # Update Target Networks 
+        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, 
+                    "actor_loss": actor_loss.item()}
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+    
+    def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False):
+        de_tknize = get_detokenize()
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+
+        if not self.critic_dropout:
+            current_Q1, current_Q2 = self.critic(state, out_utts)
+        else:
+            current_Q1 = self.critic(state, out_utts)
+
+        if sl:
+            # use return from data
+            y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1)
+        else:
+            with th.no_grad():
+                # predict a_t+1
+                _, z_t1 = self.actor_target(next_state, hard=False)
+                logprobs_t1, a_prime_t1  = self.cvae.decode_z(z_t1, batch_size, next_state, max_words, temp)
+                a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu)
+
+                # Critic Training
+                if not self.critic_dropout:
+                    target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1)
+                    target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                else:
+                    target_Q = self.critic(state, out_utts)
+
+            y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done,LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * target_Q
+
+        q1_loss = F.mse_loss(current_Q1, y)
+        q2_loss = F.mse_loss(current_Q2, y) if not self.critic_dropout else 0.0
+        critic_loss = q1_loss + q2_loss
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                print("True_Sys: {}".format(true_output))
+                print(current_Q1[turn_id], y[turn_id])
+
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward()
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+        # Update Target Networks 
+        if sl:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(param.data)
+        else:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0} 
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+class LatentPLASAgent_weird(object):
+    def __init__(self, cvae, corpus, args, evaluator, name, tune_pi_only):
+        if "gauss" in args.sv_config_path:
+            self.is_gauss = True 
+            self.is_stochastic = args.is_stochastic
+            if not self.is_stochastic:
+                self.actor = DeterministicGaussianActor(cvae, corpus, args)
+            else:
+                self.actor = StochasticGaussianActor(cvae, corpus, args)
+        else:
+            self.is_gauss = False 
+            self.is_stochastic = args.is_stochastic
+            self.actor = CatActor(cvae, corpus, args)
+            self.embed_z_for_critic = args.embed_z_for_critic
+
+        self.actor_target = copy.deepcopy(self.actor)
+        self.opt = optim.SGD(self.actor.parameters(), lr = args.actor_rl_lr, weight_decay=0.01) 
+        self.importance_threshold = args.importance_threshold
+        self.evaluator = evaluator
+        
+        if "z_loss" not in args:
+            args['z_loss'] = False
+        else:
+            self.z_loss = args.z_loss
+            if self.is_gauss:
+                if not self.is_stochastic:
+                    self.z_lossf = nn.MSELoss()
+                else:
+                    self.z_lossf = NormKLLoss(unit_average=True)
+            else:
+                if not self.is_stochastic:
+                    self.z_lossf = nn.CrossEntropyLoss()
+                else:
+                    self.z_lossf = CatKLLoss()
+
+        if "critic_kl_loss" not in args:
+            args['critic_kl_loss'] = False
+        else:
+            if not self.is_stochastic:
+                z_loss = self.z_lossf(th.exp(log_pz_t), th.argmax(corpus_z_t, dim=1))
+            else:
+                z_loss = self.z_lossf(log_pz_t, corpus_log_pz_t, batch_size, unit_average=True)
+            actor_loss += z_loss
+
+
+        self.opt.zero_grad()
+        actor_loss.backward()
+        actor_total_norm = 0
+        parameters = [p for p in self.actor.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            actor_total_norm += param_norm.item() ** 2
+            actor_total_norm = actor_total_norm ** 0.5
+        nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip)
+        self.opt.step()
+
+        # Update Target Networks 
+        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, 
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm,
+                    "actor_loss": actor_loss.item(), 
+                    "actor_grad_norm":actor_total_norm,
+                    "z_loss":z_loss.item() if self.z_loss else 0.0}
+
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+    def train_vae_model(self, batch_cnt):
+        de_tknize = get_detokenize()
+
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+
+        vae_batch = copy.deepcopy(state)
+        vae_batch['contexts'] = np.expand_dims(action, 1)
+        vae_batch['outputs'] = action
+        vae_batch['context_lens'] = ctx_lens - 1
+
+        loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE)
+        self.cvae.backward(loss, batch_cnt)
+        nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip)
+        self.vae_optimizer.step()
+        vae_loss = self.cvae.valid_loss(loss) 
+
+        return vae_loss
+
+    def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False):
+        de_tknize = get_detokenize()
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+        next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu)
+
+        if self.is_gauss:
+            corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts)
+            _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state)
+            if not self.critic_dropout:
+                current_Q1, current_Q2 = self.critic(state, corpus_z_t)
+            else:
+                current_Q1 = self.critic(state, corpus_z_t)
+                
+        else:
+           corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True)
+           soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False)
+           if self.embed_z_for_critic:
+                soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+           else: 
+                soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size)
+           if self.critic_dropout:
+               current_Q1 = self.critic(state, soft_corpus_z_t)
+           else:
+               current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t)
+
+        _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp)
+
+        if type(corpus_a_t[0]) == int:
+            corpus_a_t = [corpus_a_t]
+        corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu)
+
+
+        if sl:
+            # use return from data
+            y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1)
+        else:
+            with th.no_grad():
+                # predict a_t+1
+                if self.is_gauss:
+                    z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state)
+                    # Critic Training
+                    if not self.critic_dropout:
+                        target_Q1, target_Q2 = self.critic_target(next_state, z_t1)
+                        # Soft Clipped Double Q-learning 
+                        target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                    else:
+
+                        if not self.fix_episode:
+                            Qs = [self.critic_target(next_state, z_t1) for z_t1 in z_t1s]
+                        else:
+                            # special forward to avoid using pseudo-trajectory
+                            # tic = time.perf_counter()
+                            Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for z_t1 in z_t1s]
+                            # toc = time.perf_counter()
+                            # print(f"One critic pass in {toc - tic:0.4f} seconds")
+
+                        if self.args.critic_dropout_agg == "min":
+                            target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+                        elif self.args.critic_dropout_agg == "avg":
+                            target_Q = th.mean(th.cat(Qs, dim=1), 1)
+                else:
+                    z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state)
+                    if self.embed_z_for_critic:
+                        soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+                    else:
+                        soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size)
+
+                    if not self.critic_dropout:
+                        target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1)
+                        # Soft Clipped Double Q-learning 
+                        target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                    else:
+                        Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)]
+                        target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+
+
+            if self.args.critic_kl_loss:
+                critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach())
+            else:
+                critic_kl_loss = 0.0
+
+            # critic only ever receive supervision on final state
+            if not self.critic_dropout:
+                y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) 
+            else:
+                y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss)
+                y = y.unsqueeze(1)
+
+
+
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+
+        critic_loss = q1_loss + q2_loss + critic_kl_loss
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id)
+                print("True_Sys: {}".format(true_output))
+                corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id)
+                print("VAE_Sys: {}".format(corpus_output))
+                print("pred: ", current_Q1[turn_id], "y: ", y[turn_id])
+
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward(retain_graph=True)
+        critic_total_norm = 0
+        parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            critic_total_norm += param_norm.item() ** 2
+            critic_total_norm = critic_total_norm ** 0.5
+
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+        # Update Target Networks 
+        if sl:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(param.data)
+        else:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0,
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm}
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+class LatentPLASAgent(object):
+    def __init__(self, cvae, corpus, args, evaluator, name, tune_pi_only):
+        if "gauss" in args.sv_config_path:
+            self.is_gauss = True 
+            self.is_stochastic = args.is_stochastic
+            if not self.is_stochastic:
+                self.actor = DeterministicGaussianActor(cvae, corpus, args)
+            else:
+                self.actor = StochasticGaussianActor(cvae, corpus, args)
+        else:
+            self.is_gauss = False 
+            self.is_stochastic = args.is_stochastic
+            self.actor = CatActor(cvae, corpus, args)
+            self.embed_z_for_critic = args.embed_z_for_critic
+
+        self.actor_target = copy.deepcopy(self.actor)
+        self.opt = optim.SGD(self.actor.parameters(), lr = args.actor_rl_lr, weight_decay=0.01) 
+        self.importance_threshold = args.importance_threshold
+
+        self.evaluator = evaluator
+        
+        if "z_loss" not in args:
+            args['z_loss'] = False
+        else:
+            self.z_loss = args.z_loss
+            if self.is_gauss:
+                if not self.is_stochastic:
+                    self.z_lossf = nn.MSELoss()
+                else:
+                    self.z_lossf = NormKLLoss(unit_average=True)
+            else:
+                if not self.is_stochastic:
+                    self.z_lossf = nn.CrossEntropyLoss()
+                else:
+                    self.z_lossf = CatKLLoss()
+
+        self.critic_kl_loss = args.critic_kl_loss
+        if self.is_gauss:
+            self.kl_lossf = NormKLLoss(unit_average=True)
+        else:
+            self.kl_lossf = CatKLLoss()
+
+        self.args = args
+        self.discount = args.gamma
+        self.tau = args.tau
+        self.lmbda = args.lmbda
+        self.beta = args.beta
+        self.fix_episode = args.fix_episode
+
+        self.critic_dropout = args.critic_dropout
+        if self.critic_dropout:
+            if self.fix_episode:
+                self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args)
+            else:
+                self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args)
+        else:
+            self.critic = RecurrentCritic(cvae, corpus, cvae.config, args)
+        self.critic_target = copy.deepcopy(self.critic)
+        self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01)
+
+        self.cvae = cvae #plas cvae model
+        for n, p in self.cvae.named_parameters():
+            if "aux_encoder" not in n:
+                p.requires_grad = False
+
+
+
+        self.train_vae = args.train_vae
+        if self.train_vae:
+            self.vae_optimizer =  optim.SGD([p for p in self.cvae.parameters() if p.requires_grad==True], lr = args.rl_lr) 
+            self.vae_loss = LossManager()
+            self.cvae.shared_train = False
+            self.cvae.config.beta = args.vae_beta
+            if args.weighted_vae_nll:
+                # set NLL to be weighted on requestable tokens
+                req_tokens = []
+                for d in REQ_TOKENS.keys():
+                    req_tokens.extend(REQ_TOKENS[d])
+                nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.cvae.vocab]))
+                print("req tokens assigned with special weights")
+                if args.use_gpu:
+                    nll_weight = nll_weight.cuda()
+                self.cvae.nll.avg_type = "weighted"
+                self.cvae.nll.set_weight(nll_weight)
+
+
+        self.q_lossf = nn.MSELoss()
+        self.regf = nn.MSELoss()
+
+        self.train_buffer = ReplayBuffer(args)
+        self.valid_buffer = ReplayBuffer(args)
+        self.test_buffer = ReplayBuffer(args)
+        self.corpus = corpus
+        self.name = name
+        self.raw_goal = None
+        self.vec_goals_list = None
+        self.logprobs = None
+        self.n_z = args.n_z 
+
+    def print_dialog(self, dialog, reward, stats):
+        for t_id, turn in enumerate(dialog):
+            if t_id % 2 == 0:
+                print("Usr: {}".format(' '.join([t for t in turn if t != '<pad>'])))
+            else:
+                print("Sys: {}".format(' '.join(turn)))
+        # report = ['{}: {}'.format(k, v) for k, v in stats.items()]
+        # print("Reward {}. {}".format(reward, report))
+        print("Reward {}\n{}".format(reward, stats))
+    
+    def run(self, batch):
+        """
+        run one dialogue and compute success rate with pseudo trajectory
+        """
+        self.logprobs = []
+        self.dlg_history =[]
+        batch_size = len(batch['keys'])
+        de_tknize = get_detokenize()
+
+        # ret_dict contains the z and the response 
+        z = self.actor(batch)
+        _, outputs, _, sample_y = self.cvae.decode_z(z, batch_size, batch, self.args.max_words, self.args.temperature) 
+        pred_labels = np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in outputs])
+
+        key = batch['keys'][0]
+
+        sys_turns = []
+        # construct the dialog history for printing and calculating reward
+        for turn_id, turn in enumerate(batch['contexts']):
+            # user_input = self.corpus.id2sent(turn[-1])
+            user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+            self.dlg_history.append(user_input)
+            sys_output = get_sent(self.cvae.vocab, de_tknize, pred_labels, turn_id)
+            # sys_output = self.corpus.id2sent(outs[turn_id])
+            self.dlg_history.append(sys_output)
+            sys_turns.append(sys_output)
+
+        # for log_prob in logprobs:
+            # self.logprobs.extend(log_prob)
+        # return the reward here
+        generated_dialog = {key: sys_turns}
+        task_report, success, match = self.evaluator.evaluateModel(generated_dialog, mode="offline_rl", verbose=False)
+
+        return sample_y,task_report, success, match
+
+    def train(self, verbose=False, max_words=None, temp=0.1, debug=False, n=0):
+        de_tknize = get_detokenize()
+        self.logprobs = []
+        generated_dialogs = defaultdict(list)
+
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        if self.fix_episode:
+            key = state['keys'][0]
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+        next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu)
+
+        # predict a_t
+        if self.is_gauss:
+            z_t, actor_mu_t, actor_logvar_t = self.actor(state)
+        else:
+            z_t, soft_z_t, log_pz_t, logits_pz_t = self.actor(state)
+            # soft_z_t = self.z_embedding(soft_z_t)
+            if self.embed_z_for_critic:
+                soft_z_t = self.actor.z_embedding(soft_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+            else:
+                soft_z_t = soft_z_t.view(-1, self.actor.y_size * self.actor.k_size)
+
+        logprobs_t, a_prime_t = self.cvae.decode_z(z_t, batch_size, state, max_words, temp)
+        if type(a_prime_t[0]) == int:
+            a_prime_t = [a_prime_t]
+            logprobs_t = [logprobs_t]
+        a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu)
+
+        # predict get at using VAE, and predict Q(st, at)
+        if self.is_gauss:
+            corpus_z_t, corpus_mu_t, corpus_logvar_t = self.cvae.get_z_via_vae(out_utts)
+            corpus_z_t1, corpus_mu_t1, corpus_logvar_t1 = self.cvae.get_z_via_vae(next_out_utts)
+            _, rg_mu_t, rg_logvar_t = self.cvae.get_z_via_rg(state) # for kl loss
+            _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state) # for kl loss
+            if not self.critic_dropout:
+                current_Q1, current_Q2 = self.critic(state, corpus_z_t)
+            else:
+                current_Q1 = self.critic(state, corpus_z_t)
+        else:
+           corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True)
+           soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False)
+           _, rg_log_pz_t, _ = self.cvae.get_z_via_rg(state) # for kl loss
+           _, rg_log_pz_t1, _ = self.cvae.get_z_via_rg(next_state) # for kl loss
+
+           if self.embed_z_for_critic:
+                soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+           else: 
+                soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size)
+           if self.critic_dropout:
+               current_Q1 = self.critic(state, soft_corpus_z_t)
+           else:
+               current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t)
+
+
+        # for logging purposes only
+        _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp)
+        if type(corpus_a_t[0]) == int:
+            corpus_a_t = [corpus_a_t]
+        corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu)
+
+
+        with th.no_grad():
+            # predict a_t+1 and Q(s_t+1, a_t+1)
+            if self.is_gauss:
+                z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state)
+                # Critic Training
+                if not self.critic_dropout:
+                    target_Q1, target_Q2 = self.critic_target(next_state, z_t1)
+                    # Soft Clipped Double Q-learning 
+                    target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                else:
+                    if not self.fix_episode:
+                        Qs = [self.critic_target(next_state, z_t1) for _ in range(5)]
+                    else:
+                        # special forward to avoid using pseudo-trajectory
+                        # tic = time.perf_counter()
+                        Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for _ in range(5)]
+                        # toc = time.perf_counter()
+                        # print(f"One batch critic forward pass in {toc - tic:0.4f} seconds")
+
+                    if self.args.critic_dropout_agg == "min":
+                        target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+                    elif self.args.critic_dropout_agg == "avg":
+                        target_Q = th.mean(th.cat(Qs, dim=1), 1)
+
+            else:
+                z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state)
+                if self.embed_z_for_critic:
+                    soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+                else:
+                    soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size)
+
+                if not self.critic_dropout:
+                    target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1)
+                    # Soft Clipped Double Q-learning 
+                    target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                else:
+                    Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)]
+                    target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+
+
+
+        if self.args.critic_kl_loss:
+            if self.is_gauss:
+                critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach())
+            else:
+                critic_kl_loss = self.kl_lossf(log_pz_t1, rg_log_pz_t1.detach(), unit_average=True)
+        else:
+            critic_kl_loss = 0.0
+
+        if not self.critic_dropout:
+            y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) 
+        else:
+            y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss)
+            y = y.unsqueeze(1)
+
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+
+        critic_loss = q1_loss + q2_loss + critic_kl_loss
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward(retain_graph=True)
+        critic_total_norm = 0
+        parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            critic_total_norm += param_norm.item() ** 2
+            critic_total_norm = critic_total_norm ** 0.5
+
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+
+        # Update through DPG
+        actor_loss = 0
+        if self.is_gauss:
+            if self.critic_dropout:
+                q_values = self.critic(state, z_t)
+            else:
+                q_values = self.critic.q1(state, z_t)
+        else:
+            if self.critic_dropout:
+                q_values = self.critic(state, soft_z_t)
+            else:
+                q_values = self.critic.q1(state, soft_z_t)
+
+        if debug:
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                print("True_Sys: {}".format(true_output))
+                corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id)
+                print("VAE_Sys: {}".format(corpus_output))
+                sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id)
+                print("Pred_Sys: {}".format(sys_output))
+                # print(q_values[turn_id], Return[turn_id])
+                print("pred_Q: ", current_Q1[turn_id], "y_Q: ", y[turn_id], "target: ", target_Q[turn_id], "Return: ", Return[turn_id])
+                print()
+
+        if not self.is_stochastic:
+            actor_loss = - th.mean(q_values) 
+        else:
+            if self.is_gauss:
+                # off-policy policy gradient
+                pi_a_s = self.cvae.gaussian_prob(actor_mu_t, actor_logvar_t, corpus_z_t) + 1e-15 # batch_size, 200
+                pib_a_s = self.cvae.gaussian_prob(corpus_mu_t, corpus_logvar_t, corpus_z_t) # batch_size, 200
+                threshold =  th.ones([batch_size, self.cvae.y_size]) * self.importance_threshold
+                if self.args.use_gpu:
+                    threshold = threshold.cuda()
+                importance_ratio = th.min(pi_a_s / pib_a_s, threshold)
+
+                actor_loss = - th.mean(importance_ratio.detach() * th.log(pi_a_s) * q_values.detach())
+
+            else:
+                log_pi_a_s = self.cvae.categorical_logprob(logits_pz_t, corpus_z_t, temp=1.0) # batch_size, 200
+                log_pib_a_s = self.cvae.categorical_logprob(corpus_logits_pz_t, corpus_z_t, temp=1.0) # batch_size, 200
+                threshold =  th.ones([batch_size]) * self.importance_threshold
+                if self.args.use_gpu:
+                    threshold = threshold.cuda()
+                importance_ratio = th.min(th.exp(log_pi_a_s) / th.exp(log_pib_a_s), threshold)
+                actor_loss = - th.mean(importance_ratio.detach() * log_pi_a_s * q_values.detach())
+
+
+
+        if self.args.z_loss:
+            if self.is_gauss:
+                if not self.is_stochastic:
+                    z_loss = self.z_lossf(z_t, corpus_z_t.detach())
+                else:
+                    z_loss = 0.01 *  self.z_lossf(actor_mu_t, actor_logvar_t, corpus_mu_t.detach(), corpus_logvar_t.detach())
+                actor_loss += z_loss
+            else:
+                if not self.is_stochastic:
+                    z_loss = self.z_lossf(th.exp(log_pz_t), th.argmax(corpus_z_t, dim=1))
+                else:
+                    z_loss = self.z_lossf(log_pz_t, corpus_log_pz_t, batch_size, unit_average=True)
+                actor_loss += z_loss
+
+
+        self.opt.zero_grad()
+        actor_loss.backward()
+        actor_total_norm = 0
+        parameters = [p for p in self.actor.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            actor_total_norm += param_norm.item() ** 2
+            actor_total_norm = actor_total_norm ** 0.5
+        nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.rl_clip)
+        self.opt.step()
+
+
+        # Update Target Networks 
+        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
+            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0, 
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm,
+                    "actor_loss": actor_loss.item(), 
+                    "actor_grad_norm":actor_total_norm,
+                    "z_loss":z_loss.item() if self.z_loss else 0.0}
+
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+    def train_vae_model(self, batch_cnt):
+        de_tknize = get_detokenize()
+
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+
+        vae_batch = copy.deepcopy(state)
+        vae_batch['contexts'] = np.expand_dims(action, 1)
+        vae_batch['outputs'] = action
+        vae_batch['context_lens'] = ctx_lens - 1
+
+        loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE)
+        self.cvae.backward(loss, batch_cnt)
+        nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip)
+        self.vae_optimizer.step()
+        vae_loss = self.cvae.valid_loss(loss) 
+
+        return vae_loss
+
+    def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False):
+        de_tknize = get_detokenize()
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+        next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu)
+
+        if self.is_gauss:
+            corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts)
+            _, rg_mu_t1, rg_logvar_t1 = self.cvae.get_z_via_rg(next_state)
+            if not self.critic_dropout:
+                current_Q1, current_Q2 = self.critic(state, corpus_z_t)
+            else:
+                current_Q1 = self.critic(state, corpus_z_t)
+                
+        else:
+           corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True)
+           soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False)
+           if self.embed_z_for_critic:
+                # corpus_z_t = self.z_embedding(corpus_z_t)
+                soft_corpus_z_t = self.actor.z_embedding(soft_corpus_z_t.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+           else: 
+                soft_corpus_z_t = soft_corpus_z_t.view(-1, self.actor.y_size * self.actor.k_size)
+           if self.critic_dropout:
+               current_Q1 = self.critic(state, soft_corpus_z_t)
+           else:
+               current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t)
+
+        _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp)
+
+        if type(corpus_a_t[0]) == int:
+            corpus_a_t = [corpus_a_t]
+        corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu)
+
+
+        if sl:
+            # use return from data
+            y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1)
+        else:
+            with th.no_grad():
+                # predict a_t+1
+                if self.is_gauss:
+                    z_t1, actor_mu_t1, actor_logvar_t1 = self.actor_target(next_state)
+                    # Critic Training
+                    if not self.critic_dropout:
+                        target_Q1, target_Q2 = self.critic_target(next_state, z_t1)
+                        # Soft Clipped Double Q-learning 
+                        target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                    else:
+
+                        if not self.fix_episode:
+                            Qs = [self.critic_target(next_state, z_t1) for _ in range(5)]
+                        else:
+                            # special forward to avoid using pseudo-trajectory
+                            # tic = time.perf_counter()
+                            Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for _ in range(5)]
+                            # toc = time.perf_counter()
+                            # print(f"One critic pass in {toc - tic:0.4f} seconds")
+
+                        if self.args.critic_dropout_agg == "min":
+                            target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+                        elif self.args.critic_dropout_agg == "avg":
+                            target_Q = th.mean(th.cat(Qs, dim=1), 1)
+                else:
+                    z_t1, soft_z_t1, log_pz_t1, logits_pz_t1 = self.actor_target(next_state)
+                    if self.embed_z_for_critic:
+                        soft_z_t1 = self.actor_target.z_embedding(soft_z_t1.view(1, -1, self.actor.y_size * self.actor.k_size)).squeeze(0)
+                    else:
+                        soft_z_t1 = soft_z_t1.view(-1, self.actor.y_size * self.actor.k_size)
+
+                    if not self.critic_dropout:
+                        target_Q1, target_Q2 = self.critic_target(next_state, soft_z_t1)
+                        # Soft Clipped Double Q-learning 
+                        target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                    else:
+                        Qs = [self.critic_target(next_state, soft_z_t1) for _ in range(5)]
+                        target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+
+
+            if self.args.critic_kl_loss:
+                critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, rg_mu_t1.detach(), rg_logvar_t1.detach())
+            else:
+                critic_kl_loss = 0.0
+
+            # critic only ever receive supervision on final state
+            if not self.critic_dropout:
+                y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) 
+            else:
+                y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss)
+                y = y.unsqueeze(1)
+
+
+
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+
+        critic_loss = q1_loss + q2_loss + critic_kl_loss
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id)
+                print("True_Sys: {}".format(true_output))
+                corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id)
+                print("VAE_Sys: {}".format(corpus_output))
+                print("pred: ", current_Q1[turn_id], "y: ", y[turn_id])
+
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward(retain_graph=True)
+        critic_total_norm = 0
+        parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            critic_total_norm += param_norm.item() ** 2
+            critic_total_norm = critic_total_norm ** 0.5
+
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+        # Update Target Networks 
+        if sl:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(param.data)
+        else:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0,
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm}
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+class LatentCriticAgent(object):
+    def __init__(self, cvae, corpus, args, evaluator, name, vae=None):
+        if "gauss" in args.config_path:
+            self.is_gauss = True 
+            self.kl_lossf = NormKLLoss(unit_average=True)
+        else:
+            self.is_gauss = False 
+            self.kl_lossf = CatKLLoss()
+            self.embed_z_for_critic = args.embed_z_for_critic
+
+        self.evaluator = evaluator
+        self.corpus = corpus
+        
+        if "raw_response" in args:
+            self.raw_response = args.raw_response
+        else:
+            self.raw_response = False
+
+        if self.raw_response:
+            self._read_responses(args.response_path)
+
+        if "corpus_response" in args:
+            self.corpus_response = args.corpus_response
+        else:
+            self.corpus_response = False
+
+        assert self.corpus_response != self.raw_response
+
+        if self.corpus_response or self.raw_response:
+            self.vae = vae
+        else:
+            self.vae=None
+        
+        self.args = args
+        self.discount = args.gamma
+        self.tau = args.tau
+        self.lmbda = args.lmbda
+        self.beta = args.beta
+        self.fix_episode = args.fix_episode
+
+        self.critic_dropout = args.critic_dropout
+        if self.critic_dropout:
+            if self.fix_episode:
+                self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args)
+            else:
+                self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args)
+        else:
+            self.critic = RecurrentCritic(cvae, corpus, cvae.config, args)
+        self.critic_target = copy.deepcopy(self.critic)
+        self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01)
+
+        self.cvae = cvae #lava model
+        for n, p in self.cvae.named_parameters():
+            p.requires_grad = False # no param is trained
+
+        self.train_vae = args.train_vae
+        if self.train_vae:
+            self.vae_optimizer =  optim.SGD([p for p in self.cvae.parameters() if p.requires_grad==True], lr = args.rl_lr) 
+            self.vae_loss = LossManager()
+            self.cvae.shared_train = False
+            self.cvae.config.beta = args.vae_beta
+            if args.weighted_vae_nll:
+                # set NLL to be weighted on requestable tokens
+                req_tokens = []
+                for d in REQ_TOKENS.keys():
+                    req_tokens.extend(REQ_TOKENS[d])
+                nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.cvae.vocab]))
+                print("req tokens assigned with special weights")
+                if args.use_gpu:
+                    nll_weight = nll_weight.cuda()
+                self.cvae.nll.avg_type = "weighted"
+                self.cvae.nll.set_weight(nll_weight)
+
+        self.q_lossf = nn.MSELoss()
+        self.regf = nn.MSELoss()
+
+        self.train_buffer = ReplayBuffer(args)
+        self.valid_buffer = ReplayBuffer(args)
+        self.test_buffer = ReplayBuffer(args)
+        self.corpus = corpus
+        self.name = name
+        self.raw_goal = None
+        self.vec_goals_list = None
+        self.logprobs = None
+        self.n_z = args.n_z 
+
+    def _read_responses(self, json_path):
+        """
+        following shades of bleu format
+        """
+
+        self.raw_responses = defaultdict(list)
+
+        # read responses
+        json_list = [json_path, json_path.replace("test", "val"), json_path.replace("test", "train")]
+        for j in json_list:
+            print(f"Reading responses from {j}...")
+            with open(j) as f:
+                raw_data = json.load(f)
+
+            # convert to ID of critic encoder
+            for k in raw_data.keys():
+                if "augpt" in json_path:
+                    self.raw_responses[k] = [self.corpus._sent2id(replace_augpt_tokens(augpt_normalize(raw_data[k]['response'][i]), raw_data[k]['active_domain'][i]).split()) for i in range(len(raw_data[k]['response']))]
+                elif "HDSA" in json_path or "hdsa" in json_path:
+                    self.raw_responses[k + ".json"] = [self.corpus._sent2id(replace_hdsa_tokens(raw_data[k][i])) for i in range(len(raw_data[k]))]
+
+            print(len(raw_data.keys()), "dialogues read")
+
+        print(f"responses from {len(self.raw_responses.keys())} dialogues read")
+
+
+    def train_vae_model(self, batch_cnt):
+        de_tknize = get_detokenize()
+
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+
+        vae_batch = copy.deepcopy(state)
+        vae_batch['contexts'] = np.expand_dims(action, 1)
+        vae_batch['outputs'] = action
+        vae_batch['context_lens'] = ctx_lens - 1
+
+        loss = self.cvae.forward_aux(vae_batch, mode=TEACH_FORCE)
+        self.cvae.backward(loss, batch_cnt)
+        nn.utils.clip_grad_norm_(self.cvae.parameters(),self.cvae.config.grad_clip)
+        self.vae_optimizer.step()
+        vae_loss = self.cvae.valid_loss(loss) 
+
+        return vae_loss
+
+    def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False, n=0):
+        de_tknize = get_detokenize()
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        key = state['keys'][0]
+
+        if self.raw_response:
+            while key not in self.raw_responses.keys():
+                experiences = self.train_buffer.sample()
+                state, action, reward, next_state, expert_next_action, done, Return = experiences
+                key = state['keys'][0]
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+        next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu)
+        if self.raw_response:
+            a_prime_t1 =  self.cvae.np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in self.raw_responses[key][1:] + [[]]]), LONG)
+        elif self.corpus_response:
+            a_prime_t1 = next_out_utts
+
+        if self.is_gauss:
+            corpus_z_t, corpus_mu, corpus_logvar = self.cvae.get_z_via_vae(out_utts)
+            _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp)
+            if not self.critic_dropout:
+                current_Q1, current_Q2 = self.critic(state, corpus_z_t)
+            else:
+                current_Q1 = self.critic(state, corpus_z_t)
+                
+        else:
+           corpus_z_t, _, _= self.cvae.get_z_via_vae(out_utts, hard=True)
+           soft_corpus_z_t, corpus_logits_pz_t, corpus_log_pz_t = self.cvae.get_z_via_vae(out_utts, hard=False)
+           _, corpus_a_t = self.cvae.decode_z(corpus_z_t, batch_size, state, max_words, temp)
+           if self.embed_z_for_critic:
+                soft_corpus_z_t = self.cvae.z_embedding(soft_corpus_z_t.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0)
+                corpus_z_t = self.cvae.z_embedding(corpus_z_t.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0)
+           else: 
+                soft_corpus_z_t = soft_corpus_z_t.view(-1, self.cvae.y_size * self.cvae.k_size)
+                corpus_z_t = corpus_z_t.view(-1, self.cvae.y_size * self.cvae.k_size)
+           if self.critic_dropout:
+               current_Q1 = self.critic(state, soft_corpus_z_t)
+           else:
+               current_Q1, current_Q2 = self.critic(state, soft_corpus_z_t)
+
+
+
+        if type(corpus_a_t[0]) == int:
+            corpus_a_t = [corpus_a_t]
+        corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu)
+
+        z_t, actor_mu_t, actor_logvar_t = self.cvae.get_z_via_rg(state)
+        logprobs_t, a_prime_t  = self.cvae.decode_z(z_t, batch_size, state, max_words, temp)
+        if type(a_prime_t[0]) == int:
+            a_prime_t = [a_prime_t]
+        a_prime_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t]), LONG, use_gpu=self.args.use_gpu)
+
+
+        with th.no_grad():
+            if sl:
+                # use return from data
+                y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1)
+            else:
+                # predict a_t+1
+                if self.is_gauss:
+                    if not self.raw_response:
+                        z_t1, actor_mu_t1, actor_logvar_t1 = self.cvae.get_z_via_rg(next_state)
+                    else:
+                        z_t1, actor_mu_t1, actor_logvar_t1 = self.vae.get_z_via_vae(a_prime_t1)
+
+                    _, corpus_mu_t1, corpus_logvar_t1 = self.cvae.get_z_via_vae(next_out_utts)
+
+                else: #categorical
+                    z_t1, actor_log_pz_t1, actor_pz_t1 = self.cvae.get_z_via_rg(next_state, hard=True)
+                    if self.embed_z_for_critic:
+                        z_t1 = self.cvae.z_embedding(z_t1.view(1, -1, self.cvae.y_size * self.cvae.k_size)).squeeze(0)
+                    else:
+                        z_t1 = z_t1.view(-1, self.cvae.y_size * self.cvae.k_size)
+                    _, corpus_log_pz_t1, corpus_logits_pz_t1 = self.cvae.get_z_via_vae(next_out_utts)
+
+                # Critic Training
+                if not self.critic_dropout:
+                    target_Q1, target_Q2 = self.critic_target(next_state, z_t1)
+                    # Soft Clipped Double Q-learning 
+                    target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                else:
+                    if not self.fix_episode:
+                        Qs = [self.critic_target(next_state, z_t1) for z_t1 in z_t1s]
+                    else:
+                        # special forward to avoid using pseudo-trajectory
+                        Qs = [self.critic_target.forward_target(next_state, z_t1, corpus_z_t) for z_t1 in z_t1s]
+
+                    if self.args.critic_dropout_agg == "min":
+                        target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+                    elif self.args.critic_dropout_agg == "avg":
+                        target_Q = th.mean(th.cat(Qs, dim=1), 1)
+
+
+                if self.args.critic_kl_loss:
+                    if self.is_gauss:
+                        critic_kl_loss = self.kl_lossf(actor_mu_t1, actor_logvar_t1, corpus_mu_t1.detach(), corpus_logvar_t1.detach())
+                    else:
+                        critic_kl_loss = self.kl_lossf(actor_log_pz_t1, corpus_log_pz_t1.detach(), unit_average=True)
+                else:
+                    critic_kl_loss = 0.0
+
+                # critic only ever receive supervision on final state
+                if not self.critic_dropout:
+                    y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss) 
+                else:
+                    y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q - self.args.critic_kl_alpha * critic_kl_loss)
+                    y = y.unsqueeze(1)
+
+
+
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+
+        critic_loss = q1_loss + q2_loss + critic_kl_loss
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                # sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t, turn_id)
+                print("True_Sys: {}".format(true_output))
+                corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id)
+                print("VAE_Sys: {}".format(corpus_output))
+                print("pred: ", current_Q1[turn_id], "y: ", y[turn_id])
+
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward()
+        critic_total_norm = 0
+        parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            critic_total_norm += param_norm.item() ** 2
+            critic_total_norm = critic_total_norm ** 0.5
+
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+        # Update Target Networks 
+        if sl:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(param.data)
+        else:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0,
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm}
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+
+class CriticAgent(LatentCriticAgent):
+
+    def __init__(self, cvae, corpus, args, evaluator, name):
+        if "gauss" in args.config_path:
+            self.is_gauss = True 
+            self.kl_lossf = NormKLLoss(unit_average=True)
+        else:
+            self.is_gauss = False 
+            self.kl_lossf = CatKLLoss()
+            self.embed_z_for_critic = args.embed_z_for_critic
+
+        self.evaluator = evaluator
+        self.corpus = corpus
+        
+        if "raw_response" in args:
+            self.raw_response = args.raw_response
+        else:
+            self.raw_response = False
+
+        if self.raw_response:
+            self._read_responses(args.response_path)
+
+        if "corpus_response" in args:
+            self.corpus_response = args.corpus_response
+        else:
+            self.corpus_response = False
+
+        if self.corpus_response or self.raw_response:
+            assert self.corpus_response != self.raw_response
+        
+        self.args = args
+        self.discount = args.gamma
+        self.tau = args.tau
+        self.lmbda = args.lmbda
+        self.beta = args.beta
+        self.fix_episode = args.fix_episode
+    
+        self.critic_dropout = args.critic_dropout
+        if self.critic_dropout:
+            if self.fix_episode:
+                self.critic = SingleHierarchicalRecurrentCritic(cvae, corpus, cvae.config, args)
+            else:
+                self.critic = SingleRecurrentCritic(cvae, corpus, cvae.config, args)
+        else:
+            self.critic = RecurrentCritic(cvae, corpus, cvae.config, args)
+        self.critic_target = copy.deepcopy(self.critic)
+        self.critic_optimizer = optim.SGD(self.critic.parameters(), lr=args.critic_rl_lr, weight_decay=0.01)
+
+        self.cvae = cvae #lava model
+        for n, p in self.cvae.named_parameters():
+            p.requires_grad = False # no param is trained
+
+        if "actor_path" in args and args.actor_path is not None:
+            with open(args.actor_config, "r") as f:
+                actor_config =  Pack(json.load(f))
+            self.is_stochastic = actor_config.is_stochastic
+            if self.is_gauss:
+                if not self.is_stochastic:
+                    self.actor = DeterministicGaussianActor(cvae, corpus, args)
+                else:
+                    self.actor = StochasticGaussianActor(cvae, corpus, args)
+            else:
+                self.actor = CatActor(cvae, corpus, args)
+                self.embed_z_for_critic = args.embed_z_for_critic
+
+            actor_dict = th.load(args.actor_path, map_location=lambda storage, location: storage)
+            self.actor.load_state_dict(actor_dict)
+        else:
+            self.actor = None
+
+
+        self.q_lossf = nn.MSELoss()
+        self.regf = nn.MSELoss()
+
+        self.train_buffer = ReplayBuffer(args)
+        self.valid_buffer = ReplayBuffer(args)
+        self.test_buffer = ReplayBuffer(args)
+        self.name = name
+        self.raw_goal = None
+        self.vec_goals_list = None
+        self.logprobs = None
+        self.n_z = args.n_z 
+
+    def train_critic(self, verbose=False, max_words=None, temp=0.1, sl=False, debug=False, n=0):
+        de_tknize = get_detokenize()
+        # Sample replay buffer / batch
+        experiences = self.train_buffer.sample()
+        state, action, reward, next_state, expert_next_action, done, Return = experiences
+        key = state['keys'][0]
+
+        if self.raw_response:
+            while key not in self.raw_responses.keys():
+                experiences = self.train_buffer.sample()
+                state, action, reward, next_state, expert_next_action, done, Return = experiences
+                key = state['keys'][0]
+
+        ctx_lens = state['context_lens']  # (batch_size, )
+        batch_size = len(state['context_lens'])
+
+        out_utts = np2var(action, LONG, use_gpu=self.args.use_gpu)
+        next_out_utts = np2var(expert_next_action, LONG, use_gpu=self.args.use_gpu)
+        corpus_a_t = action
+        corpus_a_t1 = expert_next_action
+
+        if type(corpus_a_t[0]) == int:
+            corpus_a_t = [corpus_a_t]
+        corpus_a_t = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in corpus_a_t]), LONG, use_gpu=self.args.use_gpu)
+
+        if not self.critic_dropout:
+            current_Q1, current_Q2 = self.critic(state, corpus_a_t)
+        else:
+            current_Q1 = self.critic(state, corpus_a_t)
+
+        # compute target
+        with th.no_grad():
+            if sl:
+                # use return from data
+                y = np2var(Return, FLOAT, use_gpu=self.args.use_gpu).unsqueeze(1)
+            else:
+                if self.raw_response:
+                    a_prime_t1 =  self.cvae.np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in self.raw_responses[key][1:] + [[]]]), LONG)
+                elif self.corpus_response:
+                    a_prime_t1 = next_out_utts
+                else:
+                    # predict a_t+1
+                    if self.actor is not None:
+                        z_t1, actor_mu_t1, actor_logvar_t1 = self.actor(next_state)
+                        _, a_prime_t1 = self.cvae.decode_z(z_t1, batch_size, next_state, self.args.max_words, self.args.temperature)
+                    else:
+                        logprobs, a_prime_t1, joint_logpz, sample_z = self.cvae.forward_rl(next_state, max_words=max_words, temp=temp)
+
+                    if type(a_prime_t1[0]) == int:
+                        a_prime_t1 = [a_prime_t1]
+                    a_prime_t1 = np2var(np.asarray([self.cvae.pad_to(self.args.max_words, a, do_pad=True) for a in a_prime_t1]), LONG, use_gpu=self.args.use_gpu)
+
+                # Critic Training
+                if not self.critic_dropout:
+                    target_Q1, target_Q2 = self.critic_target(next_state, a_prime_t1)
+                    # Soft Clipped Double Q-learning 
+                    target_Q = self.lmbda * th.min(target_Q1, target_Q2) + (1. - self.lmbda) * th.max(target_Q1, target_Q2)
+                else:
+                    if not self.fix_episode:
+                        Qs = [self.critic_target(next_state, a_prime_t1) for _ in range(5)]
+                    else:
+                        # special forward to avoid using pseudo-trajectory
+                        # tic = time.perf_counter()
+                        Qs = [self.critic_target.forward_target(next_state, a_prime_t1.unsqueeze(1), corpus_a_t) for _ in range(5)]
+                        # toc = time.perf_counter()
+                        # print(f"One batch of critic pass in {toc - tic:0.4f} seconds")
+                    if self.args.critic_dropout_agg == "min":
+                        target_Q = th.min(th.cat(Qs, dim=1), dim=1)[0]
+                    elif self.args.critic_dropout_agg == "avg":
+                        target_Q = th.mean(th.cat(Qs, dim=1), 1)
+
+                # critic only ever receive supervision on final state
+                if not self.critic_dropout:
+                    y = np2var(reward, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu).unsqueeze(1) * self.args.gamma * (target_Q)
+                else:
+                    y = np2var(reward, LONG, use_gpu=self.args.use_gpu) + np2var(1 - done, LONG, use_gpu=self.args.use_gpu) * self.args.gamma * (target_Q)
+                    y = y.unsqueeze(1)
+
+
+
+        q1_loss = self.q_lossf(current_Q1, y)
+        q2_loss = self.q_lossf(current_Q2, y) if not self.critic_dropout else 0.0
+
+        critic_loss = q1_loss + q2_loss
+
+        if debug:
+
+            print("===TURN T===")
+            for turn_id, turn in enumerate(state['contexts']):
+                user_input = get_sent(self.cvae.vocab, de_tknize, turn, -1)
+                print("Usr: {}".format(user_input))
+                if turn_id > 0 and not sl:
+                    sys_output = get_sent(self.cvae.vocab, de_tknize, a_prime_t1, turn_id - 1)
+                    print("Pred_Sys: {}".format(sys_output))
+                true_output = get_sent(self.cvae.vocab, de_tknize, out_utts, turn_id)
+                print("True_Sys: {}".format(true_output))
+                corpus_output = get_sent(self.cvae.vocab, de_tknize, corpus_a_t, turn_id)
+                print("VAE_Sys: {}".format(corpus_output))
+                print("pred: ", current_Q1[turn_id], "y: ", y[turn_id])
+
+
+        self.critic_optimizer.zero_grad()
+        critic_loss.backward()
+        critic_total_norm = 0
+        parameters = [p for p in self.critic.parameters() if p.grad is not None and p.requires_grad]
+        for p in parameters:
+            param_norm = p.grad.detach().data.norm(2)
+            critic_total_norm += param_norm.item() ** 2
+            critic_total_norm = critic_total_norm ** 0.5
+
+        nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.rl_clip)
+        self.critic_optimizer.step()
+
+        # Update Target Networks 
+        if sl:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(param.data)
+        else:
+            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
+                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
+
+        loss_dict =  {"critic_loss": critic_loss.item(), 
+                    "q1_loss": q1_loss.item(), 
+                    "q2_loss": q2_loss.item() if not self.critic_dropout else 0.0,
+                    "critic_kl_loss":critic_kl_loss.item() if self.args.critic_kl_loss else 0.0,
+                    "critic_grad_norm":critic_total_norm}
+
+        report = ", ".join([f"{k}: {v}" for k, v in loss_dict.items()])
+
+        if verbose:
+            logger.info(report)
+        
+        return report, loss_dict
+    
diff --git a/latent_dialog/augpt_utils.py b/latent_dialog/augpt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5c4c1db7358b6a7e321f56bb05a6db0ed637e6
--- /dev/null
+++ b/latent_dialog/augpt_utils.py
@@ -0,0 +1,458 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2021 lubis <lubis@hilbert242>
+#
+# Distributed under terms of the MIT license.
+
+"""
+utils from AuGPT codebase
+"""
+import re
+import os
+import sys
+import types
+import shutil
+import logging
+import requests
+import torch
+import zipfile
+import bisect
+import random
+import copy
+import json
+from collections import OrderedDict, defaultdict
+from typing import Callable, Union, Set, Optional, List, Dict, Any, Tuple, MutableMapping  # noqa: 401
+from dataclasses import dataclass
+import pdb
+
+DATASETS_PATH = os.path.join(os.path.expanduser(os.environ.get('DATASETS_PATH', '~/datasets')), 'augpt')
+
+pricepat = re.compile("\d{1,3}[.]\d{1,2}")
+
+fin = open(os.path.join(DATASETS_PATH, 'mapping.pair'))
+replacements = []
+for line in fin.readlines():
+    tok_from, tok_to = line.replace('\n', '').split('\t')
+    replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
+
+
+def insertSpace(token, text):
+    sidx = 0
+    while True:
+        sidx = text.find(token, sidx)
+        if sidx == -1:
+            break
+        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
+                re.match('[0-9]', text[sidx + 1]):
+            sidx += 1
+            continue
+        if text[sidx - 1] != ' ':
+            text = text[:sidx] + ' ' + text[sidx:]
+            sidx += 1
+        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
+            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
+        sidx += 1
+    return text
+
+
+
+class AutoDatabase:
+    @staticmethod
+    def load(pretrained_model_name_or_path):
+        database_file = os.path.join(pretrained_model_name_or_path, 'database.zip')
+
+        with zipfile.ZipFile(database_file) as zipf:
+            def _build_database():
+                module = types.ModuleType('database')
+                exec(zipf.read('database.py').decode('utf-8'), module.__dict__)
+                return module.Database(zipf)
+
+            database = _build_database()
+
+
+        return database
+
+class BeliefParser:
+    def __init__(self):
+        self.slotval_re = re.compile(r"(\w[\w ]*\w) = ([\w\d: |']+)")
+        self.domain_re = re.compile(r"(\w+) {\s*([\w,= :\d|']*)\s*}", re.IGNORECASE)
+
+    def __call__(self, raw_belief: str):
+        belief = OrderedDict()
+        for match in self.domain_re.finditer(raw_belief):
+            domain, domain_bs = match.group(1), match.group(2)
+            belief[domain] = {}
+            for slot_match in self.slotval_re.finditer(domain_bs):
+                slot, val = slot_match.group(1), slot_match.group(2)
+                belief[domain][slot] = val
+        return belief
+
+class AutoLexicalizer:
+    @staticmethod
+    def load(pretrained_model_name_or_path):
+        lexicalizer_file = os.path.join(pretrained_model_name_or_path, 'lexicalizer.zip')
+        
+        with zipfile.ZipFile(lexicalizer_file) as zipf:
+            def _build_lexicalizer():
+                module = types.ModuleType('lexicalizer')
+                exec(zipf.read('lexicalizer.py').decode('utf-8'), module.__dict__)
+                return module.Lexicalizer(zipf)
+
+            lexicalizer = _build_lexicalizer()
+
+
+        return lexicalizer
+
+def build_blacklist(items, domains=None):
+    for i, (dialogue, items) in enumerate(items):
+        if domains is not None and set(dialogue['domains']).difference(domains):
+            yield i
+        elif items[-1]['speaker'] != 'system':
+            yield i
+
+class BlacklistItemsWrapper:
+    def __init__(self, items, blacklist):
+        self.items = items
+        self.key2idx = items.key2idx
+        self._indexmap = []
+        blacklist_pointer = 0
+        for i in range(len(items)):
+            if blacklist_pointer >= len(blacklist):
+                self._indexmap.append(i)
+            elif i < blacklist[blacklist_pointer]:
+                self._indexmap.append(i)
+            elif i == blacklist[blacklist_pointer]:
+                blacklist_pointer += 1
+        assert len(self._indexmap) == len(items) - len(blacklist)
+
+    def __getitem__(self, idx):
+        if isinstance(idx, str):
+            idx = self.key2idx[idx]
+        return self.items[self._indexmap[idx]]
+
+    def __len__(self):
+        return len(self._indexmap)
+def split_name(dataset_name: str):
+    split = dataset_name.rindex('/')
+    return dataset_name[:split], dataset_name[split + 1:]
+
+@dataclass
+class DialogDatasetItem:
+    context: Union[List[str], str]
+    belief: Union[Dict[str, Dict[str, str]], str] = None
+    database: Union[List[Tuple[str, int]], List[Tuple[str, int, Any]], None, str] = None
+    response: str = None
+    positive: bool = True
+    raw_belief: Any = None
+    raw_response: str = None
+    key: str = None
+
+    def __getattribute__(self, name):
+        val = object.__getattribute__(self, name)
+        if name == 'belief' and val is None and self.raw_belief is not None:
+            val = format_belief(self.raw_belief)
+            self.belief = val
+        return val
+
+@dataclass
+class DialogDataset(torch.utils.data.Dataset):
+    items: List[any]
+    database: Any = None
+    domains: List[str] = None
+    lexicalizer: Any = None
+    transform: Callable[[Any], Any] = None
+    normalize_input: Callable[[str], str] = None
+    ontology: Dict[Tuple[str, str], Set[str]] = None
+
+    @staticmethod
+    def build_dataset_without_database(items, *args, **kwargs):
+        return DialogDataset(items, FakeDatabase(), *args, **kwargs)
+
+    def __getitem__(self, index):
+        item = self.items[index]
+        if self.transform is not None:
+            item = self.transform(item)
+        return item
+
+    def __len__(self):
+        return len(self.items)
+
+    def map(self, transformation):
+        def trans(x):
+            x = self.transform(x)
+            x = transformation(x)
+            return x
+        return dataclasses.replace(self, transform=trans)
+
+    def finish(self, progressbar: Union[str, bool] = False):
+        if self.transform is None:
+            return self
+
+        ontology = defaultdict(lambda: set())
+        domains = set(self.domains) if self.domains else set()
+
+        items = []
+        for i in trange(len(self),
+                        desc=progressbar if isinstance(progressbar, str) else 'loading dataset',
+                        disable=not progressbar):
+            item = self[i]
+            for k, bs in item.raw_belief.items():
+                domains.add(k)
+                for k2, val in bs.items():
+                    ontology[(k, k2)].add(val)
+            items.append(item)
+        if self.ontology:
+            ontology = merge_ontologies((self.ontology, ontology))
+        return dataclasses.replace(self, items=items, transform=None, domains=domains, ontology=ontology)
+
+class DialogueItems:
+    @staticmethod
+    def cumsum(sequence):
+        r, s = [], 0
+        for e in sequence:
+            r.append(e + s)
+            s += e
+        return r
+
+    def __init__(self, dialogues):
+        lengths = [len(x['items']) for x in dialogues]
+        self.keys = [x['name'] for x in dialogues]
+        self.key2idx = {k:i for (i, k) in enumerate(self.keys)}
+        self.cumulative_sizes = DialogueItems.cumsum(lengths)
+        self.dialogues = dialogues
+
+    def __getitem__(self, idx):
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError("absolute value of index should not exceed dataset length")
+            idx = len(self) + idx
+        dialogue_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dialogue_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dialogue_idx - 1]
+        return self.dialogues[dialogue_idx], self.dialogues[dialogue_idx]['items'][:sample_idx + 1]
+
+    def __len__(self):
+        if not self.cumulative_sizes:
+            return 0
+        return self.cumulative_sizes[-1]
+
+
+def load_dataset(name, use_goal=False, context_window_size=15, domains=None, **kwargs) -> DialogDataset:
+    name, split = split_name(name)
+    path = os.path.join(DATASETS_PATH, name)
+    with open(os.path.join(path, f'{split}.json'), 'r') as f:
+        data = json.load(f, object_pairs_hook=OrderedDict)
+    dialogues = data['dialogues']
+    items = DialogueItems(dialogues)
+    items = BlacklistItemsWrapper(items, list(build_blacklist(items, domains)))
+
+    def transform(x):
+        dialogue, items = x
+        context = [s['text'] for s in items[:-1]]
+        if context_window_size is not None and context_window_size > 0:
+            context = context[-context_window_size:]
+        belief = items[-1]['belief']
+        database = items[-1]['database']
+        item = DialogDatasetItem(context,
+                        raw_belief=belief,
+                        database=database,
+                        response=items[-1]['delexicalised_text'],
+                        raw_response=items[-1]['text'],
+                        key=dialogue['name'])
+        if use_goal:
+            setattr(item, 'goal', dialogue['goal'])
+            # MultiWOZ evaluation uses booked domains property
+            if 'booked_domains' in items[-1]:
+                setattr(item, 'booked_domains', items[-1]['booked_domains'])
+            setattr(item, 'dialogue_act', items[-1]['dialogue_act'])
+        setattr(item, 'active_domain', items[-1]['active_domain'])
+        return item
+
+    dataset = DialogDataset(items, transform=transform, domains=data['domains'])
+    if os.path.exists(os.path.join(path, 'database.zip')):
+        dataset.database = AutoDatabase.load(path)
+
+    if os.path.exists(os.path.join(path, 'lexicalizer.zip')):
+        dataset.lexicalizer = AutoLexicalizer.load(path)
+
+    return dataset
+
+def format_belief(belief: OrderedDict) -> str:
+    assert isinstance(belief, OrderedDict)
+    str_bs = []
+    for domain, domain_bs in belief.items():
+        domain_bs = ', '.join([f'{slot} = {val}' for slot, val in sorted(domain_bs.items(), key=lambda x: x[0])])
+        str_bs.extend([domain, '{' + domain_bs + '}'])
+    return ' '.join(str_bs)
+
+def augpt_normalize(text, delexicalize=True):
+    # lower case every word
+    text = text.lower()
+
+    # replace white spaces in front and end
+    text = re.sub(r'^\s*|\s*$', '', text)
+
+    # hotel domain pfb30
+    text = re.sub(r"b&b", "bed and breakfast", text)
+    text = re.sub(r"b and b", "bed and breakfast", text)
+
+    # normalize phone number
+    ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text)
+    if ms:
+        sidx = 0
+        for m in ms:
+            sidx = text.find(m[0], sidx)
+            if text[sidx - 1] == '(':
+                sidx -= 1
+            eidx = text.find(m[-1], sidx) + len(m[-1])
+            text = text.replace(text[sidx:eidx], ''.join(m))
+
+    # normalize postcode
+    ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})',
+                    text)
+    if ms:
+        sidx = 0
+        for m in ms:
+            sidx = text.find(m, sidx)
+            eidx = sidx + len(m)
+            text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:]
+
+    # weird unicode bug
+    text = re.sub(u"(\u2018|\u2019)", "'", text)
+
+    # replace time and and price
+    if delexicalize:
+        text = re.sub(pricepat, ' [price] ', text)
+        #text = re.sub(pricepat2, '[value_price]', text)
+
+    # replace st.
+    text = text.replace(';', ',')
+    text = re.sub('$\/', '', text)
+    text = text.replace('/', ' and ')
+
+    # replace other special characters
+    text = text.replace('-', ' ')
+    if delexicalize:
+        text = re.sub('[\":\<>@\(\)]', '', text)
+    else:
+        text = re.sub('[\"\<>@\(\)]', '', text)
+
+    # insert white space before and after tokens:
+    for token in ['?', '.', ',', '!']:
+        text = insertSpace(token, text)
+
+    # insert white space for 's
+    text = insertSpace('\'s', text)
+
+    # replace it's, does't, you'd ... etc
+    text = re.sub('^\'', '', text)
+    text = re.sub('\'$', '', text)
+    text = re.sub('\'\s', ' ', text)
+    text = re.sub('\s\'', ' ', text)
+    for fromx, tox in replacements:
+        text = ' ' + text + ' '
+        text = text.replace(fromx, tox)[1:-1]
+
+    # remove multiple spaces
+    text = re.sub(' +', ' ', text)
+
+    # concatenate numbers
+    tmp = text
+    tokens = text.split()
+    i = 1
+    while i < len(tokens):
+        if re.match(u'^\d+$', tokens[i]) and \
+                re.match(u'\d+$', tokens[i - 1]):
+            tokens[i - 1] += tokens[i]
+            del tokens[i]
+        else:
+            i += 1
+    text = ' '.join(tokens)
+
+    return text
+
+def replace_augpt_tokens(text, active_domains):
+    if type(active_domains) == list:
+        if len(active_domains) > 1:
+            pdb.set_trace()
+        elif len(active_domains) == 0:
+            active_domain = None
+        else:
+            active_domain = active_domains[0]
+    else:
+        active_domain = active_domains
+
+    tokens = text.split()
+
+    ret = []
+    for t in tokens:
+        if t[0] == "[" and len(t) > 1:
+            if t in ["[name]", "[address]", "[phone]", "[postcode]", "[reference]", "[id]"]:
+                if active_domain is None:
+                    continue
+                else:
+                    ret.append(f'[{active_domain}_{t[1:]}')
+            elif t == "[car]":
+                ret.append("[taxi_type]")
+            elif t in ["[departure]", "[destination]"]:
+                ret.append("[value_place]")
+            elif t in ["[leave", "[arrive"]:
+                ret.append("[value_time]")
+            elif t in ["at]", "by]", "range]"]:
+                continue
+            elif t in ["[food]", "[area]", "[time]", "[day]", "[price]"]:
+                ret.append(f"[value_{t[1:]}")
+            elif t == "[price":
+                ret.append("[value_pricerange]")
+            elif t in ["[duration]", "[stars]", "[stay]", "[people]"]:
+                ret.append("[value_count]")
+            elif t == "[type]":
+                ret.append("hotel")
+            else:
+                pdb.set_trace()
+        else:
+            ret.append(t)
+
+    ret = " ".join(ret)
+    ret = re.sub("([0-9])+", "[value_count]", ret)
+
+    return ret
+
+
+def replace_hdsa_tokens(text):
+    tokens = text.split()
+    ret = []
+    for t in tokens:
+        if t[0] == "[" and len(t) > 1:
+            if t == "[train_trainid]":
+                ret.append("[train_id]")
+            elif t in ["[train_arriveby]", "[train_leaveat]"]:
+                ret.append("[value_time]")
+            elif t in ["[train_departure]", "[train_destination]"]:
+                ret.append("[value_place]")
+            elif t in ["[hotel_pricerange]", "[restaurant_pricerange]"]:
+                ret.append("[value_pricerange]")
+            elif t in ["[hotel_area]", "[attraction_area]", "[restaurant_area]"]:
+                ret.append("[value_area]")
+            elif t in ["[train_price]"]:
+                ret.append("[value_price]")
+            elif t == "[train_day]":
+                ret.append("[value_day]")
+            elif t == "[restaurant_food]":
+                ret.append("[value_food]")
+            else:
+                ret.append(t)
+                # skipped.append(t)
+        else:
+            ret.append(t)
+
+    # ret = " ".join(ret)
+    # ret = re.sub("([0-9])+", "[value_count]", ret)
+
+    # return ret, skipped
+    return ret
+
diff --git a/latent_dialog/base_data_loaders.py b/latent_dialog/base_data_loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..e96372acf9fcdbbf17bf393dc58158751212dacc
--- /dev/null
+++ b/latent_dialog/base_data_loaders.py
@@ -0,0 +1,183 @@
+import numpy as np
+import logging
+
+
+class BaseDataLoaders(object):
+    def __init__(self, name):
+        self.data_size = None
+        self.indexes = None
+        self.name = name
+
+    def _shuffle_indexes(self):
+        np.random.shuffle(self.indexes)
+
+    def _shuffle_batch_indexes(self):
+        np.random.shuffle(self.batch_indexes)
+
+    def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
+        self.ptr = 0
+        self.batch_size = config.batch_size
+        self.num_batch = self.data_size // config.batch_size
+
+        if verbose:
+            print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
+
+        if shuffle and not fix_batch:
+            self._shuffle_indexes()
+
+        self.batch_indexes = []
+        for i in range(self.num_batch):
+            self.batch_indexes.append(self.indexes[i*self.batch_size: (i+1)*self.batch_size])
+
+        if shuffle and fix_batch:
+            self._shuffle_batch_indexes()
+
+        if verbose:
+            print('%s begins with %d batches' % (self.name, self.num_batch))
+
+    def next_batch(self):
+        if self.ptr < self.num_batch:
+            selected_ids = self.batch_indexes[self.ptr]
+            self.ptr += 1
+            return self._prepare_batch(selected_index=selected_ids)
+        else:
+            return None
+
+    def _prepare_batch(self, *args, **kwargs):
+        raise NotImplementedError('Have to override _prepare_batch()')
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+
+class LongDataLoader(object):
+    """A special efficient data loader for TBPTT. Assume the data contains
+    N long sequences, each sequence has length k_i
+
+    :ivar batch_size: the size of a minibatch
+    :ivar backward_size: how many steps in time to do BP
+    :ivar step_size: how fast we move the window
+    :ivar ptr: the current idx of batch
+    :ivar num_batch: the total number of batch
+    :ivar batch_indexes: a list of list. Each item is the IDs in this batch
+    :ivar grid_indexes: a list of (b_id, s_id, e_id). b_id is the index of
+    batch, s_id is the starting time id in that batch and e_id is the ending
+    time id.
+    :ivar indexes: a list, the ordered of sequences ID it should go through
+    :ivar data_size: the number of sequences, N.
+    :ivar data_lens: a list containing k_i
+    :ivar prev_alive_size:
+    :ivar name: the name of the this data loader
+    """
+    logger = logging.getLogger()
+
+    def __init__(self, name):
+        self.batch_size = 0
+        self.backward_size = 0
+        self.step_size = 0
+        self.ptr = 0
+        self.num_batch = None
+        self.batch_indexes = None  # one batch is a dialog
+        self.grid_indexes = None  # grid is the tokenized versiion
+        self.indexes = None
+        self.data_lens = None
+        self.data_size = None
+        self.name = name
+
+    def _shuffle_batch_indexes(self):
+        np.random.shuffle(self.batch_indexes)
+
+    def _shuffle_grid_indexes(self):
+        np.random.shuffle(self.grid_indexes)
+
+    def _prepare_batch(self, cur_grid, prev_grid):
+        raise NotImplementedError("Have to override prepare batch")
+
+    def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
+
+        assert len(self.indexes) == self.data_size and \
+               len(self.data_lens) == self.data_size
+        # make sure backward_size can be divided by step size
+        assert config.backward_size % config.step_size == 0
+
+        self.ptr = 0
+        self.batch_size = config.batch_size
+        self.backward_size = config.backward_size
+        self.step_size = config.step_size
+
+        # create batch indexes
+        temp_num_batch = self.data_size // config.batch_size
+        self.batch_indexes = []
+        for i in range(temp_num_batch):
+            self.batch_indexes.append(
+                self.indexes[i * self.batch_size:(i + 1) * self.batch_size])
+
+        left_over = self.data_size - temp_num_batch * config.batch_size
+        if shuffle:
+            self._shuffle_batch_indexes()
+
+        # create grid indexes
+        self.grid_indexes = []
+        for idx, b_ids in enumerate(self.batch_indexes):
+            # assume the b_ids are sorted
+            all_lens = [self.data_lens[i] for i in b_ids]
+            max_len = self.data_lens[b_ids[0]]
+            min_len = self.data_lens[b_ids[-1]]
+            assert np.max(all_lens) == max_len
+            assert np.min(all_lens) == min_len
+            num_seg = (max_len - self.backward_size - self.step_size) // self.step_size
+            cut_start, cut_end = [], []
+            if num_seg > 1:
+                cut_start = list(range(config.step_size, num_seg * config.step_size, config.step_size))
+                cut_end = list(range(config.backward_size + config.step_size,
+                                num_seg * config.step_size + config.backward_size,
+                                config.step_size))
+                assert cut_end[-1] < max_len
+
+            actual_size = min(max_len, config.backward_size)
+            temp_end = list(range(2, actual_size, config.step_size))
+            temp_start = [0] * len(temp_end)
+
+            cut_start = temp_start + cut_start
+            cut_end = temp_end + cut_end
+
+            assert len(cut_end) == len(cut_start)
+            new_grids = [(idx, s_id, e_id) for s_id, e_id in
+                         zip(cut_start, cut_end) if s_id < min_len - 1]
+
+            self.grid_indexes.extend(new_grids)
+
+        # shuffle batch indexes
+        if shuffle:
+            self._shuffle_grid_indexes()
+
+        self.num_batch = len(self.grid_indexes)
+        if verbose:
+            self.logger.info("%s init with %d batches with %d left over samples" %
+                             (self.name, self.num_batch, left_over))
+
+    def next_batch(self):
+        if self.ptr < self.num_batch:
+            current_grid = self.grid_indexes[self.ptr]
+            if self.ptr > 0:
+                prev_grid = self.grid_indexes[self.ptr - 1]
+            else:
+                prev_grid = None
+            self.ptr += 1
+            return self._prepare_batch(cur_grid=current_grid,
+                                       prev_grid=prev_grid)
+        else:
+            return None
+
+    def pad_to(self, max_len, tokens, do_pad=True):
+        if len(tokens) >= max_len:
+            return tokens[0:max_len - 1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
diff --git a/latent_dialog/base_models.py b/latent_dialog/base_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..c793e8c732bbdb4220b1adfae375e2beb69808bc
--- /dev/null
+++ b/latent_dialog/base_models.py
@@ -0,0 +1,116 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.autograd import Variable
+import numpy as np
+from latent_dialog.utils import INT, FLOAT, LONG, cast_type
+import pdb
+
+
+class BaseModel(nn.Module):
+    def __init__(self, config):
+        super(BaseModel, self).__init__()
+        self.use_gpu = config.use_gpu
+        self.config = config
+        self.kl_w = 0.0
+
+    def np2var(self, inputs, dtype):
+        if inputs is None:
+            return None
+        return cast_type(Variable(th.from_numpy(inputs)), 
+                         dtype, 
+                         self.use_gpu)
+
+    def forward(self, *inputs):
+        raise NotImplementedError
+
+    def backward(self, loss, batch_cnt):
+        total_loss = self.valid_loss(loss, batch_cnt)
+        total_loss.backward()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        total_loss = 0.0
+        for k, l in loss.items():
+            if l is not None:
+                total_loss += l
+        return total_loss
+
+    def get_optimizer(self, config, verbose=True):
+        if config.op == 'adam':
+            if verbose:
+                print('Use Adam')
+            return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=config.init_lr, weight_decay=config.l2_norm)
+        elif config.op == 'sgd':
+            print('Use SGD')
+            return optim.SGD(self.parameters(), lr=config.init_lr, momentum=config.momentum)
+        elif config.op == 'rmsprop':
+            print('Use RMSProp')
+            return optim.RMSprop(self.parameters(), lr=config.init_lr, momentum=config.momentum)
+
+    def get_clf_optimizer(self, config):
+        params = []
+        params.extend(self.gru_attn_encoder.parameters())
+        params.extend(self.feat_projecter.parameters())
+        params.extend(self.sel_classifier.parameters())
+
+        if config.fine_tune_op == 'adam':
+            print('Use Adam')
+            return optim.Adam(params, lr=config.fine_tune_lr)
+        elif config.fine_tune_op == 'sgd':
+            print('Use SGD')
+            return optim.SGD(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum)
+        elif config.fine_tune_op == 'rmsprop':
+            print('Use RMSProp')
+            return optim.RMSprop(params, lr=config.fine_tune_lr, momentum=config.fine_tune_momentum)
+
+        
+    def model_sel_loss(self, loss, batch_cnt):
+        return self.valid_loss(loss, batch_cnt)
+
+
+    def extract_short_ctx(self, context, context_lens, backward_size=1):
+        utts = []
+        for b_id in range(context.shape[0]):
+            utts.append(context[b_id, context_lens[b_id]-1])
+            # utt = []
+            # for i in range(context_lens[b_id]):
+                # utt.extend(context[b_id, i])
+            # utts.append(utt)
+        return np.array(utts)
+
+    def flatten_context(self, context, context_lens, align_right=False):
+        utts = []
+        temp_lens = []
+        for b_id in range(context.shape[0]):
+            temp = []
+            for t_id in range(context_lens[b_id]):
+                for token in context[b_id, t_id]:
+                    if token != 0:
+                        temp.append(token)
+            temp_lens.append(len(temp))
+            utts.append(temp)
+        max_temp_len = np.max(temp_lens)
+        results = np.zeros((context.shape[0], max_temp_len))
+        for b_id in range(context.shape[0]):
+            if align_right:
+                results[b_id, -temp_lens[b_id]:] = utts[b_id]
+            else:
+                results[b_id, 0:temp_lens[b_id]] = utts[b_id]
+
+        return results
+
+
+def frange_cycle_linear(n_iter, start=0.0, stop=1.0,  n_cycle=4, ratio=0.5):
+    L = np.ones(n_iter) * stop
+    period = n_iter/n_cycle
+    step = (stop-start)/(period*ratio) # linear schedule
+
+    for c in range(n_cycle):
+        v, i = start, 0
+        while v <= stop and (int(i+c*period) < n_iter):
+            L[int(i+c*period)] = v
+            v += step
+            i += 1
+    return L 
+
diff --git a/latent_dialog/corpora.py b/latent_dialog/corpora.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd1f8ca4e2bd46e2de1c574b65a4c2d53aa0609f
--- /dev/null
+++ b/latent_dialog/corpora.py
@@ -0,0 +1,573 @@
+from __future__ import unicode_literals
+import numpy as np
+from collections import Counter
+from latent_dialog.utils import Pack, get_tokenize, get_chat_tokenize, missingdict
+from latent_dialog.augpt_utils import AutoDatabase, AutoLexicalizer, augpt_normalize
+import json
+from nltk.tokenize import WordPunctTokenizer
+import logging
+from collections import defaultdict
+import pdb
+import os
+import re
+
+PAD = '<pad>'
+UNK = '<unk>'
+USR = 'YOU:'
+SYS = 'THEM:'
+BOD = '<d>'
+EOD = '</d>'
+BOS = '<s>'
+EOS = '<eos>'
+SEL = '<selection>'
+SEP = "|"
+REQ = "<requestable>"
+INF = "<informable>"
+WILD = "%s"
+SPECIAL_TOKENS = [PAD, UNK, USR, SYS, BOS, BOD, EOS, EOD]
+STOP_TOKENS = [EOS, SEL]
+DECODING_MASKED_TOKENS = [PAD, UNK, USR, SYS, BOD]
+
+REQ_TOKENS = {}
+DOMAIN_REQ_TOKEN = ['restaurant', 'hospital', 'hotel','attraction', 'train', 'police', 'taxi']
+ACTIVE_BS_IDX = [13, 30, 35, 61, 72, 91, 93] #indexes in the BS indicating if domain is active
+NO_MATCH_DB_IDX = [-1, 0, -1, 6, 12, 18, -1] # indexes in DB pointer indicating 0 match is found for that domain, -1 mean that domain has no DB
+REQ_TOKENS['attraction'] = ["[attraction_address]", "[attraction_name]", "[attraction_phone]", "[attraction_postcode]", "[attraction_reference]", "[attraction_type]"]
+REQ_TOKENS['hospital'] = ["[hospital_address]", "[hospital_department]", "[hospital_name]", "[hospital_phone]", "[hospital_postcode]"] #, "[hospital_reference]"
+REQ_TOKENS['hotel'] = ["[hotel_address]", "[hotel_name]", "[hotel_phone]", "[hotel_postcode]", "[hotel_reference]", "[hotel_type]"]
+REQ_TOKENS['restaurant'] = ["[restaurant_name]", "[restaurant_address]", "[restaurant_phone]", "[restaurant_postcode]", "[restaurant_reference]"]
+REQ_TOKENS['train'] = ["[train_id]", "[train_reference]"]
+REQ_TOKENS['police'] = ["[police_address]", "[police_phone]", "[police_postcode]"] #"[police_name]", 
+REQ_TOKENS['taxi'] = ["[taxi_phone]", "[taxi_type]"]
+
+GENERIC_TOKENS = ["[value_area]", "[value_count]", "[value_day]", "[value_food]", "[value_place]", "[value_price]", "[value_pricerange]", "[value_time]"]
+
+
+class NormMultiWozCorpus(object):
+    logger = logging.getLogger()
+
+    def __init__(self, config):
+        self.bs_size = 94
+        self.db_size = 30
+        self.goal_size = 77
+        self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b']
+        self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
+        self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt']
+        self.config = config
+        self.tokenize = lambda x: x.split()
+        self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config)
+        self._extract_vocab()
+        self._extract_goal_vocab()
+        self.logger.info('Loading corpus finished.')
+
+    def _read_file(self, config):
+        train_data = json.load(open(config.train_path))
+        valid_data = json.load(open(config.valid_path))
+        test_data = json.load(open(config.test_path))
+        
+        train_data = self._process_dialogue(train_data)
+        valid_data = self._process_dialogue(valid_data)
+        test_data = self._process_dialogue(test_data)
+
+        return train_data, valid_data, test_data
+
+    def _process_dialogue(self, data):
+        new_dlgs = []
+        all_sent_lens = []
+        all_dlg_lens = []
+
+        for key, raw_dlg in data.items():
+            norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)]
+            for t_id in range(len(raw_dlg['db'])):
+                usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS]
+                sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS]
+                norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
+                norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
+                all_sent_lens.extend([len(usr_utt), len(sys_utt)])
+            # To stop dialog
+            norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
+            # if self.config.to_learn == 'usr':
+            #     norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
+            all_dlg_lens.append(len(raw_dlg['db']))
+            processed_goal = self._process_goal_simple(raw_dlg['goal'])
+            new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key))
+
+        self.logger.info('Max utt len = %d, mean utt len = %.2f' % (
+            np.max(all_sent_lens), float(np.mean(all_sent_lens))))
+        self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % (
+            np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
+        return new_dlgs
+
+    def _extract_vocab(self):
+        all_words = []
+        for dlg in self.train_corpus:
+            for turn in dlg.dlg:
+                all_words.extend(turn.utt)
+        vocab_count = Counter(all_words).most_common()
+        raw_vocab_size = len(vocab_count)
+        keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size)
+        oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words))
+
+        self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0],
+                                                               vocab_count[keep_vocab_size - 1][1]) +
+              'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100))
+
+        vocab_count = vocab_count[0:keep_vocab_size]
+        self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS]
+        self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)}
+        self.unk_id = self.vocab_dict[UNK]
+        self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab)))
+
+    def _process_goal(self, raw_goal):
+        res = {}
+        for domain in self.domains:
+            all_words = []
+            d_goal = raw_goal[domain]
+            if d_goal:
+                for info_type in self.info_types:
+                    sv_info = d_goal.get(info_type, dict())
+                    if info_type == 'reqt' and isinstance(sv_info, list):
+                        all_words.extend([info_type + '|' + item for item in sv_info])
+                    elif isinstance(sv_info, dict):
+                        all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()])
+                    else:
+                        print('Fatal Error!')
+                        exit(-1)
+            res[domain] = all_words
+        return res
+
+    def _process_goal_simple(self, raw_goal):
+        res = {}
+        for domain in self.domains:
+            all_words = []
+            if domain in raw_goal:
+                d_goal = raw_goal[domain]
+                for info_type in ['book', 'info', 'reqt']:
+                    sv_info = d_goal.get(info_type, dict())
+                    if info_type == 'reqt' and isinstance(sv_info, list):
+                        all_words.extend([info_type + '|' + item for item in sv_info])
+                    elif isinstance(sv_info, dict):
+                        all_words.extend([info_type + '|' + k for k, v in sv_info.items()])
+                    else:
+                        print('Fatal Error!')
+                        exit(-1)
+            res[domain] = all_words
+        return res
+
+
+    def _extract_goal_vocab(self):
+        self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {}
+        for domain in self.domains:
+            all_words = []
+            for dlg in self.train_corpus:
+                all_words.extend(dlg.goal[domain])
+            vocab_count = Counter(all_words).most_common()
+            raw_vocab_size = len(vocab_count)
+            discard_wc = np.sum([c for t, c in vocab_count])
+
+            self.logger.info('================= domain = {}, \n'.format(domain) +
+                  'goal vocab size of train set = %d, \n' % (raw_vocab_size,) +
+                  'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) +
+                  'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),))
+
+            self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count]
+            self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])}
+            self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK]
+
+    def get_corpus(self):
+        id_train = self._to_id_corpus('Train', self.train_corpus)
+        id_val = self._to_id_corpus('Valid', self.val_corpus)
+        id_test = self._to_id_corpus('Test', self.test_corpus)
+        return id_train, id_val, id_test
+
+    def _to_id_corpus(self, name, data):
+        results = []
+        for dlg in data:
+            if len(dlg.dlg) < 1:
+                continue
+            id_dlg = []
+            for turn in dlg.dlg:
+                id_turn = Pack(utt=self._sent2id(turn.utt),
+                               speaker=turn.speaker,
+                               db=turn.db, bs=turn.bs)
+                id_dlg.append(id_turn)
+            id_goal = self._goal2id(dlg.goal)
+            results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key))
+        return results
+
+    def _sent2id(self, sent):
+        return [self.vocab_dict.get(t, self.unk_id) for t in sent]
+
+    def _goal2id(self, goal):
+        res = {}
+        for domain in self.domains:
+            d_bow = [0.0] * len(self.goal_vocab[domain])
+            for word in goal[domain]:
+                word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain])
+                d_bow[word_id] += 1.0
+            res[domain] = d_bow
+        return res
+
+    def id2sent(self, id_list):
+        return [self.vocab[i] for i in id_list]
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+def get_summary_bstate(bstate):
+    """Based on the mturk annotations we form multi-domain belief state"""
+    domains = [u'taxi', u'restaurant',  u'hospital',
+               u'hotel', u'attraction', u'train', u'police']
+    book_keys = {
+        'taxi': ['booked'],
+        'police': ['booked'],
+        'hospital': ['booked'],
+        'restaurant': ['booked', 'time', 'day', 'people'],
+        'hotel': ['booked', 'stay', 'day', 'people'],
+        'attraction': ['booked'],
+        'train': ['booked', 'people']
+        }
+
+    semi_keys = {
+            'taxi': ['leave at', 'destination', 'departure', 'arrive by'],
+            'police': [],
+            'restaurant': ['food', 'price range', 'name', 'area'],
+            'hospital': ['department'],
+            'hotel': ['name', 'area', 'parking', 'price range', 'stars', 'internet', 'type'],
+            'attraction': ['type', 'name', 'area'],
+            'train': ['leave at', 'destination', 'day', 'arriveby', 'departure'],
+            }
+    # {'taxi': {'book': {'booked': []}, 'semi': {'leaveAt': '', 'destination': '', 'departure': '', 'arriveBy': ''}}, 
+    # 'police': {'book': {'booked': []}, 'semi': {}}, 
+    # 'restaurant': {'book': {'booked': [], 'time': '', 'day': '', 'people': ''}, 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, 
+    # 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, 
+    # 'hotel': {'book': {'booked': [], 'stay': '', 'day': '', 'people': ''}, 'semi': {'name': 'not mentioned', 'area': 'not mentioned', 'parking': 'not mentioned', 'pricerange': 'cheap', 'stars': 'not mentioned', 'internet': 'not mentioned', 'type': 'hotel'}}, 
+    #'attraction': {'book': {'booked': []}, 'semi': {'type': '', 'name': '', 'area': ''}}, 
+    #'train': {'book': {'booked': [], 'people': ''}, 'semi': {'leaveAt': '', 'destination': '', 'day': '', 'arriveBy': '', 'departure': ''}}}
+    summary_bstate = []
+    pdb.set_trace()
+    for domain in domains:
+        domain_active = False
+
+        booking = []
+        if domain in bstate:
+            # for slot in sorted(bstate[domain]['book'].keys()):
+            for slot in sorted(book_keys[domain]):
+                if slot == 'booked':
+                    if bstate[domain]['book']['booked']:
+                        booking.append(1)
+                    else:
+                        booking.append(0)
+                else:
+                    if bstate[domain]['book'][slot] != "":
+                        booking.append(1)
+                    else:
+                        booking.append(0)
+            if domain == 'train':
+                if 'people' not in bstate[domain]['book'].keys():
+                    booking.append(0)
+                if 'ticket' not in bstate[domain]['book'].keys():
+                    booking.append(0)
+        else:
+            if domain == "train":
+                booking = [0, 0, 0]
+            else:
+                booking = [0]
+        summary_bstate += booking
+
+        # for slot in bstate[domain]['semi']:
+        for slot in semi_keys(domain):
+            slot_enc = [0, 0, 0]
+            if slot in bstate[domain]:
+                slot_enc[2] = 1
+            else:
+                slot_enc[0] = 1
+            # if bstate[domain]['semi'][slot] == 'not mentioned':
+                # slot_enc[0] = 1
+            # elif bstate[domain]['semi'][slot] == 'dont care' or bstate[domain]['semi'][slot] == 'dontcare' or bstate[domain]['semi'][slot] == "don't care":
+                # slot_enc[1] = 1
+            # elif bstate[domain]['semi'][slot]:
+            #     slot_enc[2] = 1
+            if slot_enc != [0, 0, 0]:
+                domain_active = True
+            summary_bstate += slot_enc
+
+        # quasi domain-tracker
+        if domain_active:
+            summary_bstate += [1]
+        else:
+            summary_bstate += [0]
+
+
+    # print(len(summary_bstate))
+    assert len(summary_bstate) == 94
+    return summary_bstate
+
+def addDBPointer(db_text, booked_text):
+    """Create database pointer for all related domains."""
+    matches = {
+            "restaurant": [0, 0, 0, 0, 0, 1],
+            "hotel" : [0, 0, 0, 0, 0, 1],
+            "attraction" : [0, 0, 0, 0, 0, 1],
+            "train" : [0, 0, 0, 0, 0, 1]
+            }
+
+    booked = {
+            "restaurant": [1, 0],
+            "hotel": [1, 0],
+            "train": [1, 0]
+            }
+
+    
+    if booked_text[0] != "none":
+        for domain in booked_text:
+            booked[domain] = [0, 1]
+
+
+    if len(db_text) > 0:
+        for i in range(int(len(db_text)/2)):
+            dom = db_text[i * 2]
+            num = db_text[i * 2 + 1]
+            if dom != "train":
+                if num == "0":
+                    matches[dom] = [1, 0, 0, 0, 0, 0]
+                elif num == "1":
+                    matches[dom] = [0, 1, 0, 0, 0, 0]
+                elif num == "2":
+                    matches[dom] = [0, 0, 1, 0, 0, 0]
+                elif num == "3":
+                    matches[dom] = [0, 0, 0, 1, 0, 0]
+                elif num == "4":
+                    matches[dom] = [0, 0, 0, 0, 1, 0]
+            else:
+                if num == "0":
+                    matches[dom] = [1, 0, 0, 0, 0, 0]
+                elif int(num) <= 2:
+                    matches[dom] = [0, 1, 0, 0, 0, 0]
+                elif int(num) <= 5:
+                    matches[dom] = [0, 0, 1, 0, 0, 0]
+                elif int(num) <= 10:
+                    matches[dom] = [0, 0, 0, 1, 0, 0]
+                elif int(num) <= 40:
+                    matches[dom] = [0, 0, 0, 0, 1, 0]
+
+
+    metadata =  matches["restaurant"] + matches["hotel"] + matches["attraction"] + matches["train"] + booked["restaurant"] + booked["hotel"] + booked["train"]
+
+    return metadata
+
+class NormMultiWozCorpusAE(object):
+    logger = logging.getLogger()
+
+    def __init__(self, config):
+        self.bs_size = 94
+        self.db_size = 30
+        self.goal_size = 77
+        self.bs_types =['b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'b', 'b', 'b']
+        self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
+        self.info_types = ['book', 'fail_book', 'fail_info', 'info', 'reqt']
+        # self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome']
+        # self.act2id = {a:i for i, a in enumerate(self.act_types)}
+        # self.id2act = {i:a for i, a in enumerate(self.act_types)}
+        # self.act_size = len(self.act_types) #domain agnostic act 
+        # self.act_size = len(domain) * len(self.act_types) #domain dependent act 
+        self.config = config
+        self.tokenize = lambda x: x.split()
+        self.train_corpus, self.val_corpus, self.test_corpus = self._read_file(self.config)
+        self._extract_vocab()
+        self._extract_goal_vocab()
+        self.logger.info('Loading corpus finished.')
+
+    def _read_file(self, config):
+        train_data = json.load(open(config.train_path))
+        valid_data = json.load(open(config.valid_path))
+        test_data = json.load(open(config.test_path))
+        
+        train_data = self._process_dialogue(train_data)
+        valid_data = self._process_dialogue(valid_data)
+        test_data = self._process_dialogue(test_data)
+
+        return train_data, valid_data, test_data
+
+    def _process_dialogue(self, data):
+        new_dlgs = []
+        all_sent_lens = []
+        all_dlg_lens = []
+
+        for key, raw_dlg in data.items():
+            norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)]
+            for t_id in range(len(raw_dlg['db'])):
+                usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS]
+                sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS]
+                norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
+                norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
+                all_sent_lens.extend([len(usr_utt), len(sys_utt)])
+            # To stop dialog
+            norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
+            # if self.config.to_learn == 'usr':
+            #     norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
+            all_dlg_lens.append(len(raw_dlg['db']))
+            processed_goal = self._process_goal(raw_dlg['goal'])
+            new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key))
+
+        self.logger.info('Max utt len = %d, mean utt len = %.2f' % (
+            np.max(all_sent_lens), float(np.mean(all_sent_lens))))
+        self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % (
+            np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
+        return new_dlgs
+
+    def _extract_vocab(self):
+        all_words = []
+        for dlg in self.train_corpus:
+            for turn in dlg.dlg:
+                all_words.extend(turn.utt)
+        vocab_count = Counter(all_words).most_common()
+        raw_vocab_size = len(vocab_count)
+        keep_vocab_size = min(self.config.max_vocab_size, raw_vocab_size)
+        oov_rate = np.sum([c for t, c in vocab_count[0:keep_vocab_size]]) / float(len(all_words))
+
+        self.logger.info('cut off at word {} with frequency={},\n'.format(vocab_count[keep_vocab_size - 1][0],
+                                                               vocab_count[keep_vocab_size - 1][1]) +
+              'OOV rate = {:.2f}%'.format(100.0 - oov_rate * 100))
+
+        vocab_count = vocab_count[0:keep_vocab_size]
+        self.vocab = SPECIAL_TOKENS + [t for t, cnt in vocab_count if t not in SPECIAL_TOKENS]
+        self.vocab_dict = {t: idx for idx, t in enumerate(self.vocab)}
+        self.unk_id = self.vocab_dict[UNK]
+        self.logger.info("Raw vocab size {} in train set and final vocab size {}".format(raw_vocab_size, len(self.vocab)))
+
+    def _process_goal(self, raw_goal):
+        res = {}
+        for domain in self.domains:
+            all_words = []
+            d_goal = raw_goal[domain]
+            if d_goal:
+                for info_type in self.info_types:
+                    sv_info = d_goal.get(info_type, dict())
+                    if info_type == 'reqt' and isinstance(sv_info, list):
+                        all_words.extend([info_type + '|' + item for item in sv_info])
+                    elif isinstance(sv_info, dict):
+                        all_words.extend([info_type + '|' + k + '|' + str(v) for k, v in sv_info.items()])
+                    else:
+                        print('Fatal Error!')
+                        exit(-1)
+            res[domain] = all_words
+        return res
+    
+    def _process_multidomain_summary_acts(self, dact):
+        """
+        process dialogue action dictionary into binary vector representation
+        each domain has its own vector, and final output is the flattened respresentation of each domain's action
+        """
+        res = {}
+        # dact = {domain:{action:[slot]}, domain:{action:[slot]}}
+        for domain in self.domains:
+            res[domain] = np.zeros(len(self.act_types))
+            if domain in dact.keys(): 
+                for i in range(len(self.act_types)):
+                    if self.act_types[i] in dact[domain].keys():
+                        res[domain][i] = 1
+
+
+        # multiwoz dact = {domain-act:[[slot, value], [slot, value]]}
+        # for domain in self.domains:
+            # res[domain] = np.zeros(len(self.act_types))
+        # for k in dact.keys():
+            # d = k.split("-")[0].lower()
+            # a = k.split("-")[1].lower()
+
+            # res[d][self.act2id[a]] = 1
+
+        flat_res = [act for domain in sorted(self.domains) for act in res[domain]]
+        return flat_res
+    
+    def _process_summary_acts(self, dact):
+        """
+        process dialogue action dictionary into binary vector representation, ignoring domain information
+        """
+        res = np.zeros(len(self.act_types))
+        # damd dact = {domain:{action:[slot]}, domain:{action:[slot]}}
+        for domain in self.domains:
+            if domain in dact.keys(): 
+                for i in range(len(self.act_types)):
+                    if self.act_types[i] in dact[domain].keys():
+                        res[i] = 1
+
+        # multiwoz dact = {domain-act:[[slot, value], [slot, value]]}
+        # for k in dact.keys():
+            # # d = k.split("-")[0].lower()
+            # a = k.split("-")[1].lower()
+
+           #  res[self.act2id[a]] = 1
+
+        return list(res)
+
+    def _extract_goal_vocab(self):
+        self.goal_vocab, self.goal_vocab_dict, self.goal_unk_id = {}, {}, {}
+        for domain in self.domains:
+            all_words = []
+            for dlg in self.train_corpus:
+                all_words.extend(dlg.goal[domain])
+            vocab_count = Counter(all_words).most_common()
+            raw_vocab_size = len(vocab_count)
+            discard_wc = np.sum([c for t, c in vocab_count])
+
+            self.logger.info('================= domain = {}, \n'.format(domain) +
+                  'goal vocab size of train set = %d, \n' % (raw_vocab_size,) +
+                  'cut off at word %s with frequency = %d, \n' % (vocab_count[-1][0], vocab_count[-1][1]) +
+                  'OOV rate = %.2f' % (1 - float(discard_wc) / len(all_words),))
+
+            self.goal_vocab[domain] = [UNK] + [g for g, cnt in vocab_count]
+            self.goal_vocab_dict[domain] = {t: idx for idx, t in enumerate(self.goal_vocab[domain])}
+            self.goal_unk_id[domain] = self.goal_vocab_dict[domain][UNK]
+
+    def get_corpus(self):
+        id_train = self._to_id_corpus('Train', self.train_corpus)
+        id_val = self._to_id_corpus('Valid', self.val_corpus)
+        id_test = self._to_id_corpus('Test', self.test_corpus)
+        return id_train, id_val, id_test
+
+    def _to_id_corpus(self, name, data):
+        results = []
+        for dlg in data:
+            if len(dlg.dlg) < 1:
+                continue
+            id_dlg = []
+            for turn in dlg.dlg:
+                id_turn = Pack(utt=self._sent2id(turn.utt),
+                               speaker=turn.speaker,
+                               db=turn.db, bs=turn.bs) #, act=turn.act)
+                id_dlg.append(id_turn)
+            id_goal = self._goal2id(dlg.goal)
+            results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key))
+        return results
+
+    def _sent2id(self, sent):
+        return [self.vocab_dict.get(t, self.unk_id) for t in sent]
+
+    def _goal2id(self, goal):
+        res = {}
+        for domain in self.domains:
+            d_bow = [0.0] * len(self.goal_vocab[domain])
+            for word in goal[domain]:
+                word_id = self.goal_vocab_dict[domain].get(word, self.goal_unk_id[domain])
+                d_bow[word_id] += 1.0
+            res[domain] = d_bow
+        return res
+
+    def id2sent(self, id_list):
+        return [self.vocab[i] for i in id_list]
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
diff --git a/latent_dialog/criterions.py b/latent_dialog/criterions.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3d1e78b8aa8aeb80fe5108fb481a95c804822f3
--- /dev/null
+++ b/latent_dialog/criterions.py
@@ -0,0 +1,184 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.loss import _Loss
+import numpy as np
+import pdb
+import math
+from latent_dialog import domain
+from latent_dialog.utils import LONG
+
+class NLLEntropy(_Loss):
+    def __init__(self, padding_idx, avg_type):
+        super(NLLEntropy, self).__init__()
+        self.padding_idx = padding_idx
+        self.avg_type = avg_type
+
+    def set_weight(self, weight):
+        self.weight = weight
+
+    def forward(self, net_output, labels, weight=None):
+        batch_size = net_output.size(0)
+        pred = net_output.view(-1, net_output.size(-1))
+        target = labels.view(-1)
+
+        if self.avg_type is None:
+            loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx)
+        elif self.avg_type == 'seq':
+            loss = F.nll_loss(pred, target, size_average=False, ignore_index=self.padding_idx)
+            loss = loss / batch_size
+        elif self.avg_type == 'real_word':
+            loss = F.nll_loss(pred, target, ignore_index=self.padding_idx, reduce=False)
+            loss = loss.view(-1, net_output.size(1))
+            loss = th.sum(loss, dim=1)
+            word_cnt = th.sum(th.sign(labels), dim=1).float()
+            loss = loss / word_cnt
+            loss = th.mean(loss)
+        elif self.avg_type == 'word':
+            loss = F.nll_loss(pred, target, reduction='mean', ignore_index=self.padding_idx)
+        elif self.avg_type == 'weighted':
+            loss = F.nll_loss(pred, target, weight=self.weight, reduction='mean', ignore_index=self.padding_idx)
+        else:
+            raise ValueError('Unknown average type')
+
+        return loss
+ 
+class NLLEntropy4CLF(_Loss):
+    def __init__(self, dictionary, bad_tokens=['<disconnect>', '<disagree>'], reduction='elementwise_mean'):
+        super(NLLEntropy4CLF, self).__init__()
+        w = th.Tensor(len(dictionary)).fill_(1)
+        for token in bad_tokens:
+            w[dictionary[token]] = 0.0
+        self.crit = nn.CrossEntropyLoss(w, reduction=reduction)
+
+    def forward(self, preds, labels):
+        # preds: (batch_size, outcome_len, outcome_vocab_size)
+        # labels: (batch_size, outcome_len)
+        preds = preds.view(-1, preds.size(-1))
+        labels = labels.view(-1)
+        return self.crit(preds, labels)
+
+class CombinedNLLEntropy4CLF(_Loss):
+    def __init__(self, dictionary, corpus, np2var, bad_tokens=['<disconnect>', '<disagree>']):
+        super(CombinedNLLEntropy4CLF, self).__init__()
+        self.dictionary = dictionary
+        self.domain = domain.get_domain('object_division')
+        self.corpus = corpus
+        self.np2var = np2var
+        self.bad_tokens = bad_tokens
+
+    def forward(self, preds, goals_id, outcomes_id):
+        # preds: (batch_size, outcome_len, outcome_vocab_size)
+        # goals_id: list of list, id, batch_size*goal_len
+        # outcomes_id: list of list, id, batch_size*outcome_len
+        batch_size = len(goals_id)
+        losses = []
+        for bth in range(batch_size):
+            pred = preds[bth] # (outcome_len, outcome_vocab_size)
+            goal = goals_id[bth] # list, id, len=goal_len
+            goal_str = self.corpus.id2goal(goal) # list, str, len=goal_len
+            outcome = outcomes_id[bth] # list, id, len=outcome_len
+            outcome_str = self.corpus.id2outcome(outcome) # list, str, len=outcome_len
+
+            if outcome_str[0] in self.bad_tokens:
+                continue
+
+            # get all the possible choices
+            choices = self.domain.generate_choices(goal_str)
+            sel_outs = [pred[i] for i in range(pred.size(0))] # outcome_len*(outcome_vocab_size, )
+
+            choices_logits = [] # outcome_len*(option_amount, 1)
+            for i in range(self.domain.selection_length()):
+                idxs = np.array([self.dictionary[c[i]] for c in choices])
+                idxs_var = self.np2var(idxs, LONG) # (option_amount, )
+                choices_logits.append(th.gather(sel_outs[i], 0, idxs_var).unsqueeze(1))
+
+            choice_logit = th.sum(th.cat(choices_logits, 1), 1, keepdim=False) # (option_amount, )
+            choice_logit = choice_logit.sub(choice_logit.max().item()) # (option_amount, )
+            prob = F.softmax(choice_logit, dim=0) # (option_amount, )
+
+            label = choices.index(outcome_str)
+            target_prob = prob[label]
+            losses.append(-th.log(target_prob))
+        return sum(losses) / float(len(losses))
+
+class CatKLLoss(_Loss):
+    def __init__(self):
+        super(CatKLLoss, self).__init__()
+
+    def forward(self, log_qy, log_py, batch_size=None, unit_average=False):
+        """
+        qy * log(q(y)/p(y))
+        """
+        qy = th.exp(log_qy)
+        y_kl = th.sum(qy * (log_qy - log_py), dim=1)
+        if unit_average:
+            return th.mean(y_kl)
+        else:
+            return th.sum(y_kl)/batch_size
+
+class Entropy(_Loss):
+    def __init__(self):
+        super(Entropy, self).__init__()
+
+    def forward(self, log_qy, batch_size=None, unit_average=False):
+        """
+        -qy log(qy)
+        """
+        if log_qy.dim() > 2:
+            log_qy = log_qy.squeeze()
+        qy = th.exp(log_qy)
+        h_q = th.sum(-1 * log_qy * qy, dim=1)
+        if unit_average:
+            return th.mean(h_q)
+        else:
+            return th.sum(h_q) / batch_size
+
+class GaussianEntropy(_Loss):
+    def __init__(self):
+        super(GaussianEntropy, self).__init__()
+
+    def forward(self, mu, logvar):
+        """
+        0.5 (log(mu*var)) + 0.5
+        """
+        std = th.exp(0.5 * logvar)
+        var = th.square(std)
+        h_q = 0.5 * (th.log(2 * math.pi * var)) + 0.5
+
+        return th.mean(h_q)
+
+class BinaryNLLEntropy(_Loss):
+
+    def __init__(self, size_average=True):
+        super(BinaryNLLEntropy, self).__init__()
+        self.size_average = size_average
+
+    def forward(self, net_output, label_output):
+        """
+        :param net_output: batch_size x
+        :param labels:
+        :return:
+        """
+        batch_size = net_output.size(0)
+        loss = F.binary_cross_entropy_with_logits(net_output, label_output, size_average=self.size_average)
+        if self.size_average is False:
+            loss /= batch_size
+        return loss
+
+class NormKLLoss(_Loss):
+    def __init__(self, unit_average=False):
+        super(NormKLLoss, self).__init__()
+        self.unit_average = unit_average
+
+    def forward(self, recog_mu, recog_logvar, prior_mu, prior_logvar):
+        # find the KL divergence between two Gaussian distribution
+        loss = 1.0 + (recog_logvar - prior_logvar)
+        loss -= th.div(th.pow(prior_mu - recog_mu, 2), th.exp(prior_logvar))
+        loss -= th.div(th.exp(recog_logvar), th.exp(prior_logvar))
+        if self.unit_average:
+            kl_loss = -0.5 * th.mean(loss, dim=1)
+        else:
+            kl_loss = -0.5 * th.sum(loss, dim=1)
+        avg_kl_loss = th.mean(kl_loss)
+        return avg_kl_loss
diff --git a/latent_dialog/data_loaders.py b/latent_dialog/data_loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..212d4fe871bdfc8a6f587219284d7d3e7aee0a87
--- /dev/null
+++ b/latent_dialog/data_loaders.py
@@ -0,0 +1,434 @@
+import numpy as np
+import copy
+import random
+import traceback
+from enum import Enum
+from collections import defaultdict
+import pdb
+from latent_dialog.utils import Pack
+from latent_dialog.base_data_loaders import BaseDataLoaders, LongDataLoader
+from latent_dialog.corpora import USR, SYS, UNK, SPECIAL_TOKENS
+import json
+
+
+class NoiseType(Enum):
+    REMOVAL = "tokens_removal"
+    CHANGE  = "tokens_switch"
+
+class BeliefDbDataLoaders(BaseDataLoaders):
+    def __init__(self, name, data, config, pad_context=True, noise=None, ind_voc=None, logger=None):
+        super(BeliefDbDataLoaders, self).__init__(name)
+        self.max_utt_len = config.max_utt_len
+        self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size)
+        self.data_size = len(self.data)
+        self.pad_context = pad_context
+        self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
+
+        self.noise_type = noise
+
+        if self.noise_type is not None:
+            self.noise_p = config.noise_p
+            self.remove_tokens = config.remove_tokens
+
+            self.ind_voc = copy.deepcopy(ind_voc)
+            if config.no_special:
+                for token in SPECIAL_TOKENS:
+                    if token != UNK:
+                        del self.ind_voc[token]
+            self.tokens = list(self.ind_voc.keys())
+
+            self.num_noised_tokens = 0
+            self.noised_tokens_dist = defaultdict(int)
+            self.num_examples = 0
+
+        self.logger = logger
+
+
+    def flatten_dialog( self, data, backward_size):
+        results = []
+        indexes = []
+        batch_indexes = []
+        resp_set = set()
+        for dlg in data:
+            goal = dlg.goal
+            key = dlg.key
+            batch_index = []
+            for i in range(1, len(dlg.dlg)):
+                if dlg.dlg[i].speaker == USR or dlg.dlg[i].speaker == "user":
+                    continue
+                e_idx = i
+                s_idx = max(0, e_idx - backward_size)
+                response = dlg.dlg[i].copy()
+                response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
+                resp_set.add(json.dumps(response.utt))
+                context = []
+                for turn in dlg.dlg[s_idx: e_idx]:
+                    turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
+                    context.append(turn)
+                results.append(Pack(context=context, response=response, goal=goal, key=key))
+                indexes.append(len(indexes))
+                batch_index.append(indexes[-1])
+            if len(batch_index) > 0:
+                batch_indexes.append(batch_index)
+        print("Unique resp {}".format(len(resp_set)))
+        return results, indexes, batch_indexes
+
+    def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
+        self.ptr = 0
+        if fix_batch:
+            self.batch_size = None
+            self.num_batch = len(self.batch_indexes)
+        else:
+            self.batch_size = config.batch_size
+            self.num_batch = self.data_size // config.batch_size
+            self.batch_indexes = []
+            for i in range(self.num_batch):
+                self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size])
+            if verbose:
+                print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
+        if shuffle:
+            if fix_batch:
+                self._shuffle_batch_indexes()
+            else:
+                self._shuffle_indexes()
+
+        if verbose:
+            print('%s begins with %d batches' % (self.name, self.num_batch))
+
+    def _prepare_batch(self, selected_index):
+        rows = [self.data[idx] for idx in selected_index]
+
+        ctx_utts, ctx_lens = [], []
+        out_utts, out_lens = [], []
+
+        out_bs, out_db = [] , []
+        raw_bs = []
+        goals, goal_lens = [], [[] for _ in range(len(self.domains))]
+        keys = []
+
+        for row in rows:
+            in_row, out_row, goal_row = row.context, row.response, row.goal
+
+            # source context
+            keys.append(row.key)
+            batch_ctx = []
+            for turn in in_row:
+                ctx_utt = copy.deepcopy(turn.utt)
+
+                if self.noise_type is not None:
+                    try:
+                        ctx_utt = self._add_noise(ctx_utt)
+                    except Exception as error:
+                        self.logger.warn(traceback.print_exc())
+
+                batch_ctx.append(self.pad_to(self.max_utt_len, ctx_utt, do_pad=self.pad_context))
+
+            ctx_utts.append(batch_ctx)
+            ctx_lens.append(len(batch_ctx))
+
+            # target response
+            out_utt = [t for idx, t in enumerate(out_row.utt)]
+            out_utts.append(out_utt)
+            out_lens.append(len(out_utt))
+
+            out_bs.append(out_row.bs)
+            out_db.append(out_row.db)
+            
+            if "raw_bs" in out_row:
+                raw_bs.append(out_row["raw_bs"])
+
+            # goal
+            goals.append(goal_row)
+            for i, d in enumerate(self.domains):
+                goal_lens[i].append(len(goal_row[d]))
+
+        batch_size = len(ctx_lens)
+        vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns
+        max_ctx_len = np.max(vec_ctx_lens)
+        if self.pad_context:
+            vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
+        else:
+            vec_ctx_utts = []
+        vec_out_bs = np.array(out_bs) # (batch_size, 94)
+        vec_out_db = np.array(out_db) # (batch_size, 30)
+        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
+        max_out_len = np.max(vec_out_lens)
+        vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)
+
+        max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens]
+        if max_goal_lens != min_goal_lens:
+            print('Fatal Error!')
+            exit(-1)
+        self.goal_lens = max_goal_lens
+        vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens]
+
+        for b_id in range(batch_size):
+            if self.pad_context:
+                vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
+            else:
+                vec_ctx_utts.append(ctx_utts[b_id])
+
+            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
+            for i, d in enumerate(self.domains):
+                vec_goals_list[i][b_id, :] = goals[b_id][d]
+
+        return Pack(context_lens=vec_ctx_lens, # (batch_size, )
+                    contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len)
+                    output_lens=vec_out_lens, # (batch_size, )
+                    outputs=vec_out_utts, # (batch_size, max_out_len)
+                    bs=vec_out_bs, # (batch_size, 94)
+                    raw_bs=raw_bs,
+                    db=vec_out_db, # (batch_size, 30)
+                    goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
+                    keys=keys)
+
+    def _add_noise(self, sequence):
+        num_tokens = np.ceil(self.noise_p * (len(sequence)-2))
+        indices = np.random.choice(range(1, len(sequence)-1), size=int(num_tokens), replace=False)
+        self.num_noised_tokens += num_tokens
+        self.noised_tokens_dist[num_tokens] += 1
+        self.num_examples += 1
+
+        if self.noise_type == NoiseType.REMOVAL.value:
+            return self._remove_tokens(sequence, indices)
+        elif self.noise_type == NoiseType.CHANGE.value:
+            return self._change_tokens(sequence, indices)
+
+    def _remove_tokens(self, sequence, indices):
+        for index in indices:
+            sequence[index] = self.ind_voc[UNK]
+
+        if self.remove_tokens:
+            sequence = list(filter(lambda x: x != self.ind_voc[UNK], sequence))
+
+        return sequence
+
+    def _change_tokens(self, sequence, indices):
+        for index in indices:
+            new_token_index = sequence[index]
+
+            while new_token_index == sequence[index]:
+                ind = int(len(self.tokens) * random.random())
+                new_token_index = self.ind_voc[self.tokens[ind]]
+            
+            sequence[index] = new_token_index
+        
+        return sequence
+
+class BeliefDbDataLoadersAE(BaseDataLoaders):
+    def __init__(self, name, data, config, noise=None, ind_voc=None, logger=None):
+        super(BeliefDbDataLoadersAE, self).__init__(name)
+        self.max_utt_len = config.max_utt_len
+        self.data, self.indexes, self.batch_indexes = self.flatten_dialog(data, config.backward_size)
+        self.data_size = len(self.data)
+        self.domains = ['hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi']
+        # self.act_types = ['bye', 'inform', 'nobook', 'nooffer', 'offerbook', 'offerbooked', 'recommend', 'reqmore', 'request', 'select', 'welcome']
+        if "ae_zero_pad" in config.keys():
+            self.zero_pad = config.ae_zero_pad
+        else:
+            self.zero_pad = False
+        
+        self.noise_type = noise
+
+        if self.noise_type is not None:
+            self.noise_p = config.noise_p
+            self.remove_tokens = config.remove_tokens
+
+            self.ind_voc = copy.deepcopy(ind_voc)
+            if config.no_special:
+                for token in SPECIAL_TOKENS:
+                    if token != UNK:
+                        del self.ind_voc[token]
+            self.tokens = list(self.ind_voc.keys())
+
+            self.num_noised_tokens = 0
+            self.noised_tokens_dist = defaultdict(int)
+            self.num_examples = 0
+
+        self.logger = logger
+
+
+
+    def flatten_dialog(self, data, backward_size):
+        results = []
+        indexes = []
+        batch_indexes = []
+        resp_set = set()
+        for dlg in data:
+            goal = dlg.goal
+            key = dlg.key
+            batch_index = []
+            for i in range(1, len(dlg.dlg)):
+                if dlg.dlg[i].speaker == USR:
+                    continue
+                e_idx = i
+                s_idx = max(0, e_idx - backward_size)
+                response = dlg.dlg[i].copy()
+                response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
+                resp_set.add(json.dumps(response.utt))
+                context = []
+                for turn in dlg.dlg[s_idx: e_idx]:
+                    turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
+                    context.append(turn)
+                results.append(Pack(context=context, response=response, goal=goal, key=key))
+                indexes.append(len(indexes))
+                batch_index.append(indexes[-1])
+            if len(batch_index) > 0:
+                batch_indexes.append(batch_index)
+        print("Unique resp {}".format(len(resp_set)))
+        return results, indexes, batch_indexes
+
+    def epoch_init(self, config, shuffle=True, verbose=True, fix_batch=False):
+        self.ptr = 0
+        if fix_batch:
+            self.batch_size = None
+            self.num_batch = len(self.batch_indexes)
+        else:
+            self.batch_size = config.batch_size
+            self.num_batch = self.data_size // config.batch_size
+            self.batch_indexes = []
+            for i in range(self.num_batch):
+                self.batch_indexes.append(self.indexes[i * self.batch_size: (i + 1) * self.batch_size])
+            if verbose:
+                print('Number of left over sample = %d' % (self.data_size - config.batch_size * self.num_batch))
+        if shuffle:
+            if fix_batch:
+                self._shuffle_batch_indexes()
+            else:
+                self._shuffle_indexes()
+
+        if verbose:
+            print('%s begins with %d batches' % (self.name, self.num_batch))
+
+    def _prepare_batch(self, selected_index):
+        rows = [self.data[idx] for idx in selected_index]
+
+        ctx_utts, ctx_lens = [], []
+        out_utts, out_lens = [], []
+        # out_act = []
+        out_bs, out_db = [] , []
+        goals, goal_lens = [], [[] for _ in range(len(self.domains))]
+        keys = []
+
+        raw_bs = []
+        
+        for row in rows:
+            in_row, out_row, goal_row = row.context, row.response, row.goal
+
+            # source context
+            keys.append(row.key)
+            
+            # batch_ctx = []
+            # for turn in in_row:
+                # batch_ctx.append(self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
+            # ctx_utts.append(batch_ctx)
+            # ctx_lens.append(len(batch_ctx))
+            
+            # target response
+            out_utt = [t for idx, t in enumerate(out_row.utt)]
+            out_utts.append(out_utt)
+            out_lens.append(len(out_utt))
+
+            if not self.zero_pad:
+                out_bs.append(out_row.bs)
+                out_db.append(out_row.db)
+            else:
+                out_bs.append([0] * 94)
+                out_db.append([0] * 30)
+            # out_act.append(out_row.act)
+
+            # for AE, input = output
+            # print(out_row.utt)
+            ctx_utt = copy.deepcopy(out_row.utt)
+            if self.noise_type is not None:
+                try:
+                    ctx_utt = self._add_noise(ctx_utt)
+                except Exception as error:
+                    self.logger.warn(traceback.print_exc())
+
+            batch_ctx = self.pad_to(self.max_utt_len, ctx_utt, do_pad=True)
+
+            # print(out_row.utt)
+            # pdb.set_trace()
+
+            ctx_utts.append(batch_ctx)
+            ctx_lens.append(len(batch_ctx))
+
+
+            if "raw_bs" in out_row:
+                raw_bs.append(out_row["raw_bs"])
+
+            # goal
+            goals.append(goal_row)
+            for i, d in enumerate(self.domains):
+                goal_lens[i].append(len(goal_row[d]))
+
+
+        batch_size = len(ctx_lens)
+        vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns
+        max_ctx_len = np.max(vec_ctx_lens)
+        vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
+        vec_out_bs = np.array(out_bs) # (batch_size, 94)
+        vec_out_db = np.array(out_db) # (batch_size, 30)
+        # vec_out_act = np.array(out_act) # (batch_size, 11)
+        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
+        max_out_len = np.max(vec_out_lens)
+        vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)
+
+        max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens], [min(ls) for ls in goal_lens]
+        if max_goal_lens != min_goal_lens:
+            print('Fatal Error!')
+            exit(-1)
+        self.goal_lens = max_goal_lens
+        vec_goals_list = [np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens]
+
+        for b_id in range(batch_size):
+            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
+            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
+            for i, d in enumerate(self.domains):
+                vec_goals_list[i][b_id, :] = goals[b_id][d]
+
+        return Pack(context_lens=vec_ctx_lens, # (batch_size, )
+                    contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len)
+                    output_lens=vec_out_lens, # (batch_size, )
+                    outputs=vec_out_utts, # (batch_size, max_out_len)
+                    raw_bs=raw_bs,
+                    bs=vec_out_bs, # (batch_size, 94)
+                    db=vec_out_db, # (batch_size, 30)
+                    # act=vec_out_act, #(batch_size, 11)
+                    goals_list=vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
+                    keys=keys)
+
+    def _add_noise(self, sequence):
+        num_tokens = np.ceil(self.noise_p * (len(sequence)-2))
+        indices = np.random.choice(range(1, len(sequence)-1), size=int(num_tokens), replace=False)
+        self.num_noised_tokens += num_tokens
+        self.noised_tokens_dist[num_tokens] += 1
+        self.num_examples += 1
+
+        if self.noise_type == NoiseType.REMOVAL.value:
+            return self._remove_tokens(sequence, indices)
+        elif self.noise_type == NoiseType.CHANGE.value:
+            return self._change_tokens(sequence, indices)
+
+    def _remove_tokens(self, sequence, indices):
+        for index in indices:
+            sequence[index] = self.ind_voc[UNK]
+
+        if self.remove_tokens:
+            sequence = list(filter(lambda x: x != self.ind_voc[UNK], sequence))
+
+        return sequence
+
+    def _change_tokens(self, sequence, indices):
+        for index in indices:
+            new_token_index = sequence[index]
+
+            while new_token_index == sequence[index]:
+                ind = int(len(self.tokens) * random.random())
+                new_token_index = self.ind_voc[self.tokens[ind]]
+            
+            sequence[index] = new_token_index
+        
+        return sequence
+
diff --git a/latent_dialog/dialog_task.py b/latent_dialog/dialog_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..f985d83bb985bb2a48626616ba34d34f714e0b72
--- /dev/null
+++ b/latent_dialog/dialog_task.py
@@ -0,0 +1,129 @@
+from latent_dialog.metric import MetricsContainer
+from latent_dialog.corpora import EOD, EOS
+from latent_dialog import evaluators
+
+
+class Dialog(object):
+    """Dialogue runner."""
+    def __init__(self, agents, args):
+        assert len(agents) == 2
+        self.agents = agents
+        self.system, self.user = agents
+        self.args = args
+        self.metrics = MetricsContainer()
+        self.dlg_evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', args)
+        self._register_metrics()
+
+    def _register_metrics(self):
+        """Registers valuable metrics."""
+        self.metrics.register_average('dialog_len')
+        self.metrics.register_average('sent_len')
+        self.metrics.register_average('reward')
+        self.metrics.register_time('time')
+
+    def _is_eod(self, out):
+        return len(out) == 2 and out[0] == EOD and out[1] == EOS
+
+    def _eval_dialog(self, conv, g_key, goal):
+        generated_dialog = dict()
+        generated_dialog[g_key] = {'goal': goal, 'log': list()}
+        for t_id, (name, utt) in enumerate(conv):
+            # assert utt[-1] == EOS, utt
+            if t_id % 2 == 0:
+                assert name == 'Baozi'
+            utt = ' '.join(utt[:-1])
+            if utt == EOD:
+                continue
+            generated_dialog[g_key]['log'].append({'text': utt})
+        report, success_r, match_r = self.dlg_evaluator.evaluateModel(generated_dialog, mode='rollout')
+        return success_r + match_r
+
+    def show_metrics(self):
+        return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()])
+
+    def run(self, g_key, goal):
+        """Runs one instance of the dialogue."""
+        # initialize agents by feeding in the goal
+        # initialize BOD utterance for each agent
+        for agent in self.agents:
+            agent.feed_goal(goal)
+            agent.bod_init()
+
+        # role assignment
+        reader, writer = self.system, self.user
+        begin_name = writer.name
+        print('begin_name = {}'.format(begin_name))
+
+        conv = []
+        # reset metrics
+        self.metrics.reset()
+        nturn = 0
+        while True:
+            nturn += 1
+            # produce an utterance
+            out_words = writer.write() # out: list of word, str, len = max_words
+            print('\t{} out_words = {}'.format(writer.name, ' '.join(out_words)))
+
+            self.metrics.record('sent_len', len(out_words))
+            # self.metrics.record('%s_unique' % writer.name, out_words)
+
+            # append the utterance to the conversation
+            conv.append((writer.name, out_words))
+            # make the other agent to read it
+            reader.read(out_words)
+            # check if the end of the conversation was generated
+            if self._is_eod(out_words):
+                break
+
+            if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn:
+                # return conv, 0
+                break
+
+            writer, reader = reader, writer
+
+        # evaluate dialog and produce success
+        reward = self._eval_dialog(conv, g_key, goal)
+        print('Reward = {}'.format(reward))
+        # perform update
+        self.system.update(reward)
+        self.metrics.record('time')
+        self.metrics.record('dialog_len', len(conv))
+        self.metrics.record('reward', int(reward))
+
+        print('='*50)
+        print(self.show_metrics())
+        print('='*50)
+        return conv, reward
+
+
+class DialogEval(Dialog):
+    def run(self, g_key, goal):
+        """Runs one instance of the dialogue."""
+        # initialize agents by feeding in the goal
+        # initialize BOD utterance for each agent
+        for agent in self.agents:
+            agent.feed_goal(goal)
+            agent.bod_init()
+
+        # role assignment
+        reader, writer = self.system, self.user
+        conv = []
+        nturn = 0
+        while True:
+            nturn += 1
+            # produce an utterance
+            out_words = writer.write()  # out: list of word, str, len = max_words
+            conv.append((writer.name, out_words))
+            # make the other agent to read it
+            reader.read(out_words)
+            # check if the end of the conversation was generated
+            if self._is_eod(out_words):
+                break
+
+            writer, reader = reader, writer
+            if self.args.max_nego_turn > 0 and nturn >= self.args.max_nego_turn:
+                return conv, 0
+
+        # evaluate dialog and produce success
+        reward = self._eval_dialog(conv, g_key, goal)
+        return conv, reward
diff --git a/latent_dialog/domain.py b/latent_dialog/domain.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d3ff99bcf99042bc65ec3f3fc0da96734b391b
--- /dev/null
+++ b/latent_dialog/domain.py
@@ -0,0 +1,124 @@
+import re
+import random
+import json
+
+
+def get_domain(name):
+    if name == 'object_division':
+        return ObjectDivisionDomain()
+    raise()
+
+
+class ObjectDivisionDomain(object):
+    def __init__(self):
+        self.item_pattern = re.compile('^item([0-9])=([0-9\-])+$')
+
+    def input_length(self):
+        return 3
+
+    def selection_length(self):
+        return 6
+
+    def generate_choices(self, inpt):
+        cnts, _ = self.parse_context(inpt)
+
+        def gen(cnts, idx=0, choice=[]):
+            if idx >= len(cnts):
+                left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)]
+                right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))]
+                return [left_choice + right_choice]
+            choices = []
+            for c in range(cnts[idx] + 1):
+                choice.append(c)
+                choices += gen(cnts, idx + 1, choice)
+                choice.pop()
+            return choices
+        choices = gen(cnts)
+        choices.append(['<no_agreement>'] * self.selection_length())
+        choices.append(['<disconnect>'] * self.selection_length())
+        return choices
+
+    def parse_context(self, ctx):
+        cnts = [int(n) for n in ctx[0::2]]
+        vals = [int(v) for v in ctx[1::2]]
+        return cnts, vals
+
+    def _to_int(self, x):
+        try:
+            return int(x)
+        except:
+            return 0
+
+    def score_choices(self, choices, ctxs):
+        assert len(choices) == len(ctxs)
+        # print('choices = {}'.format(choices))
+        # print('ctxs = {}'.format(ctxs))
+        cnts = [int(x) for x in ctxs[0][0::2]]
+        agree, scores = True, [0 for _ in range(len(ctxs))]
+        for i, n in enumerate(cnts):
+            for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)):
+                # taken = self._to_int(choice[i+3][-1])
+                taken = self._to_int(choice[i][-1])
+                n -= taken
+                scores[agent_id] += int(ctx[2 * i + 1]) * taken
+            agree = agree and (n == 0)
+        return agree, scores
+
+
+class ContextGenerator(object):
+    """Dialogue context generator. Generates contexes from the file."""
+    def __init__(self, context_file):
+        self.ctxs = []
+        with open(context_file, 'r') as f:
+            ctx_pair = []
+            for line in f:
+                ctx = line.strip().split()
+                ctx_pair.append(ctx)
+                if len(ctx_pair) == 2:
+                    self.ctxs.append(ctx_pair)
+                    ctx_pair = []
+
+    def sample(self):
+        return random.choice(self.ctxs)
+
+    def iter(self, nepoch=1):
+        for e in range(nepoch):
+            random.shuffle(self.ctxs)
+            for ctx in self.ctxs:
+                yield ctx
+
+    def total_size(self, nepoch):
+        return nepoch*len(self.ctxs)
+
+
+class ContextGeneratorEval(object):
+    """Dialogue context generator. Generates contexes from the file."""
+    def __init__(self, context_file):
+        self.ctxs = []
+        with open(context_file, 'r') as f:
+            ctx_pair = []
+            for line in f:
+                ctx = line.strip().split()
+                ctx_pair.append(ctx)
+                if len(ctx_pair) == 2:
+                    self.ctxs.append(ctx_pair)
+                    ctx_pair = []
+
+
+class TaskGoalGenerator(object):
+    def __init__(self, goal_file):
+        self.goals = []
+        data = json.load(open(goal_file))
+        for key, raw_dlg in data.items():
+            self.goals.append((key, raw_dlg['goal']))
+
+    def sample(self):
+        return random.choice(self.goals)
+
+    def iter(self, nepoch=1):
+        for e in range(nepoch):
+            random.shuffle(self.goals)
+            for goal in self.goals:
+                yield goal
+
+
diff --git a/latent_dialog/enc2dec/__init__.py b/latent_dialog/enc2dec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6caafab6a74e286872551db1a2edcc848eb00fd0
--- /dev/null
+++ b/latent_dialog/enc2dec/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# Author: Tiancheng Zhao
+# Date: 9/15/18
diff --git a/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e5601a38d9b18e633e2c8eb076ee990266f342a
Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/__init__.cpython-36.pyc differ
diff --git a/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0dc8e9504000fe70600b605639cc96f6a2bc5197
Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/base_modules.cpython-36.pyc differ
diff --git a/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c65603128e11acfb9750a779b4dae7be502e9f79
Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/decoders.cpython-36.pyc differ
diff --git a/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99cf032d687ea874ca05a9f135d616435cd10234
Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/encoders.cpython-36.pyc differ
diff --git a/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc b/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..887fe1642e54c64711c41b5f6aeb07360d40607b
Binary files /dev/null and b/latent_dialog/enc2dec/__pycache__/masked_decoder.cpython-36.pyc differ
diff --git a/latent_dialog/enc2dec/base_modules.py b/latent_dialog/enc2dec/base_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc7d98601596f147a25aa47332483f5502eaf9d3
--- /dev/null
+++ b/latent_dialog/enc2dec/base_modules.py
@@ -0,0 +1,68 @@
+import torch as th
+import torch.nn as nn
+import numpy as np
+from torch.nn.modules.module import _addindent
+
+def summary(model, show_weights=True, show_parameters=True):
+    """
+    Summarizes torch model by showing trainable parameters and weights.
+    """
+    tmpstr = model.__class__.__name__ + ' (\n'
+    total_params = 0
+    for key, module in model._modules.items():
+        # if it contains layers let call it recursively to get params
+        # and weights
+        if type(module) in [
+            th.nn.modules.container.Container,
+            th.nn.modules.container.Sequential
+        ]:
+            modstr = summary(module)
+        else:
+            modstr = module.__repr__()
+        modstr = _addindent(modstr, 2)
+
+        params = sum([np.prod(p.size()) for p in module.parameters()])
+        weights = tuple([tuple(p.size()) for p in module.parameters()])
+        total_params += params
+
+        tmpstr += '  (' + key + '): ' + modstr
+        if show_weights:
+            tmpstr += ', weights={}'.format(weights)
+        if show_parameters:
+            tmpstr += ', parameters={}'.format(params)
+        tmpstr += '\n'
+
+    tmpstr = tmpstr + ') Total Parameters={}'.format(total_params)
+    return tmpstr
+
+
+class BaseRNN(nn.Module):
+    KEY_ATTN_SCORE = 'attention_score'
+    KEY_SEQUENCE = 'sequence'
+
+    def __init__(self, input_dropout_p, rnn_cell, 
+                     input_size, hidden_size, num_layers, 
+                     output_dropout_p, bidirectional):
+        super(BaseRNN, self).__init__()
+        self.input_dropout = nn.Dropout(p=input_dropout_p)
+        if rnn_cell.lower() == 'lstm':
+            self.rnn_cell = nn.LSTM
+        elif rnn_cell.lower() == 'gru':
+            self.rnn_cell = nn.GRU
+        else:
+            raise ValueError('Unsupported RNN Cell Type: {0}'.format(rnn_cell))
+        self.rnn = self.rnn_cell(input_size=input_size, 
+                                 hidden_size=hidden_size,
+                                 num_layers=num_layers, 
+                                 batch_first=True, 
+                                 dropout=output_dropout_p, 
+                                 bidirectional=bidirectional)
+
+        # TODO Trick for initializing LSTM gate parameters
+        if rnn_cell.lower() == 'lstm':
+            for names in self.rnn._all_weights:
+                for name in filter(lambda n: 'bias' in n, names):
+                    bias = getattr(self.rnn, name)
+                    n = bias.size(0)
+                    start, end = n // 4, n // 2
+                    bias.data[start:end].fill_(1.)
diff --git a/latent_dialog/enc2dec/classifier.py b/latent_dialog/enc2dec/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..359d500a3db66136a3063300b5b8f562f6d975ab
--- /dev/null
+++ b/latent_dialog/enc2dec/classifier.py
@@ -0,0 +1,103 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from latent_dialog.enc2dec.base_modules import BaseRNN
+
+
+class EncoderGRUATTN(BaseRNN):
+    def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths):
+        super(EncoderGRUATTN, self).__init__(input_dropout_p=input_dropout_p, 
+                                             rnn_cell=rnn_cell, 
+                                             input_size=input_size, 
+                                             hidden_size=hidden_size, 
+                                             num_layers=num_layers, 
+                                             output_dropout_p=output_dropout_p, 
+                                             bidirectional=bidirectional)
+        self.variable_lengths = variable_lengths
+        self.nhid_attn = hidden_size
+        self.output_size = hidden_size*2 if bidirectional else hidden_size
+
+        # attention to combine selection hidden states
+        self.attn = nn.Sequential(
+            nn.Linear(2 * hidden_size, hidden_size), 
+            nn.Tanh(), 
+            nn.Linear(hidden_size, 1)
+        )
+
+    def forward(self, residual_var, input_var, turn_feat, mask=None, init_state=None, input_lengths=None):
+        # residual_var: (batch_size, max_dlg_len, 2*utt_cell_size)
+        # input_var: (batch_size, max_dlg_len, dlg_cell_size)
+
+        # TODO switch of mask
+        # mask = None
+        
+        require_embed = True
+        if require_embed:
+            # input_cat = th.cat([input_var, residual_var], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size)
+            input_cat = th.cat([input_var, residual_var, turn_feat], 2) # (batch_size, max_dlg_len, dlg_cell_size+2*utt_cell_size)
+        else:
+            # input_cat = th.cat([input_var], 2)
+            input_cat = th.cat([input_var, turn_feat], 2)
+        if mask is not None:
+            input_mask = mask.view(input_cat.size(0), input_cat.size(1), 1) # (batch_size, max_dlg_len*max_utt_len, 1)
+            input_cat = th.mul(input_cat, input_mask)
+        embedded = self.input_dropout(input_cat)
+        
+        require_rnn = True
+        if require_rnn:
+            if init_state is not None:
+                h, _ = self.rnn(embedded, init_state)
+            else:
+                h, _ = self.rnn(embedded) # (batch_size, max_dlg_len, 2*nhid_attn)
+    
+            logit = self.attn(h.contiguous().view(-1, 2*self.nhid_attn)).view(h.size(0), h.size(1)) # (batch_size, max_dlg_len)
+            # if mask is not None:
+            #     logit_mask = mask.view(input_cat.size(0), input_cat.size(1))
+            #     logit_mask = -999.0 * logit_mask
+            #     logit = logit_mask + logit
+    
+            prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(h) # (batch_size, max_dlg_len, 2*nhid_attn)
+            attn = th.sum(th.mul(h, prob), 1) # (batch_size, 2*nhid_attn)
+            
+            return attn
+
+        else:
+            logit = self.attn(embedded.contiguous().view(input_cat.size(0)*input_cat.size(1), -1)).view(input_cat.size(0), input_cat.size(1))
+            if mask is not None:
+                logit_mask = mask.view(input_cat.size(0), input_cat.size(1))
+                logit_mask = -999.0 * logit_mask
+                logit = logit_mask + logit
+
+            prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(embedded) # (batch_size, max_dlg_len, 2*nhid_attn)
+            attn = th.sum(th.mul(embedded, prob), 1) # (batch_size, 2*nhid_attn)
+            
+            return attn
+
+
+class FeatureProjecter(nn.Module):
+    def __init__(self, input_dropout_p, input_size, output_size):
+        super(FeatureProjecter, self).__init__()
+        self.input_dropout = nn.Dropout(p=input_dropout_p)
+        self.sel_encoder = nn.Sequential(
+            nn.Linear(input_size, output_size), 
+            nn.Tanh()
+        )
+
+    def forward(self, goals_h, attn_outs):
+        h = th.cat([attn_outs, goals_h], 1) # (batch_size, 2*nhid_attn+goal_nhid)
+        h = self.input_dropout(h)
+        h = self.sel_encoder.forward(h) # (batch_size, nhid_sel)
+        return h
+
+
+class SelectionClassifier(nn.Module):
+    def __init__(self, selection_length, input_size, output_size):
+        super(SelectionClassifier, self).__init__()
+        self.sel_decoders = nn.ModuleList()
+        for _ in range(selection_length):
+            self.sel_decoders.append(nn.Linear(input_size, output_size))
+
+    def forward(self, proj_outs):
+        outs = [decoder.forward(proj_outs).unsqueeze(1) for decoder in self.sel_decoders] # outcome_len*(batch_size, 1, outcome_vocab_size)
+        outs = th.cat(outs, 1) # (batch_size, outcome_len, outcome_vocab_size)
+        return outs
diff --git a/latent_dialog/enc2dec/decoders.py b/latent_dialog/enc2dec/decoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08a3492506fbf7aa5ae9f5325a096e8e0698c82
--- /dev/null
+++ b/latent_dialog/enc2dec/decoders.py
@@ -0,0 +1,575 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.autograd import Variable
+import numpy as np
+from latent_dialog.enc2dec.base_modules import BaseRNN
+from latent_dialog.utils import cast_type, LONG, FLOAT
+from latent_dialog.corpora import DECODING_MASKED_TOKENS, EOS
+import pdb
+
+
+TEACH_FORCE = 'teacher_forcing'
+TEACH_GEN = 'teacher_gen'
+GEN = 'gen'
+GEN_VALID = 'gen_valid'
+
+
+class Attention(nn.Module):
+    def __init__(self, dec_cell_size, ctx_cell_size, attn_mode, project):
+        super(Attention, self).__init__()
+        self.dec_cell_size = dec_cell_size
+        self.ctx_cell_size = ctx_cell_size
+        self.attn_mode = attn_mode
+        if project:
+            self.linear_out = nn.Linear(dec_cell_size+ctx_cell_size, dec_cell_size)
+        else:
+            self.linear_out = None
+
+        if attn_mode == 'general':
+            self.dec_w = nn.Linear(dec_cell_size, ctx_cell_size)
+        elif attn_mode == 'cat':
+            self.dec_w = nn.Linear(dec_cell_size, dec_cell_size)
+            self.attn_w = nn.Linear(ctx_cell_size, dec_cell_size)
+            self.query_w = nn.Linear(dec_cell_size, 1)
+
+    def forward(self, output, context):
+        # output: (batch_size, output_seq_len, dec_cell_size)
+        # context: (batch_size, max_ctx_len, ctx_cell_size)
+        batch_size = output.size(0)
+        max_ctx_len = context.size(1)
+
+        if self.attn_mode == 'dot':
+            attn = th.bmm(output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len)
+        elif self.attn_mode == 'general':
+            mapped_output = self.dec_w(output) # (batch_size, output_seq_len, ctx_cell_size)
+            attn = th.bmm(mapped_output, context.transpose(1, 2)) # (batch_size, output_seq_len, max_ctx_len)
+        elif self.attn_mode == 'cat':
+            mapped_output = self.dec_w(output) # (batch_size, output_seq_len, dec_cell_size)
+            mapped_attn = self.attn_w(context) # (batch_size, max_ctx_len, dec_cell_size)
+            tiled_output = mapped_output.unsqueeze(2).repeat(1, 1, max_ctx_len, 1) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size)
+            tiled_attn = mapped_attn.unsqueeze(1) # (batch_size, 1, max_ctx_len, dec_cell_size)
+            fc1 = th.tanh(tiled_output+tiled_attn) # (batch_size, output_seq_len, max_ctx_len, dec_cell_size)
+            attn = self.query_w(fc1).squeeze(-1) # (batch_size, otuput_seq_len, max_ctx_len)
+        else:
+            raise ValueError('Unknown attention mode')
+
+        # TODO mask
+        # if self.mask is not None:
+
+        attn = F.softmax(attn.view(-1, max_ctx_len), dim=1).view(batch_size, -1, max_ctx_len) # (batch_size, output_seq_len, max_ctx_len)
+        mix = th.bmm(attn, context) # (batch_size, output_seq_len, ctx_cell_size)
+        combined = th.cat((mix, output), dim=2) # (batch_size, output_seq_len, dec_cell_size+ctx_cell_size)
+        if self.linear_out is None:
+            return combined, attn
+        else:
+            output = th.tanh(
+                self.linear_out(combined.view(-1, self.dec_cell_size+self.ctx_cell_size))).view(
+                batch_size, -1, self.dec_cell_size) # (batch_size, output_seq_len, dec_cell_size)
+            return output, attn
+
+
+class DecoderRNN(BaseRNN):
+    def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p,
+                 bidirectional, vocab_size, use_attn, ctx_cell_size, attn_mode, sys_id, eos_id, use_gpu,
+                 max_dec_len, embedding=None):
+
+        super(DecoderRNN, self).__init__(input_dropout_p=input_dropout_p, 
+                                         rnn_cell=rnn_cell, 
+                                         input_size=input_size, 
+                                         hidden_size=hidden_size, 
+                                         num_layers=num_layers, 
+                                         output_dropout_p=output_dropout_p, 
+                                         bidirectional=bidirectional)
+
+        # TODO embedding is None or not
+        if embedding is None:
+            self.embedding = nn.Embedding(vocab_size, input_size)
+        else:
+            self.embedding = embedding
+
+        # share parameters between encoder and decoder
+        # self.rnn = ctx_encoder.rnn
+        # self.FC = nn.Linear(input_size, utt_encoder.output_size)
+
+        self.use_attn = use_attn
+        if self.use_attn:
+            self.attention = Attention(dec_cell_size=hidden_size, 
+                                       ctx_cell_size=ctx_cell_size, 
+                                       attn_mode=attn_mode, 
+                                       project=True)
+        
+        self.dec_cell_size = hidden_size
+        self.output_size = vocab_size
+        self.project = nn.Linear(self.dec_cell_size, self.output_size)
+        self.log_softmax = F.log_softmax
+
+        self.sys_id = sys_id
+        self.eos_id = eos_id
+        self.use_gpu = use_gpu
+        self.max_dec_len = max_dec_len
+
+    def forward(self, batch_size, dec_inputs, dec_init_state, attn_context, mode, gen_type, beam_size, goal_hid=None):
+        # dec_inputs: (batch_size, response_size-1)
+        # attn_context: (batch_size, max_ctx_len, ctx_cell_size)
+        # goal_hid: (batch_size, goal_nhid)
+
+        ret_dict = dict()
+
+        if self.use_attn:
+            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()
+
+        if mode == GEN:
+            dec_inputs = None
+
+        if gen_type != 'beam':
+            beam_size = 1
+
+        if dec_inputs is not None:
+            decoder_input = dec_inputs
+        else:
+            # prepare the BOS inputs
+            with th.no_grad():
+                bos_var = Variable(th.LongTensor([self.sys_id]))
+            bos_var = cast_type(bos_var, LONG, self.use_gpu)
+            decoder_input = bos_var.expand(batch_size*beam_size, 1) # (batch_size, 1)
+
+        if mode == GEN and gen_type == 'beam':
+            # TODO if beam search, repeat the initial states of the RNN
+            pass
+        else:
+            decoder_hidden_state = dec_init_state
+
+        prob_outputs = [] # list of logprob | max_dec_len*(batch_size, 1, vocab_size)
+        symbol_outputs = [] # list of word ids | max_dec_len*(batch_size, 1)
+        # back_pointers = []
+        # lengths = blabla...
+
+        def decode(step, cum_sum, step_output, step_attn):
+            prob_outputs.append(step_output)
+            step_output_slice = step_output.squeeze(1) # (batch_size, vocab_size)
+            if self.use_attn:
+                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)
+
+            if gen_type == 'greedy':
+                _, symbols = step_output_slice.topk(1) # (batch_size, 1)
+            elif gen_type == 'sample':
+                # TODO FIXME
+                # symbols = self.gumbel_max(step_output_slice)
+                pass
+            elif gen_type == 'beam':
+                # TODO
+                pass
+            else:
+                raise ValueError('Unsupported decoding mode')
+
+            symbol_outputs.append(symbols)
+
+            return cum_sum, symbols
+
+        if mode == TEACH_FORCE:
+            prob_outputs, decoder_hidden_state, attn = self.forward_step(input_var=decoder_input, hidden_state=decoder_hidden_state, encoder_outputs=attn_context, goal_hid=goal_hid)
+        else:
+            # do free running here
+            cum_sum = None
+            for step in range(self.max_dec_len):
+                # Input:
+                #   decoder_input: (batch_size, 1)
+                #   decoder_hidden_state: tuple: (h, c)
+                #   attn_context: (batch_size, max_ctx_len, ctx_cell_size)
+                #   goal_hid: (batch_size, goal_nhid)
+                # Output:
+                #   decoder_output: (batch_size, 1, vocab_size)
+                #   decoder_hidden_state: tuple: (h, c)
+                #   step_attn: (batch_size, 1, max_ctx_len)
+                decoder_output, decoder_hidden_state, step_attn = self.forward_step(decoder_input, decoder_hidden_state, attn_context, goal_hid=goal_hid)
+                cum_sum, symbols = decode(step, cum_sum, decoder_output, step_attn)
+                decoder_input = symbols
+
+            prob_outputs = th.cat(prob_outputs, dim=1) # (batch_size, max_dec_len, vocab_size)
+
+            # back tracking to recover the 1-best in beam search
+            # if gen_type == 'beam':
+
+        ret_dict[DecoderRNN.KEY_SEQUENCE] = symbol_outputs
+
+        # prob_outputs: (batch_size, max_dec_len, vocab_size)
+        # decoder_hidden_state: tuple: (h, c)
+        # ret_dict[DecoderRNN.KEY_ATTN_SCORE]: max_dec_len*(batch_size, 1, max_ctx_len)
+        # ret_dict[DecoderRNN.KEY_SEQUENCE]: max_dec_len*(batch_size, 1) 
+        return prob_outputs, decoder_hidden_state, ret_dict
+
+    def forward_step(self, input_var, hidden_state, encoder_outputs, goal_hid):
+        # input_var: (batch_size, response_size-1 i.e. output_seq_len)
+        # hidden_state: tuple: (h, c)
+        # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size)
+        # goal_hid: (batch_size, goal_nhid)
+        batch_size, output_seq_len = input_var.size()
+        embedded = self.embedding(input_var) # (batch_size, output_seq_len, embedding_dim)
+
+        # add goals
+        if goal_hid is not None:
+            goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (batch_size, 1, goal_nhid)
+            goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (batch_size, output_seq_len, goal_nhid)
+            embedded = th.cat([embedded, goal_rep], dim=2) # (batch_size, output_seq_len, embedding_dim+goal_nhid)
+
+        embedded = self.input_dropout(embedded)
+
+        # ############
+        # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1)
+
+        # output: (batch_size, output_seq_len, dec_cell_size)
+        # hidden: tuple: (h, c)
+        output, hidden_s = self.rnn(embedded, hidden_state)
+
+        attn = None
+        if self.use_attn:
+            # output: (batch_size, output_seq_len, dec_cell_size)
+            # encoder_outputs: (batch_size, max_ctx_len, ctx_cell_size)
+            # attn: (batch_size, output_seq_len, max_ctx_len)
+            output, attn = self.attention(output, encoder_outputs)
+
+        logits = self.project(output.contiguous().view(-1, self.dec_cell_size)) # (batch_size*output_seq_len, vocab_size)
+        prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size)
+        return prediction, hidden_s, attn
+
+    # special for rl
+    def _step(self, input_var, hidden_state, encoder_outputs, goal_hid):
+        # input_var: (1, 1)
+        # hidden_state: tuple: (h, c)
+        # encoder_outputs: (1, max_dlg_len, dlg_cell_size)
+        # goal_hid: (1, goal_nhid)
+        batch_size, output_seq_len = input_var.size()
+        embedded = self.embedding(input_var) # (1, 1, embedding_dim)
+
+        if goal_hid is not None:
+            goal_hid = goal_hid.view(goal_hid.size(0), 1, goal_hid.size(1)) # (1, 1, goal_nhid)
+            goal_rep = goal_hid.repeat(1, output_seq_len, 1) # (1, 1, goal_nhid)
+            embedded = th.cat([embedded, goal_rep], dim=2) # (1, 1, embedding_dim+goal_nhid)
+
+        embedded = self.input_dropout(embedded)
+
+        # ############
+        # embedded = self.FC(embedded.view(-1, embedded.size(-1))).view(batch_size, output_seq_len, -1)
+
+        # output: (1, 1, dec_cell_size)
+        # hidden: tuple: (h, c)
+        output, hidden_s = self.rnn(embedded, hidden_state)
+
+        attn = None
+        if self.use_attn:
+            # output: (1, 1, dec_cell_size)
+            # encoder_outputs: (1, max_dlg_len, dlg_cell_size)
+            # attn: (1, 1, max_dlg_len)
+            output, attn = self.attention(output, encoder_outputs)
+
+        logits = self.project(output.view(-1, self.dec_cell_size)) # (1*1, vocab_size)
+        prediction = logits.view(batch_size, output_seq_len, -1) # (1, 1, vocab_size)
+        # prediction = self.log_softmax(logits, dim=logits.dim()-1).view(batch_size, output_seq_len, -1) # (batch_size, output_seq_len, vocab_size)
+        return prediction, hidden_s
+
+    # special for rl
+    def write(self, input_var, hidden_state, encoder_outputs, max_words, vocab, stop_tokens, goal_hid=None, mask=True,
+              decoding_masked_tokens=DECODING_MASKED_TOKENS):
+        # input_var: (1, 1)
+        # hidden_state: tuple: (h, c)
+        # encoder_outputs: max_dlg_len*(1, 1, dlg_cell_size)
+        # goal_hid: (1, goal_nhid)
+        logprob_outputs = [] # list of logprob | max_dec_len*(1, )
+        symbol_outputs = [] # list of word ids | max_dec_len*(1, )
+        decoder_input = input_var
+        decoder_hidden_state = hidden_state
+        if type(encoder_outputs) is list:
+            encoder_outputs = th.cat(encoder_outputs, 1) # (1, max_dlg_len, dlg_cell_size)
+        # print('encoder_outputs.size() = {}'.format(encoder_outputs.size()))
+        
+        if mask:
+            special_token_mask = Variable(th.FloatTensor([-999. if token in decoding_masked_tokens else 0. for token in vocab]))
+            special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, )
+
+        def _sample(dec_output, num_i):
+            # dec_output: (1, 1, vocab_size), need to softmax and log_softmax
+            dec_output = dec_output.view(-1) # (vocab_size, )
+            # TODO temperature
+            prob = F.softmax(dec_output/0.6, dim=0) # (vocab_size, )
+            logprob = F.log_softmax(dec_output, dim=0) # (vocab_size, )
+            symbol = prob.multinomial(num_samples=1).detach() # (1, )
+            # _, symbol = prob.topk(1) # (1, )
+            _, tmp_symbol = prob.topk(1) # (1, )
+            # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()]))
+            # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()]))
+            logprob = logprob.gather(0, symbol) # (1, )
+            return logprob, symbol
+
+        for i in range(max_words):
+            decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid)
+            # disable special tokens from being generated in a normal turn
+            if mask:
+                decoder_output += special_token_mask.expand(1, 1, -1)
+            logprob, symbol = _sample(decoder_output, i)
+            logprob_outputs.append(logprob)
+            symbol_outputs.append(symbol)
+            decoder_input = symbol.view(1, -1)
+
+            if vocab[symbol.item()] in stop_tokens:
+                break
+
+        assert len(logprob_outputs) == len(symbol_outputs)
+        # logprob_list = [t.item() for t in logprob_outputs]
+        logprob_list = logprob_outputs
+        symbol_list = [t.item() for t in symbol_outputs]
+        return logprob_list, symbol_list
+
+    # For MultiWoz RL
+    def forward_rl(self, batch_size, dec_init_state, attn_context, vocab, max_words, goal_hid=None, mask=True, temp=0.1):
+        # prepare the BOS inputs
+        with th.no_grad():
+            bos_var = Variable(th.LongTensor([self.sys_id]))
+        bos_var = cast_type(bos_var, LONG, self.use_gpu)
+        decoder_input = bos_var.expand(batch_size, 1) # (1, 1)
+        decoder_hidden_state = dec_init_state # tuple: (h, c)
+        encoder_outputs = attn_context # (1, ctx_len, ctx_cell_size)
+
+        logprob_outputs = [] # list of logprob | max_dec_len*(1, )
+        symbol_outputs = [] # list of word ids | max_dec_len*(1, )
+
+        if mask:
+            special_token_mask = Variable(th.FloatTensor([-999. if token in DECODING_MASKED_TOKENS else 0. for token in vocab]))
+            special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, )
+
+        def _sample(dec_output, num_i):
+            # dec_output: (1, 1, vocab_size), need to softmax and log_softmax
+            dec_output = dec_output.view(batch_size, -1) # (batch_size, vocab_size, )
+            prob = F.softmax(dec_output/temp, dim=1) # (batch_size, vocab_size, )
+            logprob = F.log_softmax(dec_output, dim=1) # (batch_size, vocab_size, )
+            symbol = prob.multinomial(num_samples=1).detach() # (batch_size, 1)
+            # _, symbol = prob.topk(1) # (1, )
+            _, tmp_symbol = prob.topk(1) # (1, )
+            # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()]))
+            # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()]))
+            logprob = logprob.gather(1, symbol) # (1, )
+            return logprob, symbol
+
+        stopped_samples = set()
+        for i in range(max_words):
+            decoder_output, decoder_hidden_state = self._step(decoder_input, decoder_hidden_state, encoder_outputs, goal_hid)
+            # disable special tokens from being generated in a normal turn
+            if mask:
+                decoder_output += special_token_mask.expand(1, 1, -1)
+            logprob, symbol = _sample(decoder_output, i)
+            logprob_outputs.append(logprob)
+            symbol_outputs.append(symbol)
+            decoder_input = symbol.view(batch_size, -1)
+            for b_id in range(batch_size):
+                if vocab[symbol[b_id].item()] == EOS:
+                    stopped_samples.add(b_id)
+
+            if len(stopped_samples) == batch_size:
+                break
+
+        assert len(logprob_outputs) == len(symbol_outputs)
+        symbol_outputs = th.cat(symbol_outputs, dim=1).cpu().data.numpy().tolist()
+        logprob_outputs = th.cat(logprob_outputs, dim=1)
+        logprob_list = []
+        symbol_list = []
+        for b_id in range(batch_size):
+            b_logprob = []
+            b_symbol = []
+            for t_id in range(logprob_outputs.shape[1]):
+                symbol = symbol_outputs[b_id][t_id]
+                if vocab[symbol] == EOS and t_id != 0:
+                    break
+
+                b_symbol.append(symbol_outputs[b_id][t_id])
+                b_logprob.append(logprob_outputs[b_id][t_id])
+
+            logprob_list.append(b_logprob)
+            symbol_list.append(b_symbol)
+
+        # TODO backward compatible, if batch_size == 1, we remove the nested structure
+        if batch_size == 1:
+            logprob_list = logprob_list[0]
+            symbol_list = symbol_list[0]
+
+        return logprob_list, symbol_list
+
+class DecoderPointerGen(BaseRNN):
+
+    def __init__(self, vocab_size, max_len, input_size, hidden_size, sos_id,
+                 eos_id, n_layers=1, rnn_cell='lstm', input_dropout_p=0,
+                 dropout_p=0, attn_mode='cat', attn_size=None, use_gpu=True,
+                 embedding=None):
+
+        super(DecoderPointerGen, self).__init__(vocab_size, input_size,
+                                                hidden_size, input_dropout_p,
+                                                dropout_p, n_layers, rnn_cell, False)
+
+        self.output_size = vocab_size
+        self.max_length = max_len
+        self.eos_id = eos_id
+        self.sos_id = sos_id
+        self.use_gpu = use_gpu
+        self.attn_size = attn_size
+
+        if embedding is None:
+            self.embedding = nn.Embedding(self.output_size, self.input_size)
+        else:
+            self.embedding = embedding
+
+        self.attention = Attention(self.hidden_size, attn_size, attn_mode,
+                                   project=True)
+
+        self.project = nn.Linear(self.hidden_size, self.output_size)
+        self.sentinel = nn.Parameter(torch.randn((1, 1, attn_size)), requires_grad=True)
+        self.register_parameter('sentinel', self.sentinel)
+
+    def forward_step(self, input_var, hidden, attn_ctxs, attn_words, ctx_embed=None):
+        """
+        attn_size: number of context to attend
+        :param input_var: 
+        :param hidden: 
+        :param attn_ctxs: batch_size x attn_size+1 x ctx_size. If None, then leave it empty
+        :param attn_words: batch_size x attn_size
+        :return: 
+        """
+        # we enable empty attention context
+        batch_size = input_var.size(0)
+        seq_len = input_var.size(1)
+        embedded = self.embedding(input_var)
+        if ctx_embed is not None:
+            embedded += ctx_embed
+
+        embedded = self.input_dropout(embedded)
+        output, hidden = self.rnn(embedded, hidden)
+
+        if attn_ctxs is None:
+            # pointer network here
+            logits = self.project(output.contiguous().view(-1, self.hidden_size))
+            predicted_softmax = F.log_softmax(logits, dim=1)
+            return predicted_softmax, None, hidden, None, None
+        else:
+            attn_size = attn_words.size(1)
+            combined_output, attn = self.attention(output, attn_ctxs)
+
+            # output: batch_size x seq_len x hidden_size
+            # attn: batch_size x seq_len x (attn_size+1)
+
+            # pointer network here
+            rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1)
+            g = attn[:, :, 0].contiguous()
+            ptr_attn = attn[:, :, 1:].contiguous()
+            ptr_softmax = Variable(torch.zeros((batch_size * seq_len * attn_size, self.vocab_size)))
+            ptr_softmax = cast_type(ptr_softmax, FLOAT, self.use_gpu)
+
+            # convert words and ids into 1D
+            flat_attn_words = attn_words.unsqueeze(1).repeat(1, seq_len, 1).view(-1, 1)
+            flat_attn = ptr_attn.view(-1, 1)
+
+            # fill in the attention into ptr_softmax
+            ptr_softmax = ptr_softmax.scatter_(1, flat_attn_words, flat_attn)
+            ptr_softmax = ptr_softmax.view(batch_size * seq_len, attn_size, self.vocab_size)
+            ptr_softmax = torch.sum(ptr_softmax, dim=1)
+
+            # mix the softmax from rnn and pointer
+            mixture_softmax = rnn_softmax * g.view(-1, 1) + ptr_softmax
+
+            # take the log to get logsoftmax
+            logits = torch.log(mixture_softmax.clamp(min=1e-8))
+            predicted_softmax = logits.view(batch_size, seq_len, -1)
+            ptr_softmax = ptr_softmax.view(batch_size, seq_len, -1)
+
+            return predicted_softmax, ptr_softmax, hidden, ptr_attn, g
+
+    def forward(self, batch_size, attn_context, attn_words,
+                inputs=None, init_state=None, mode=TEACH_FORCE,
+                gen_type='greedy', ctx_embed=None):
+
+        # sanity checks
+        ret_dict = dict()
+
+        if mode == GEN:
+            inputs = None
+
+        if inputs is not None:
+            decoder_input = inputs
+        else:
+            # prepare the BOS inputs
+            bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True)
+            bos_var = cast_type(bos_var, LONG, self.use_gpu)
+            decoder_input = bos_var.expand(batch_size, 1)
+
+        # append sentinel to the attention
+        if attn_context is not None:
+            attn_context = torch.cat([self.sentinel.expand(batch_size, 1, self.attn_size),
+                                      attn_context], dim=1)
+
+        decoder_hidden = init_state
+        decoder_outputs = [] # a list of logprob
+        sequence_symbols = [] # a list word ids
+        attentions = []
+        pointer_gs = []
+        pointer_outputs = []
+        lengths = np.array([self.max_length] * batch_size)
+
+        def decode(step, step_output):
+            decoder_outputs.append(step_output)
+            step_output_slice = step_output.squeeze(1)
+
+            if gen_type == 'greedy':
+                symbols = step_output_slice.topk(1)[1]
+            elif gen_type == 'sample':
+                symbols = self.gumbel_max(step_output_slice)
+            else:
+                raise ValueError("Unsupported decoding mode")
+
+            sequence_symbols.append(symbols)
+
+            eos_batches = symbols.data.eq(self.eos_id)
+            if eos_batches.dim() > 0:
+                eos_batches = eos_batches.cpu().view(-1).numpy()
+                update_idx = ((lengths > di) & eos_batches) != 0
+                lengths[update_idx] = len(sequence_symbols)
+            return symbols
+
+        # Manual unrolling is used to support random teacher forcing.
+        # If teacher_forcing_ratio is True or False instead of a probability,
+        # the unrolling can be done in graph
+        if mode == TEACH_FORCE:
+            pred_softmax, ptr_softmax, decoder_hidden, attn, step_g = self.forward_step(
+                decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed)
+
+            # in teach forcing mode, we don't need symbols.
+            attentions = attn
+            decoder_outputs = pred_softmax
+            pointer_gs = step_g
+            pointer_outputs = ptr_softmax
+
+        else:
+            # do free running here
+            for di in range(self.max_length):
+                pred_softmax, ptr_softmax, decoder_hidden, step_attn, step_g = self.forward_step(
+                    decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed)
+
+                symbols = decode(di, pred_softmax)
+
+                # append the results into ctx dictionary
+                attentions.append(step_attn)
+                pointer_gs.append(step_g)
+                pointer_outputs.append(ptr_softmax)
+                decoder_input = symbols
+
+            # make list be a tensor
+            decoder_outputs = torch.cat(decoder_outputs, dim=1)
+            pointer_outputs = torch.cat(pointer_outputs, dim=1)
+            pointer_gs = torch.cat(pointer_gs, dim=1)
+
+        # save the decoded sequence symbols and sequence length
+        ret_dict[self.KEY_ATTN_SCORE] = attentions
+        ret_dict[self.KEY_SEQUENCE] = sequence_symbols
+        ret_dict[self.KEY_LENGTH] = lengths
+        ret_dict[self.KEY_G] = pointer_gs
+        ret_dict[self.KEY_PTR_SOFTMAX] = pointer_outputs
+        ret_dict[self.KEY_PTR_CTX] = attn_words
+
+        return decoder_outputs, decoder_hidden, ret_dict
diff --git a/latent_dialog/enc2dec/encoders.py b/latent_dialog/enc2dec/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..c73bbfd6488a17683f891cd8cefa699b015131a1
--- /dev/null
+++ b/latent_dialog/enc2dec/encoders.py
@@ -0,0 +1,215 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from latent_dialog.enc2dec.base_modules import BaseRNN
+
+
+class EncoderRNN(BaseRNN):
+    def __init__(self, input_dropout_p, rnn_cell, input_size, hidden_size, num_layers, output_dropout_p, bidirectional, variable_lengths):
+        super(EncoderRNN, self).__init__(input_dropout_p=input_dropout_p, 
+                                         rnn_cell=rnn_cell, 
+                                         input_size=input_size, 
+                                         hidden_size=hidden_size, 
+                                         num_layers=num_layers, 
+                                         output_dropout_p=output_dropout_p, 
+                                         bidirectional=bidirectional)
+        self.variable_lengths = variable_lengths
+        self.output_size = hidden_size*2 if bidirectional else hidden_size
+
+    def forward(self, input_var, init_state=None, input_lengths=None, goals=None):
+        # add goals
+        if goals is not None:
+            batch_size, max_ctx_len, ctx_nhid = input_var.size()
+            goals = goals.view(goals.size(0), 1, goals.size(1))
+            goals_rep = goals.repeat(1, max_ctx_len, 1).view(batch_size, max_ctx_len, -1) # (batch_size, max_ctx_len, goal_nhid)
+            input_var = th.cat([input_var, goals_rep], dim=2)
+
+        embedded = self.input_dropout(input_var)
+
+        if self.variable_lengths:
+            embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths,
+                                                         batch_first=True)
+        if init_state is not None:
+            output, hidden = self.rnn(embedded, init_state)
+        else:
+            output, hidden = self.rnn(embedded)
+        if self.variable_lengths:
+            output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
+
+        return output, hidden
+
+
+class RnnUttEncoder(nn.Module):
+    def __init__(self, vocab_size, embedding_dim, feat_size, goal_nhid, rnn_cell,
+                 utt_cell_size, num_layers, input_dropout_p, output_dropout_p,
+                 bidirectional, variable_lengths, use_attn, embedding=None):
+        super(RnnUttEncoder, self).__init__()
+        if embedding is None:
+            self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
+        else:
+            self.embedding = embedding
+
+        self.rnn = EncoderRNN(input_dropout_p=input_dropout_p,
+                              rnn_cell=rnn_cell, 
+                              input_size=embedding_dim+feat_size+goal_nhid, 
+                              hidden_size=utt_cell_size, 
+                              num_layers=num_layers, 
+                              output_dropout_p=output_dropout_p, 
+                              bidirectional=bidirectional, 
+                              variable_lengths=variable_lengths)
+
+        self.utt_cell_size = utt_cell_size
+        self.multiplier = 2 if bidirectional else 1
+        self.output_size = self.multiplier * self.utt_cell_size
+        self.use_attn = use_attn
+        if self.use_attn:
+            self.key_w = nn.Linear(self.output_size, self.utt_cell_size)
+            self.query = nn.Linear(self.utt_cell_size, 1)
+
+    def forward(self, utterances, feats=None, init_state=None, goals=None):
+        batch_size, max_ctx_len, max_utt_len = utterances.size()
+        # get word embeddings
+        flat_words = utterances.view(-1, max_utt_len) # (batch_size*max_ctx_len, max_utt_len)
+        word_embeddings = self.embedding(flat_words) # (batch_size*max_ctx_len, max_utt_len, embedding_dim)
+        flat_mask = th.sign(flat_words).float()
+        # add features
+        if feats is not None:
+            flat_feats = feats.view(-1, 1) # (batch_size*max_ctx_len, 1)
+            flat_feats = flat_feats.unsqueeze(1).repeat(1, max_utt_len, 1) # (batch_size*max_ctx_len, max_utt_len, 1)
+            word_embeddings = th.cat([word_embeddings, flat_feats], dim=2) # (batch_size*max_ctx_len, max_utt_len, embedding_dim+1)
+
+        # add goals
+        if goals is not None:
+            goals = goals.view(goals.size(0), 1, 1, goals.size(1))
+            goals_rep = goals.repeat(1, max_ctx_len, max_utt_len, 1).view(batch_size*max_ctx_len, max_utt_len, -1) # (batch_size*max_ctx_len, max_utt_len, goal_nhid)
+            word_embeddings = th.cat([word_embeddings, goals_rep], dim=2)
+
+        # enc_outs: (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size)
+        # enc_last: (num_layers*num_directions, batch_size*max_ctx_len, utt_cell_size)
+        enc_outs, enc_last = self.rnn(word_embeddings, init_state=init_state)
+
+        if self.use_attn:
+            fc1 = th.tanh(self.key_w(enc_outs)) # (batch_size*max_ctx_len, max_utt_len, utt_cell_size)
+            attn = self.query(fc1).squeeze(2)
+            # (batch_size*max_ctx_len, max_utt_len)
+            attn = F.softmax(attn, attn.dim()-1) # (batch_size*max_ctx_len, max_utt_len, 1)
+            attn = attn * flat_mask
+            attn = (attn / (th.sum(attn, dim=1, keepdim=True)+1e-10)).unsqueeze(2)
+            utt_embedded = attn * enc_outs # (batch_size*max_ctx_len, max_utt_len, num_directions*utt_cell_size)
+            utt_embedded = th.sum(utt_embedded, dim=1) # (batch_size*max_ctx_len, num_directions*utt_cell_size)
+        else:
+            # FIXME bug for multi-layer
+            attn = None
+            utt_embedded = enc_last.transpose(0, 1).contiguous() # (batch_size*max_ctx_lens, num_layers*num_directions, utt_cell_size)
+            utt_embedded = utt_embedded.view(-1, self.output_size) # (batch_size*max_ctx_len*num_layers, num_directions*utt_cell_size)
+
+        utt_embedded = utt_embedded.view(batch_size, max_ctx_len, self.output_size)
+        return utt_embedded, word_embeddings.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1), \
+               enc_outs.contiguous().view(batch_size, max_ctx_len*max_utt_len, -1)
+
+
+class MlpGoalEncoder(nn.Module):
+    def __init__(self, goal_vocab_size, k, nembed, nhid, init_range):
+        super(MlpGoalEncoder, self).__init__()
+
+        # create separate embedding for counts and values
+        self.cnt_enc = nn.Embedding(goal_vocab_size, nembed)
+        self.val_enc = nn.Embedding(goal_vocab_size, nembed)
+
+        self.encoder = nn.Sequential(
+            nn.Tanh(),
+            nn.Linear(k*nembed, nhid) 
+        )
+
+        self.cnt_enc.weight.data.uniform_(-init_range, init_range)
+        self.val_enc.weight.data.uniform_(-init_range, init_range)
+        self._init_cont(self.encoder, init_range)
+
+    def _init_cont(self, cont, init_range):
+        """initializes a container uniformly."""
+        for m in cont:
+            if hasattr(m, 'weight'):
+                m.weight.data.uniform_(-init_range, init_range)
+            if hasattr(m, 'bias'):
+                m.bias.data.fill_(0)
+
+    def forward(self, goal):
+        # goal: (batch_size, goal_len)
+        goal = goal.transpose(0, 1).contiguous() # (goal_len, batch_size)
+        idx = np.arange(goal.size(0) // 2)
+        
+        # extract counts and values
+        cnt_idx = Variable(th.from_numpy(2 * idx + 0))
+        val_idx = Variable(th.from_numpy(2 * idx + 1))
+
+        if goal.is_cuda:
+            cnt_idx = cnt_idx.type(th.cuda.LongTensor)
+            val_idx = val_idx.type(th.cuda.LongTensor)
+        else:
+            cnt_idx = cnt_idx.type(th.LongTensor)
+            val_idx = val_idx.type(th.LongTensor)
+
+        cnt = goal.index_select(0, cnt_idx) # (3, batch_size)
+        val = goal.index_select(0, val_idx) # (3, batch_size)
+
+        # embed counts and values
+        cnt_emb = self.cnt_enc(cnt) # (3, batch_size, nembed)
+        val_emb = self.val_enc(val) # (3, batch_size, nembed)
+
+        # element wise multiplication to get a hidden state
+        h = th.mul(cnt_emb, val_emb) # (3, batch_size, nembed)
+        # run the hidden state through the MLP
+        h = h.transpose(0, 1).contiguous().view(goal.size(1), -1) # (batch_size, 3*nembed)
+        goal_h = self.encoder(h) # (batch_size, nhid)
+
+        return goal_h
+
+
+class TaskMlpGoalEncoder(nn.Module):
+    def __init__(self, goal_vocab_sizes, nhid, init_range):
+        super(TaskMlpGoalEncoder, self).__init__()
+        
+        self.encoder = nn.ModuleList()
+        for v_size in goal_vocab_sizes:
+            domain_encoder = nn.Sequential(
+                nn.Linear(v_size, nhid), 
+                nn.Tanh()
+            )
+            self._init_cont(domain_encoder, init_range)
+            self.encoder.append(domain_encoder)
+
+    def _init_cont(self, cont, init_range):
+        """initializes a container uniformly."""
+        for m in cont:
+            if hasattr(m, 'weight'):
+                m.weight.data.uniform_(-init_range, init_range)
+            if hasattr(m, 'bias'):
+                m.bias.data.fill_(0)
+
+    def forward(self, goals_list):
+        # goals_list: list of tensor, 7*(batch_size, goal_len), goal_len varies among differnet domains
+        outs = [encoder.forward(goal) for goal, encoder in zip(goals_list, self.encoder)] # 7*(batch_size, goal_nhid)
+        outs = th.sum(th.stack(outs), dim=0) # (batch_size, goal_nhid)
+        return outs
+
+
+class SelfAttn(nn.Module):
+    def __init__(self, hidden_size):
+        super(SelfAttn, self).__init__()
+        self.query = nn.Linear(hidden_size, 1)
+
+    def forward(self, keys, values, attn_mask=None):
+        """
+        :param attn_inputs: batch_size x time_len x hidden_size
+        :param attn_mask: batch_size x time_len
+        :return: summary state
+        """
+        alpha = F.softmax(self.query(keys), dim=1)
+        if attn_mask is not None:
+            alpha = alpha * attn_mask.unsqueeze(2)
+            alpha = alpha / th.sum(alpha, dim=1, keepdim=True)
+
+        summary = th.sum(values * alpha, dim=1)
+        return summary
\ No newline at end of file
diff --git a/latent_dialog/evaluators.py b/latent_dialog/evaluators.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4a699cb2312fbfcc5ed047d9f008c965c8fae8c
--- /dev/null
+++ b/latent_dialog/evaluators.py
@@ -0,0 +1,1230 @@
+from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
+import math
+import numpy as np
+import scipy.stats as st
+import latent_dialog.normalizer.delexicalize as delex
+from latent_dialog.utils import get_tokenize, get_detokenize
+from collections import Counter, defaultdict
+from nltk.util import ngrams
+from latent_dialog.corpora import SYS, USR, BOS, EOS
+from sklearn.feature_extraction.text import CountVectorizer
+import json
+from latent_dialog.normalizer.delexicalize import normalize
+from latent_dialog.augpt_utils import *
+import json
+import sqlite3
+import os
+import random
+import logging
+import pdb
+import re
+from tqdm import tqdm
+from sklearn.multiclass import OneVsRestClassifier
+from sklearn.linear_model import SGDClassifier
+from sklearn import metrics
+from nltk.translate import bleu_score
+from nltk.translate.bleu_score import SmoothingFunction
+from scipy.stats import gmean
+
+
+class BaseEvaluator(object):
+    def initialize(self):
+        raise NotImplementedError
+
+    def add_example(self, ref, hyp):
+        raise NotImplementedError
+
+    def get_report(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @staticmethod
+    def _get_prec_recall(tp, fp, fn):
+        precision = tp / (tp + fp + 10e-20)
+        recall = tp / (tp + fn + 10e-20)
+        f1 = 2 * precision * recall / (precision + recall + 1e-20)
+        return precision, recall, f1
+
+    @staticmethod
+    def _get_tp_fp_fn(label_list, pred_list):
+        tp = len([t for t in pred_list if t in label_list])
+        fp = max(0, len(pred_list) - tp)
+        fn = max(0, len(label_list) - tp)
+        return tp, fp, fn
+
+class BLEUScorer(object):
+    ## BLEU score calculator via GentScorer interface
+    ## it calculates the BLEU-4 by taking the entire corpus in
+    ## Calulate based multiple candidates against multiple references
+    def score(self, hypothesis, corpus, n=1):
+        # containers
+        count = [0, 0, 0, 0]
+        clip_count = [0, 0, 0, 0]
+        r = 0
+        c = 0
+        weights = [0.25, 0.25, 0.25, 0.25]
+
+        # accumulate ngram statistics
+        for hyps, refs in zip(hypothesis, corpus):
+            # if type(hyps[0]) is list:
+            #    hyps = [hyp.split() for hyp in hyps[0]]
+            # else:
+            #    hyps = [hyp.split() for hyp in hyps]
+
+            # refs = [ref.split() for ref in refs]
+            hyps = [hyps]
+            # Shawn's evaluation
+            # refs[0] = [u'GO_'] + refs[0] + [u'EOS_']
+            # hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_']
+
+            for idx, hyp in enumerate(hyps):
+                for i in range(4):
+                    # accumulate ngram counts
+                    hypcnts = Counter(ngrams(hyp, i + 1))
+                    cnt = sum(hypcnts.values())
+                    count[i] += cnt
+
+                    # compute clipped counts
+                    max_counts = {}
+                    for ref in refs:
+                        refcnts = Counter(ngrams(ref, i + 1))
+                        for ng in hypcnts:
+                            max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
+                    clipcnt = dict((ng, min(count, max_counts[ng])) \
+                                   for ng, count in hypcnts.items())
+                    clip_count[i] += sum(clipcnt.values())
+
+                # accumulate r & c
+                bestmatch = [1000, 1000]
+                for ref in refs:
+                    if bestmatch[0] == 0: break
+                    diff = abs(len(ref) - len(hyp))
+                    if diff < bestmatch[0]:
+                        bestmatch[0] = diff
+                        bestmatch[1] = len(ref)
+                r += bestmatch[1]
+                c += len(hyp)
+                if n == 1:
+                    break
+        # computing bleu score
+        p0 = 1e-7
+        bp = 1 if c > r else math.exp(1 - float(r) / float(c))
+        p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \
+                for i in range(4)]
+        s = math.fsum(w * math.log(p_n) \
+                      for w, p_n in zip(weights, p_ns) if p_n)
+        bleu = bp * math.exp(s)
+        return bleu
+
+class BleuEvaluator(BaseEvaluator):
+    def __init__(self, data_name):
+        self.data_name = data_name
+        self.labels = list()
+        self.hyps = list()
+
+    def initialize(self):
+        self.labels = list()
+        self.hyps = list()
+
+    def add_example(self, ref, hyp):
+        self.labels.append(ref)
+        self.hyps.append(hyp)
+
+    def get_report(self):
+        tokenize = get_tokenize()
+        print('Generate report for {} samples'.format(len(self.hyps)))
+        refs, hyps = [], []
+        for label, hyp in zip(self.labels, self.hyps):
+            # label = label.replace(EOS, '')
+            # hyp = hyp.replace(EOS, '')
+            # ref_tokens = tokenize(label)[1:]
+            # hyp_tokens = tokenize(hyp)[1:]
+            ref_tokens = tokenize(label)
+            hyp_tokens = tokenize(hyp)
+            refs.append([ref_tokens])
+            hyps.append(hyp_tokens)
+        bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1)
+        report = '\n===== BLEU = %f =====\n' % (bleu,)
+        return '\n===== REPORT FOR DATASET {} ====={}'.format(self.data_name, report)
+
+class MultiWozDB(object):
+    # loading databases
+    domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']  # , 'police']
+    dbs = {}
+    CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '')
+
+    for domain in domains:
+        db = os.path.join(CUR_DIR, 'data/norm-multi-woz/db/{}-dbase.db'.format(domain))
+        conn = sqlite3.connect(db)
+        c = conn.cursor()
+        dbs[domain] = c
+
+    def queryResultVenues(self, domain, turn, real_belief=False):
+        # query the db
+        sql_query = "select * from {}".format(domain)
+
+        if real_belief == True:
+            items = turn.items()
+        else:
+            items = turn['metadata'][domain]['semi'].items()
+
+        flag = True
+        for key, val in items:
+            if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care":
+                pass
+            else:
+                if flag:
+                    sql_query += " where "
+                    val2 = val.replace("'", "''")
+                    val2 = normalize(val2)
+                    if key == 'leaveAt':
+                        sql_query += r" " + key + " > " + r"'" + val2 + r"'"
+                    elif key == 'arriveBy':
+                        sql_query += r" " + key + " < " + r"'" + val2 + r"'"
+                    else:
+                        sql_query += r" " + key + "=" + r"'" + val2 + r"'"
+                    flag = False
+                else:
+                    val2 = val.replace("'", "''")
+                    val2 = normalize(val2)
+                    if key == 'leaveAt':
+                        sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
+                    elif key == 'arriveBy':
+                        sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
+                    else:
+                        sql_query += r" and " + key + "=" + r"'" + val2 + r"'"
+
+        try:  # "select * from attraction  where name = 'queens college'"
+            return self.dbs[domain].execute(sql_query).fetchall()
+        except:
+            return []
+
+class MultiWozEvaluator(BaseEvaluator):
+    CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '')
+    logger = logging.getLogger()
+    def __init__(self, data_name, config):
+        self.data_name = data_name
+        self.slot_dict = delex.prepareSlotValuesIndependent()
+        delex_path = config.train_path.replace("train_dials.json", "") + [p for p in  os.listdir(config.train_path.replace("train_dials.json", "")) if 'delex' in p][0]
+        self.delex_dialogues = json.load(open(delex_path))
+        self.db = MultiWozDB()
+        self.labels = list()
+        self.hyps = list()
+
+    def initialize(self):
+        self.labels = list()
+        self.hyps = list()
+
+    def add_example(self, ref, hyp):
+        self.labels.append(ref)
+        self.hyps.append(hyp)
+
+    def _parseGoal(self, goal, d, domain):
+        """Parses user goal into dictionary format."""
+        goal[domain] = {}
+        goal[domain] = {'informable': [], 'requestable': [], 'booking': [], 'failed': {}}
+        if 'info' in d['goal'][domain]:
+        # if d['goal'][domain].has_key('info'):
+            if domain == 'train':
+                # we consider dialogues only where train had to be booked!
+                if 'book' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('book'):
+                    goal[domain]['requestable'].append('reference')
+                if 'reqt' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('reqt'):
+                    if 'trainID' in d['goal'][domain]['reqt']:
+                        goal[domain]['requestable'].append('id')
+            else:
+                if 'reqt' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('reqt'):
+                    for s in d['goal'][domain]['reqt']:  # addtional requests:
+                        if s in ['phone', 'address', 'postcode', 'reference', 'id']:
+                            # ones that can be easily delexicalized
+                            goal[domain]['requestable'].append(s)
+                if 'book' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('book'):
+                    goal[domain]['requestable'].append("reference")
+
+            goal[domain]["informable"] = d['goal'][domain]['info']
+            if 'book' in d['goal'][domain]:
+            # if d['goal'][domain].has_key('book'):
+                goal[domain]["booking"] = d['goal'][domain]['book']
+
+        return goal
+
+    def _evaluateGeneratedDialogue(self, dialog, goal, realDialogue, real_requestables, soft_acc=False):
+        """Evaluates the dialogue created by the model.
+        First we load the user goal of the dialogue, then for each turn
+        generated by the system we look for key-words.
+        For the Inform rate we look whether the entity was proposed.
+        For the Success rate we look for requestables slots"""
+        # for computing corpus success
+        requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+
+        # CHECK IF MATCH HAPPENED
+        provided_requestables = {}
+        venue_offered = {}
+        domains_in_goal = []
+
+        for domain in goal.keys():
+            venue_offered[domain] = []
+            provided_requestables[domain] = []
+            domains_in_goal.append(domain)
+
+        for t, sent_t in enumerate(dialog):
+            for domain in goal.keys():
+                # for computing success
+                if '[' + domain + '_name]' in sent_t or '_id' in sent_t: # undo delexicalization if system generates [domain_name] or [domain_id]
+                    if domain in ['restaurant', 'hotel', 'attraction', 'train']: 
+                        # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
+                        # in this case, look for the actual offered venues based on true belief state
+                        venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1])
+                        # venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True)
+
+                        # if venue has changed
+                        if len(venue_offered[domain]) == 0 and venues:
+                            venue_offered[domain] = random.sample(venues, 1)
+                        else:
+                            flag = False
+                            for ven in venues:
+                                if venue_offered[domain][0] == ven:
+                                    flag = True
+                                    break
+                            if not flag and venues:  # sometimes there are no results so sample won't work
+                                # print venues
+                                venue_offered[domain] = random.sample(venues, 1)
+                    else:  # not limited so we can provide one
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+                # ATTENTION: assumption here - we didn't provide phone or address twice! etc
+                for requestable in requestables:
+                    if requestable == 'reference':
+                        if domain + '_reference' in sent_t:
+                            if 'restaurant_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -5] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'hotel_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -3] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'train_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -1] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            else:
+                                provided_requestables[domain].append('reference')
+                    else:
+                        if '[' + domain + '_' + requestable + ']' in sent_t:
+                            provided_requestables[domain].append(requestable)
+
+        # if name was given in the task
+        for domain in goal.keys():
+            # if name was provided for the user, the match is being done automatically
+            # assumption doesn't always hold, maybe it's better if name is provided by user that it is ignored?
+            if 'info' in realDialogue['goal'][domain]:
+                if 'name' in realDialogue['goal'][domain]['info']:
+                    venue_offered[domain] = '[' + domain + '_name]'
+
+            # special domains - entity does not need to be provided
+            if domain in ['taxi', 'police', 'hospital']:
+                venue_offered[domain] = '[' + domain + '_name]'
+
+            if domain == 'train':
+                if not venue_offered[domain]:
+                    # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']:
+                    if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']:
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+        """
+        Given all inform and requestable slots
+        we go through each domain from the user goal
+        and check whether right entity was provided and
+        all requestable slots were given to the user.
+        The dialogue is successful if that's the case for all domains.
+        """
+        # HARD EVAL
+        stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                 'taxi': [0, 0, 0],
+                 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+
+        match = 0
+        success = 0
+        # MATCH (between offered venue by generated dialogue and venue actually fitting to the criteria)
+        for domain in goal.keys():
+            match_stat = 0
+            if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True)
+                # if venue offered is not dict
+                if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: # yields false positive, does not match what is offered with real dialogue?
+                    match += 1
+                    match_stat = 1
+                # if venue offered is dict
+                elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: # actually checks the offered venue
+                    match += 1
+                    match_stat = 1
+            # other domains
+            else:
+                if domain + '_name]' in venue_offered[domain]: # yields false positive, in terms of occurence and correctness
+                    match += 1
+                    match_stat = 1
+
+            stats[domain][0] = match_stat
+            stats[domain][2] = 1
+
+        if soft_acc:
+            match = float(match)/len(goal.keys())
+        else:
+            # only count success if all domain has matches
+            if match == len(goal.keys()):
+                match = 1.0
+            else:
+                match = 0.0
+
+        # SUCCESS (whether the requestable info in realDialogue is generated by the system)
+        # if no match, then success is assumed to be 0
+        if match == 1.0:
+            for domain in domains_in_goal:
+                success_stat = 0
+                domain_success = 0
+                if len(real_requestables[domain]) == 0: # if there is no requestable, assume to be succesful. incorrect, cause does not count false positives. 
+                    success += 1
+                    success_stat = 1
+                    stats[domain][1] = success_stat
+                    continue
+                # if values in sentences are super set of requestables
+                # pdb.set_trace()
+                # print(provided_requestables[domain], real_requestables[domain])
+                for request in set(provided_requestables[domain]):
+                    if request in real_requestables[domain]:
+                        domain_success += 1
+
+                if domain_success >= len(real_requestables[domain]):
+                    success += 1
+                    success_stat = 1
+
+                stats[domain][1] = success_stat
+
+            # final eval
+            if soft_acc:
+                success = float(success)/len(real_requestables)
+            else:
+                if success >= len(real_requestables):
+                    success = 1
+                else:
+                    success = 0
+
+        # rint requests, 'DIFF', requests_real, 'SUCC', success
+        return success, match, stats
+
+    def _evaluateGeneratedDialogue_new(self, dialog, goal, realDialogue, real_requestables, soft_acc=False):
+        """Evaluates the dialogue created by the model.
+        First we load the user goal of the dialogue, then for each turn
+        generated by the system we look for key-words.
+        For the Inform rate we look whether the entity was proposed.
+        For the Success rate we look for requestables slots"""
+        # for computing corpus success
+        requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+
+        # CHECK IF MATCH HAPPENED
+        provided_requestables = {}
+        venue_offered = {}
+        domains_in_goal = []
+
+        for domain in goal.keys():
+            venue_offered[domain] = []
+            provided_requestables[domain] = []
+            domains_in_goal.append(domain)
+
+        for t, sent_t in enumerate(dialog):
+            for domain in goal.keys():
+                # for computing success
+                if '[' + domain + '_name]' in sent_t or '_id' in sent_t:
+                    if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                        # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
+                        venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1])
+                        # venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True)
+
+                        # if venue has changed
+                        if len(venue_offered[domain]) == 0 and venues:
+                            venue_offered[domain] = random.sample(venues, 1)
+                        else:
+                            flag = False
+                            for ven in venues:
+                                if venue_offered[domain][0] == ven:
+                                    flag = True
+                                    break
+                            if not flag and venues:  # sometimes there are no results so sample won't work
+                                # print venues
+                                venue_offered[domain] = random.sample(venues, 1)
+                    else:  # not limited so we can provide one
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+                # ATTENTION: assumption here - we didn't provide phone or address twice! etc
+                for requestable in requestables:
+                    if requestable == 'reference':
+                        if domain + '_reference' in sent_t:
+                            if 'restaurant_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -5] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'hotel_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -3] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'train_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -1] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            else:
+                                provided_requestables[domain].append('reference')
+                    else:
+                        if domain + '_' + requestable + ']' in sent_t:
+                            provided_requestables[domain].append(requestable)
+
+        # if name was given in the task
+        for domain in goal.keys():
+            # if name was provided for the user, the match is being done automatically
+            # if realDialogue['goal'][domain].has_key('info'):
+            if 'info' in realDialogue['goal'][domain]:
+                # if realDialogue['goal'][domain]['info'].has_key('name'):
+                if 'name' in realDialogue['goal'][domain]['info']:
+                    venue_offered[domain] = '[' + domain + '_name]'
+
+            # special domains - entity does not need to be provided
+            if domain in ['taxi', 'police', 'hospital']:
+                venue_offered[domain] = '[' + domain + '_name]'
+
+            # the original method
+            # if domain == 'train':
+            #     if not venue_offered[domain]:
+            #         # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']:
+            #         if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']:
+            #             venue_offered[domain] = '[' + domain + '_name]'
+
+            # Wrong one in HDSA
+            # if domain == 'train':
+            #     if not venue_offered[domain]:
+            #         if goal[domain]['requestable'] and 'id' not in goal[domain]['requestable']:
+            #             venue_offered[domain] = '[' + domain + '_name]'
+
+            # if id was not requested but train was found we dont want to override it to check if we booked the right train
+            if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']):
+                venue_offered[domain] = '[' + domain + '_name]'
+
+        """
+        Given all inform and requestable slots
+        we go through each domain from the user goal
+        and check whether right entity was provided and
+        all requestable slots were given to the user.
+        The dialogue is successful if that's the case for all domains.
+        """
+        # HARD EVAL
+        stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                 'taxi': [0, 0, 0],
+                 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+
+        match = 0
+        success = 0
+        # MATCH
+        for domain in goal.keys():
+            match_stat = 0
+            if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True)
+                if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+                elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues:
+                    match += 1
+                    match_stat = 1
+            else:
+                if domain + '_name]' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+
+            stats[domain][0] = match_stat
+            stats[domain][2] = 1
+
+        if soft_acc:
+            match = float(match)/len(goal.keys())
+        else:
+            if match == len(goal.keys()):
+                match = 1.0
+            else:
+                match = 0.0
+
+        # SUCCESS
+        if match == 1.0:
+            for domain in domains_in_goal:
+                success_stat = 0
+                domain_success = 0
+                if len(real_requestables[domain]) == 0:
+                    success += 1
+                    success_stat = 1
+                    stats[domain][1] = success_stat
+                    continue
+                # if values in sentences are super set of requestables
+                for request in set(provided_requestables[domain]):
+                    if request in real_requestables[domain]:
+                        domain_success += 1
+
+                if domain_success >= len(real_requestables[domain]):
+                    success += 1
+                    success_stat = 1
+
+                stats[domain][1] = success_stat
+
+            # final eval
+            if soft_acc:
+                success = float(success)/len(real_requestables)
+            else:
+                if success >= len(real_requestables):
+                    success = 1
+                else:
+                    success = 0
+
+        # print requests, 'DIFF', requests_real, 'SUCC', success
+
+       
+
+        return success, match, stats
+
+    def _evaluateRealDialogue(self, dialog, filename, soft=False):
+        """Evaluation of the real dialogue from corpus.
+        First we loads the user goal and then go through the dialogue history.
+        Similar to evaluateGeneratedDialogue above."""
+        domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police']
+        requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+
+        # get the list of domains in the goal
+        domains_in_goal = []
+        goal = {}
+        for domain in domains:
+            if dialog['goal'][domain]:
+                goal = self._parseGoal(goal, dialog, domain)
+                domains_in_goal.append(domain)
+
+        # compute corpus success
+        real_requestables = {}
+        provided_requestables = {}
+        venue_offered = {}
+        for domain in goal.keys():
+            provided_requestables[domain] = []
+            venue_offered[domain] = []
+            real_requestables[domain] = goal[domain]['requestable']
+
+        # iterate each turn
+        m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1]
+        for t in range(len(m_targetutt)):
+            for domain in domains_in_goal:
+                sent_t = m_targetutt[t]
+                # for computing match - where there are limited entities
+                if domain + '_name' in sent_t or '_id' in sent_t:
+                    if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                        # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
+                        venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1])
+
+                        # if venue has changed
+                        if len(venue_offered[domain]) == 0 and venues:
+                            venue_offered[domain] = random.sample(venues, 1)
+                        else:
+                            flag = False
+                            for ven in venues:
+                                if venue_offered[domain][0] == ven:
+                                    flag = True
+                                    break
+                            if not flag and venues:  # sometimes there are no results so sample won't work
+                                # print venues
+                                venue_offered[domain] = random.sample(venues, 1)
+                    else:  # not limited so we can provide one
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+                for requestable in requestables:
+                    # check if reference could be issued
+                    if requestable == 'reference':
+                        if domain + '_reference' in sent_t:
+                            if 'restaurant_reference' in sent_t:
+                                if dialog['log'][t * 2]['db_pointer'][-5] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'hotel_reference' in sent_t:
+                                if dialog['log'][t * 2]['db_pointer'][-3] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                                    # return goal, 0, match, real_requestables
+                            elif 'train_reference' in sent_t:
+                                if dialog['log'][t * 2]['db_pointer'][-1] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            else:
+                                provided_requestables[domain].append('reference')
+                    else:
+                        if domain + '_' + requestable in sent_t:
+                            provided_requestables[domain].append(requestable)
+
+        # offer was made?
+        for domain in domains_in_goal:
+            # if name was provided for the user, the match is being done automatically
+            # if dialog['goal'][domain].has_key('info'):
+            if 'info' in dialog['goal'][domain]:
+                # if dialog['goal'][domain]['info'].has_key('name'):
+                if 'name' in dialog['goal'][domain]['info']:
+                    venue_offered[domain] = '[' + domain + '_name]'
+
+            # special domains - entity does not need to be provided
+            if domain in ['taxi', 'police', 'hospital']:
+                venue_offered[domain] = '[' + domain + '_name]'
+
+            # if id was not requested but train was found we dont want to override it to check if we booked the right train
+            if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']):
+                venue_offered[domain] = '[' + domain + '_name]'
+
+        # HARD (0-1) EVAL
+        stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                 'taxi': [0, 0, 0],
+                 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+
+        match, success = 0, 0
+        # MATCH
+        for domain in goal.keys():
+            match_stat = 0
+            if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True)
+                # print(goal_venues)
+                if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+                elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues:
+                    match += 1
+                    match_stat = 1
+
+            else:
+                if domain + '_name' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+
+            stats[domain][0] = match_stat
+            stats[domain][2] = 1
+
+        if not soft:
+            if match == len(goal.keys()):
+                match = 1
+            else:
+                match = 0
+        else:
+            match = float(match) / len(goal.keys())
+
+        # SUCCESS
+        if match == 1: # this or match > 0?
+            for domain in domains_in_goal:
+                domain_success = 0
+                success_stat = 0
+                if len(real_requestables[domain]) == 0:
+                    # check that
+                    success += 1
+                    success_stat = 1
+                    stats[domain][1] = success_stat
+                    continue
+                # if values in sentences are super set of requestables
+                for request in set(provided_requestables[domain]):
+                    if request in real_requestables[domain]:
+                        domain_success += 1
+
+                if domain_success >= len(real_requestables[domain]):
+                    success += 1
+                    success_stat = 1
+
+                stats[domain][1] = success_stat
+
+            # final eval
+            if success >= len(real_requestables):
+                success = 1
+            else:
+                if soft:
+                    success = float(success) / len(real_requestables)
+                else:
+                    success = 0
+
+        return goal, success, match, real_requestables, stats
+
+    def _evaluateRolloutDialogue(self, dialog):
+        domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police']
+        requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+
+        # get the list of domains in the goal
+        domains_in_goal = []
+        goal = {}
+        for domain in domains:
+            if dialog['goal'][domain]:
+                goal = self._parseGoal(goal, dialog, domain)
+                domains_in_goal.append(domain)
+
+        # compute corpus success
+        real_requestables = {}
+        provided_requestables = {}
+        venue_offered = {}
+        for domain in goal.keys():
+            provided_requestables[domain] = []
+            venue_offered[domain] = []
+            real_requestables[domain] = goal[domain]['requestable']
+
+        # iterate each turn
+        m_targetutt = [turn['text'] for idx, turn in enumerate(dialog['log']) if idx % 2 == 1]
+        for t in range(len(m_targetutt)):
+            for domain in domains_in_goal:
+                sent_t = m_targetutt[t]
+                # for computing match - where there are limited entities
+                if domain + '_name' in sent_t or domain+'_id' in sent_t:
+                    if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                        venue_offered[domain] = '[' + domain + '_name]'
+                        """
+                        venues = self.db.queryResultVenues(domain, dialog['log'][t * 2 + 1])
+                        if len(venue_offered[domain]) == 0 and venues:
+                            venue_offered[domain] = random.sample(venues, 1)
+                        else:
+                            flag = False
+                            for ven in venues:
+                                if venue_offered[domain][0] == ven:
+                                    flag = True
+                                    break
+                            if not flag and venues:  # sometimes there are no results so sample won't work
+                                # print venues
+                                venue_offered[domain] = random.sample(venues, 1)
+                        """
+                    else:  # not limited so we can provide one
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+                for requestable in requestables:
+                    # check if reference could be issued
+                    if requestable == 'reference':
+                        if domain + '_reference' in sent_t:
+                            if 'restaurant_reference' in sent_t:
+                                if True or dialog['log'][t * 2]['db_pointer'][-5] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'hotel_reference' in sent_t:
+                                if True or dialog['log'][t * 2]['db_pointer'][-3] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+                                    # return goal, 0, match, real_requestables
+                            elif 'train_reference' in sent_t:
+                                if True or dialog['log'][t * 2]['db_pointer'][-1] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            else:
+                                provided_requestables[domain].append('reference')
+                    else:
+                        if domain + '_' + requestable in sent_t:
+                            provided_requestables[domain].append(requestable)
+
+        # offer was made?
+        for domain in domains_in_goal:
+            # if name was provided for the user, the match is being done automatically
+            # if dialog['goal'][domain].has_key('info'):
+            if 'info' in dialog['goal'][domain]:
+                # if dialog['goal'][domain]['info'].has_key('name'):
+                if 'name' in dialog['goal'][domain]['info']:
+                    venue_offered[domain] = '[' + domain + '_name]'
+
+            # special domains - entity does not need to be provided
+            if domain in ['taxi', 'police', 'hospital']:
+                venue_offered[domain] = '[' + domain + '_name]'
+
+            # if id was not requested but train was found we dont want to override it to check if we booked the right train
+            if domain == 'train' and (not venue_offered[domain] and 'id' not in goal['train']['requestable']):
+                venue_offered[domain] = '[' + domain + '_name]'
+
+        # REWARD CALCULATION
+        stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+        match, success = 0.0, 0.0
+        # MATCH
+        for domain in goal.keys():
+            match_stat = 0
+            if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                goal_venues = self.db.queryResultVenues(domain, dialog['goal'][domain]['info'], real_belief=True)
+                if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+                elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues:
+                    match += 1
+                    match_stat = 1
+            else:
+                if domain + '_name' in venue_offered[domain]:
+                    match += 1
+                    match_stat = 1
+
+            stats[domain][0] = match_stat
+            stats[domain][2] = 1
+
+        match = min(1.0, float(match) / len(goal.keys()))
+
+        # SUCCESS
+        if match:
+            for domain in domains_in_goal:
+                domain_success = 0
+                success_stat = 0
+                if len(real_requestables[domain]) == 0:
+                    # check that
+                    success += 1
+                    success_stat = 1
+                    stats[domain][1] = success_stat
+                    continue
+                # if values in sentences are super set of requestables
+                for request in set(provided_requestables[domain]):
+                    if request in real_requestables[domain]:
+                        domain_success += 1
+
+                if domain_success >= len(real_requestables[domain]):
+                    success += 1
+                    success_stat = 1
+
+                stats[domain][1] = success_stat
+
+            # final eval
+            success = min(1.0, float(success) / len(real_requestables))
+
+        return success, match, stats
+
+    def _parse_entities(self, tokens):
+        entities = []
+        for t in tokens:
+            if '[' in t and ']' in t:
+                entities.append(t)
+        return entities
+
+    def evaluateModel(self, dialogues, mode='valid', new_version=False, verbose=True):
+        """Gathers statistics for the whole sets."""
+        delex_dialogues = self.delex_dialogues
+        successes, matches = 0, 0
+        corpus_successes, corpus_matches = 0, 0
+        total = 0
+        data_succ, data_match = [], []
+        corpus_succ, corpus_match = [], []
+
+        gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                     'taxi': [0, 0, 0],
+                     'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+        sng_gen_stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0],
+                         'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+
+        for filename, dial in dialogues.items():
+            if mode == 'rollout':
+                success, match, stats = self._evaluateRolloutDialogue(dial)
+            else:
+                # data is ground truth, dial is generated
+                data = delex_dialogues[filename]
+                goal, success, match, requestables, _ = self._evaluateRealDialogue(data, filename) # only goal and requestables are kept
+                corpus_successes += success
+                corpus_succ.append(success)
+                corpus_matches += match
+                corpus_match.append(match)
+                if new_version:
+                    success, match, stats = self._evaluateGeneratedDialogue_new(dial, goal, data, requestables,
+                                                                            soft_acc=mode =='offline_rl')
+                else:
+                    success, match, stats = self._evaluateGeneratedDialogue(dial, goal, data, requestables,
+                                                                            soft_acc=mode =='offline_rl')
+
+                # if success == 0:
+                    # pdb.set_trace()
+
+            successes += success
+            data_succ.append(success)
+            matches += match
+            data_match.append(match)
+            total += 1
+
+            for domain in gen_stats.keys():
+                gen_stats[domain][0] += stats[domain][0]
+                gen_stats[domain][1] += stats[domain][1]
+                gen_stats[domain][2] += stats[domain][2]
+
+            if 'SNG' in filename:
+                for domain in gen_stats.keys():
+                    sng_gen_stats[domain][0] += stats[domain][0]
+                    sng_gen_stats[domain][1] += stats[domain][1]
+                    sng_gen_stats[domain][2] += stats[domain][2]
+
+        report = ""
+        report += '{} Corpus Matches : {:2.2f}%, Groundtruth {} Matches : {:2.2f}%'.format(mode, (matches / float(total) * 100), mode, (corpus_matches / float(total) * 100)) + "\n"
+        report += '{} Corpus Success : {:2.2f}%, Groundtruth {} Success : {:2.2f}%'.format(mode, (successes / float(total) * 100), mode, (corpus_successes / float(total) * 100)) + "\n"
+        report += 'Total number of dialogues: %s, new version=%s ' % (total, new_version)
+        # compute 95% confidence interval
+        # for name, data in zip(["match", "success", "corpus_match", "corpus_success"], [data_match, data_succ, corpus_match, corpus_succ]):
+            # mean = np.mean(data)
+            # std = np.std(data)
+            # lower, upper = st.t.interval(0.95, len(data)-1, loc=np.mean(data), scale=st.sem(data))
+            # print(f"{name}: {mean * 100} +- {(mean-lower) * 100}")
+ 
+
+        if verbose:
+            self.logger.info(report)
+        return report, successes/float(total), matches/float(total)
+    
+    def get_report(self):
+        tokenize = lambda x: x.split()
+        print('Generate report for {} samples'.format(len(self.hyps)))
+        refs, hyps = [], []
+        tp, fp, fn = 0, 0, 0
+        for label, hyp in zip(self.labels, self.hyps):
+            ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS]
+            hyp_tokens = [BOS] + tokenize(hyp.replace(SYS, '').replace(USR, '').strip()) + [EOS]
+            refs.append([ref_tokens])
+            hyps.append(hyp_tokens)
+
+            ref_entities = self._parse_entities(ref_tokens)
+            hyp_entities = self._parse_entities(hyp_tokens)
+            tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, hyp_entities)
+            tp += tpp
+            fp += fpp
+            fn += fnn
+
+        # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1)
+        bleu = BLEUScorer().score(hyps, refs) 
+        prec, rec, f1 = self._get_prec_recall(tp, fp, fn)
+        report = "\nBLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1)
+        return report, bleu, prec, rec, f1
+
+    def get_groundtruth_report(self):
+        tokenize = lambda x: x.split()
+        print('Generate report for {} samples'.format(len(self.hyps)))
+        refs, hyps = [], []
+        tp, fp, fn = 0, 0, 0
+        for label, hyp in zip(self.labels, self.hyps):
+            ref_tokens = [BOS] + tokenize(label.replace(SYS, '').replace(USR, '').strip()) + [EOS]
+            refs.append([ref_tokens])
+
+            ref_entities = self._parse_entities(ref_tokens)
+            tpp, fpp, fnn = self._get_tp_fp_fn(ref_entities, ref_entities)
+            tp += tpp
+            fp += fpp
+            fn += fnn
+
+        # bleu = corpus_bleu(refs, hyps, smoothing_function=SmoothingFunction().method1)
+        # bleu = BLEUScorer().score(refs, refs) 
+        prec, rec, f1 = self._get_prec_recall(tp, fp, fn)
+        # report = "\nGroundtruth BLEU score {}\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(bleu, prec, rec, f1)
+        report = "\nGroundtruth\nEntity precision {:.4f} recall {:.4f} and f1 {:.4f}\n".format(prec, rec, f1)
+        return report, 0, prec, rec, f1
+
+class MultiWozEvaluatorwPenalty(MultiWozEvaluator):
+    CUR_DIR = os.path.dirname(__file__).replace('latent_dialog', '')
+    logger = logging.getLogger()
+    def __init__(self, data_name, config):
+        super(MultiWozEvaluatorwPenalty, self).__init__(data_name, config)
+
+    def _parseGoal(self, goal, d, domain):
+        """Parses user goal into dictionary format."""
+        goal[domain] = {}
+        goal[domain] = {'informable': [], 'requestable': [], 'booking': [], 'failed': {}}
+        if 'info' in d['goal'][domain]:
+        # if d['goal'][domain].has_key('info'):
+            if domain == 'train':
+                # we consider dialogues only where train had to be booked!
+                if 'book' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('book'):
+                    goal[domain]['requestable'].append('reference')
+                if 'reqt' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('reqt'):
+                    if 'trainID' in d['goal'][domain]['reqt']:
+                        goal[domain]['requestable'].append('id')
+                goal[domain]['failed'] = d['goal'][domain]['fail_info']
+            else:
+                if 'reqt' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('reqt'):
+                    for s in d['goal'][domain]['reqt']:  # addtional requests:
+                        if s in ['phone', 'address', 'postcode', 'reference', 'id']:
+                            # ones that can be easily delexicalized
+                            goal[domain]['requestable'].append(s)
+                goal[domain]['failed'] = d['goal'][domain]['fail_info']
+                if 'book' in d['goal'][domain]:
+                # if d['goal'][domain].has_key('book'):
+                    goal[domain]['requestable'].append("reference")
+
+            goal[domain]["informable"] = d['goal'][domain]['info']
+            if 'book' in d['goal'][domain]:
+            # if d['goal'][domain].has_key('book'):
+                goal[domain]["booking"] = d['goal'][domain]['book']
+
+        return goal
+
+    def _evaluateGeneratedDialogue(self, dialog, goal, realDialogue, real_requestables, soft_acc=False):
+        """Evaluates the dialogue created by the model.
+        First we load the user goal of the dialogue, then for each turn
+        generated by the system we look for key-words.
+        For the Inform rate we look whether the entity was proposed.
+        For the Success rate we look for requestables slots"""
+        # for computing corpus success
+        requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+
+        # CHECK IF MATCH HAPPENED
+        provided_requestables = {}
+        venue_offered = {}
+        fail_info_penalty = {}
+        domains_in_goal = []
+
+        for domain in goal.keys():
+            venue_offered[domain] = []
+            provided_requestables[domain] = []
+            fail_info_penalty[domain] = 0
+            domains_in_goal.append(domain)
+
+        for t, sent_t in enumerate(dialog): # go turn by turn
+            for domain in goal.keys(): # for each domain in goal
+                # for computing success
+                if '[' + domain + '_name]' in sent_t or '_id' in sent_t: # undo delexicalization if system generates [domain_name] or [domain_id]
+                    if domain in ['restaurant', 'hotel', 'attraction', 'train']: 
+                        # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
+                        # in this case, look for the actual offered venues based on true belief state
+                        venues = self.db.queryResultVenues(domain, realDialogue['log'][t * 2 + 1])
+
+                        # if venue has changed
+                        if len(venue_offered[domain]) == 0 and venues:
+                            venue_offered[domain] = random.sample(venues, 1)
+                        else:
+                            flag = False
+                            for ven in venues:
+                                if venue_offered[domain][0] == ven:
+                                    flag = True
+                                    break
+                            if not flag and venues:  # sometimes there are no results so sample won't work
+                                # print venues
+                                venue_offered[domain] = random.sample(venues, 1)
+                        if goal[domain]['failed']:
+                            if any([req in sent_t for req in ['[' + domain + "_" + req + ']' for req in requestables + ['name']]]) and len(venues) == 0:
+                                fail_info_penalty[domain] = 1
+                                # print(sent_t, len(venues))
+
+                    else:  # not limited so we can provide one
+                        venue_offered[domain] = '[' + domain + '_name]'
+                        
+                # ATTENTION: assumption here - we didn't provide phone or address twice! etc
+                for requestable in requestables:
+                    if requestable == 'reference':
+                        if domain + '_reference' in sent_t:
+                            if 'restaurant_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -5] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'hotel_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -3] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            elif 'train_reference' in sent_t:
+                                if realDialogue['log'][t * 2]['db_pointer'][
+                                    -1] == 1:  # if pointer was allowing for that?
+                                    provided_requestables[domain].append('reference')
+
+                            else:
+                                provided_requestables[domain].append('reference')
+                    else:
+                        if '[' + domain + '_' + requestable + ']' in sent_t:
+                            provided_requestables[domain].append(requestable)
+
+        # if name was given in the task
+        for domain in goal.keys():
+            # if name was provided for the user, the match is being done automatically
+            # assumption doesn't always hold, maybe it's better if name is provided by user that it is ignored?
+            if 'info' in realDialogue['goal'][domain]:
+                if 'name' in realDialogue['goal'][domain]['info']:
+                    venue_offered[domain] = '[' + domain + '_name]'
+
+            # special domains - entity does not need to be provided
+            if domain in ['taxi', 'police', 'hospital']:
+                venue_offered[domain] = '[' + domain + '_name]'
+
+            if domain == 'train':
+                if not venue_offered[domain]:
+                    # if realDialogue['goal'][domain].has_key('reqt') and 'id' not in realDialogue['goal'][domain]['reqt']:
+                    if 'reqt' in realDialogue['goal'][domain] and 'id' not in realDialogue['goal'][domain]['reqt']:
+                        venue_offered[domain] = '[' + domain + '_name]'
+
+        """
+        Given all inform and requestable slots
+        we go through each domain from the user goal
+        and check whether right entity was provided and
+        all requestable slots were given to the user.
+        The dialogue is successful if that's the case for all domains.
+        """
+        # HARD EVAL
+        stats = {'restaurant': [0, 0, 0, 0], 'hotel': [0, 0, 0, 0], 'attraction': [0, 0, 0, 0], 'train': [0, 0, 0, 0],
+                 'taxi': [0, 0, 0],
+                 'hospital': [0, 0, 0], 'police': [0, 0, 0]}
+
+        match = 0
+        success = 0
+        # MATCH (between offered venue by generated dialogue and venue actually fitting to the criteria)
+        for domain in goal.keys():
+            match_stat = 0
+            if domain in ['restaurant', 'hotel', 'attraction', 'train']:
+                goal_venues = self.db.queryResultVenues(domain, goal[domain]['informable'], real_belief=True)
+                # if venue offered is not dict
+                if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: # yields false positive, does not match what is offered with real dialogue?
+                    match += 1
+                    match_stat = 1
+                # if venue offered is dict
+                elif len(venue_offered[domain]) > 0 and venue_offered[domain][0] in goal_venues: # actually checks the offered venue
+                    match += 1
+                    match_stat = 1
+                stats[domain][3] = fail_info_penalty[domain]
+            # other domains
+            else:
+                if domain + '_name]' in venue_offered[domain]: # yields false positive, in terms of occurence and correctness
+                    match += 1
+                    match_stat = 1
+
+            stats[domain][0] = match_stat
+            stats[domain][2] = 1
+
+
+        if soft_acc:
+            match = float(match)/len(goal.keys())
+        else:
+            # only count success if all domain has matches
+            if match == len(goal.keys()):
+                match = 1.0
+            else:
+                match = 0.0
+
+        # SUCCESS (whether the requestable info in realDialogue is generated by the system)
+        # if no match, then success is assumed to be 0
+        if match == 1.0:
+            for domain in domains_in_goal:
+                success_stat = 0
+                domain_success = 0
+                if len(real_requestables[domain]) == 0: # if there is no requestable, assume to be succesful. incorrect, cause does not count false positives. 
+                    success += 1
+                    success_stat = 1
+                    stats[domain][1] = success_stat
+                    continue
+                # if values in sentences are super set of requestables
+                for request in set(provided_requestables[domain]):
+                    if request in real_requestables[domain]:
+                        domain_success += 1
+
+                if domain_success >= len(real_requestables[domain]):
+                    success += 1
+                    success_stat = 1
+
+                stats[domain][1] = success_stat 
+            # final eval
+            if soft_acc:
+                success = float(success)/len(real_requestables)
+            else:
+                if success >= len(real_requestables):
+                    success = 1
+                else:
+                    success = 0
+
+        success -= sum([stats[domain][3] for domain in ['restaurant', 'hotel', 'attraction', 'train']])
+
+
+        # rint requests, 'DIFF', requests_real, 'SUCC', success
+        return success, match, stats
diff --git a/latent_dialog/main.py b/latent_dialog/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7debd8f9ff08db390eca0df618c696c2409fbf77
--- /dev/null
+++ b/latent_dialog/main.py
@@ -0,0 +1,1485 @@
+import os
+import sys
+import numpy as np
+import torch as th
+from torch.utils.tensorboard import SummaryWriter
+from torch import nn
+from tqdm import trange
+from collections import defaultdict, Counter
+from latent_dialog.enc2dec.base_modules import summary
+from latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN
+from datetime import datetime
+from latent_dialog.utils import get_detokenize, LONG, FLOAT
+from latent_dialog.corpora import EOS, PAD
+from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE
+from latent_dialog import evaluators
+from latent_dialog.record import record, record_task, UniquenessSentMetric, UniquenessWordMetric
+from sklearn.metrics import mean_squared_error as mse
+import logging
+import pdb
+import json
+from scipy.stats import entropy
+import time
+
+logger = logging.getLogger()
+
+class LossManager(object):
+    def __init__(self):
+        self.losses = defaultdict(list)
+        self.backward_losses = []
+
+    def add_loss(self, loss):
+        for key, val in loss.items():
+            # print('key = %s\nval = %s' % (key, val))
+            if val is not None and type(val) is not bool:
+                self.losses[key].append(val.item())
+
+    def pprint(self, name, window=None, prefix=None):
+        str_losses = []
+        for key, loss in self.losses.items():
+            if loss is None:
+                continue
+            aver_loss = np.average(loss) if window is None else np.average(loss[-window:])
+            if 'nll' in key:
+                str_losses.append('{} PPL {:.3f}'.format(key, np.exp(aver_loss)))
+            else:
+                str_losses.append('{} {:.3f}'.format(key, aver_loss))
+
+
+        if prefix:
+            return '{}: {} {}'.format(prefix, name, ' '.join(str_losses))
+        else:
+            return '{} {}'.format(name, ' '.join(str_losses))
+
+    def clear(self):
+        self.losses = defaultdict(list)
+        self.backward_losses = []
+
+    def add_backward_loss(self, loss):
+        self.backward_losses.append(loss.item())
+
+    def avg_loss(self):
+        return np.mean(self.backward_losses)
+
+class OfflineTaskReinforce(object):
+    def __init__(self, agent, corpus, sv_config, sys_model, rl_config, generate_func):
+        self.agent = agent
+        self.corpus = corpus
+        self.sv_config = sv_config
+        self.sys_model = sys_model
+        self.rl_config = rl_config
+        # training func for supervised learning
+        self.train_func = task_train_single_batch
+        self.record_func = record_task
+        self.validate_func = validate
+
+        # prepare data loader
+        train_dial, val_dial, test_dial = self.corpus.get_corpus()
+        self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config)
+        self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config)
+
+        # create log files
+        if self.rl_config.record_freq > 0:
+            self.learning_exp_file = open(os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w')
+            self.ppl_val_file = open(os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w')
+            self.rl_val_file = open(os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w')
+            self.ppl_test_file = open(os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w')
+            self.rl_test_file = open(os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w')
+        # evaluation
+        if "fail_info_penalty" in rl_config and rl_config.fail_info_penalty:
+            self.evaluator = evaluators.MultiWozEvaluatorwPenalty('SYS_WOZ', rl_config)
+        else:
+            self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', rl_config)
+        self.generate_func = generate_func
+
+    def run(self):
+        n = 0
+        best_valid_loss = np.inf
+        best_rewards = -1 * np.inf
+
+        # BEFORE RUN, RECORD INITIAL PERFORMANCE
+        test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True)
+        t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config,
+                                                              self.evaluator, None, verbose=False)#, temperature=self.rl_config.temperature)
+
+        self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1))
+        self.ppl_test_file.flush()
+        self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match))
+        self.rl_test_file.flush()
+
+        self.sys_model.train()
+        try:
+            for epoch_id in range(self.rl_config.nepoch):
+                self.train_data.epoch_init(self.sv_config, shuffle=True, verbose=epoch_id == 0, fix_batch=True)
+                while True:
+                    if n % self.rl_config.episode_repeat == 0:
+                        batch = self.train_data.next_batch()
+
+                    if batch is None:
+                        break
+
+                    n += 1
+                    if n % 50 == 0:
+                        print("Reinforcement Learning {}/{} episode".format(n, self.train_data.num_batch*self.rl_config.nepoch))
+                        self.learning_exp_file.write(
+                            '{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-50:])))
+                        self.learning_exp_file.flush()
+
+                    # reinforcement learning
+                    # make sure it's the same dialo
+                    assert len(set(batch['keys'])) == 1
+                    task_report, success, match = self.agent.run(batch, self.evaluator, max_words=self.rl_config.max_words, temp=self.rl_config.temperature)
+                    reward = float(success) # + float(match)
+                    # print(reward)
+                    stats = {'Match': match, 'Success': success}
+                    self.agent.update(reward, stats)
+
+                    # supervised learning
+                    if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0:
+                        self.train_func(self.sys_model, self.sl_train_data, self.sv_config)
+
+                    # record model performance in terms of several evaluation metrics
+                    if self.rl_config.record_freq > 0 and n % self.rl_config.record_freq == 0:
+                         self.agent.print_dialog(self.agent.dlg_history, reward, stats)
+                         print('-'*15, 'Recording start', '-'*15)
+                         # save train reward
+                         self.learning_exp_file.write('{}\t{}\n'.format(n, np.mean(self.agent.all_rewards[-self.rl_config.record_freq:])))
+                         self.learning_exp_file.flush()
+
+                         # PPL & reward on validation
+                         valid_loss = self.validate_func(self.sys_model, self.val_data, self.sv_config, use_py=True)
+                         v_success, v_match, v_bleu, v_f1 = self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, None, verbose=False) #, temperature=self.rl_config.temperature)
+                         self.ppl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(valid_loss), v_bleu, v_f1))
+                         self.ppl_val_file.flush()
+                         self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(n, (v_success + v_match), v_success, v_match))
+                         self.rl_val_file.flush()
+
+                         test_loss = self.validate_func(self.sys_model, self.test_data, self.sv_config, use_py=True)
+                         t_success, t_match, t_bleu, t_f1 = self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, None, verbose=False)#, temperature=self.rl_config.temperature)
+                         self.ppl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, np.exp(test_loss), t_bleu, t_f1))
+                         self.ppl_test_file.flush()
+                         self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match))
+                         self.rl_test_file.flush()
+
+                         # save model is needed
+                         if v_success+v_match > best_rewards:
+                             print("Model saved with success {} match {}".format(v_success, v_match))
+                             th.save(self.sys_model.state_dict(), self.rl_config.reward_best_model_path)
+                             best_rewards = v_success+v_match
+
+
+                         self.sys_model.train()
+                         print('-'*15, 'Recording end', '-'*15)
+        except KeyboardInterrupt:
+            print("RL training stopped from keyboard")
+
+        print("$$$ Load {}-model".format(self.rl_config.reward_best_model_path))
+        self.sv_config.batch_size = 32
+        self.sys_model.load_state_dict(th.load(self.rl_config.reward_best_model_path))
+
+        validate(self.sys_model, self.val_data, self.sv_config, use_py=True)
+        validate(self.sys_model, self.test_data, self.sv_config, use_py=True)
+
+        with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f:
+            self.generate_func(self.sys_model, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f)#, temperature=self.rl_config.temperature)
+
+        with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f:
+            self.generate_func(self.sys_model, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f)#, temperature=self.rl_config.temperature)
+
+class OfflinePLAS(object):
+    def __init__(self, agent, corpus, sv_config, rl_config, generate_func, name="", vae_gen=None):
+        self.tb = SummaryWriter(comment=name)
+        logger.info(f"Run will be logged in tensorboard as {name}")
+
+        self.agent = agent
+        self.corpus = corpus
+        self.sv_config = sv_config
+        # self.sys_model = agent.model
+        self.rl_config = rl_config
+        # training func for supervised learning
+        self.train_func = task_train_single_batch
+        # self.record_func = record_task
+        self.validate_func = validate_offlinerl #policy evaluation
+        self.all_rewards = []
+        self.is_gauss = agent.is_gauss
+        self.train_vae = rl_config.train_vae
+
+
+        # prepare data loader
+        train_dial, val_dial, test_dial = self.corpus.get_corpus()
+        self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config)
+        self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config)
+        self.ae_val_data = BeliefDbDataLoadersAE('Val', val_dial, self.sv_config)
+
+        # create log files
+        if self.rl_config.record_freq > 0:
+            self.learning_exp_file = open(os.path.join(self.rl_config.record_path, 'offline-learning.tsv'), 'w')
+            self.ppl_val_file = open(os.path.join(self.rl_config.record_path, 'val-ppl.tsv'), 'w')
+            self.rl_val_file = open(os.path.join(self.rl_config.record_path, 'val-rl.tsv'), 'w')
+            self.ppl_test_file = open(os.path.join(self.rl_config.record_path, 'test-ppl.tsv'), 'w')
+            self.rl_test_file = open(os.path.join(self.rl_config.record_path, 'test-rl.tsv'), 'w')
+        # evaluation
+        self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', rl_config)
+        self.generate_func = generate_func
+        self.vae_generate_func = vae_gen
+
+        if "validate_with_critic" in rl_config:
+            self.validate_with_critic = rl_config.validate_with_critic
+        else:
+            self.validate_with_critic = False
+
+    def extract(self, data_feed, replay_buffer):
+        """
+        extract experiences from corpus
+        """ 
+        data_feed.epoch_init(self.sv_config, shuffle=False, verbose=False, fix_batch=True)
+        # self.sys_model.eval()
+        self.keys = []
+        delex_dialogues = self.evaluator.delex_dialogues
+        belief_state = th.load("/gpfs/project/lubis/LAVA_code/LAVA_dev/data/MultiWOZ_2.1_SetSUMBT_EnD2/setsumbt_belief_states.bin")
+
+        corpus_successes = 0
+        corpus_matches = 0
+        total = 0
+        
+        all_actions = []
+        pred_strs = []
+        true_strs = []
+        
+        n = 0 
+
+        while True:
+            batch = data_feed.next_batch()
+
+            n += 1
+            
+            if n % 500 == 0:
+                print("Processing batch {}/{}".format(n, data_feed.num_batch))
+
+            if batch is None:
+                break
+            
+            batch_size = len(batch['keys'])
+            key = batch['keys'][0]
+            self.keys.append(key)
+            assert len(set(batch['keys'])) == 1
+
+            rewards = np.zeros(batch_size)
+
+            states = []
+            actions = []
+            next_states = []
+            next_actions = []
+            dones = []
+
+    
+            # actions, pred_str, true_str= self._get_actions(batch)
+            # actions, _ = self._sample_actions(batch)
+            # get states, next state, action, and done from data
+            for turn_id, turn in enumerate(batch['contexts']):
+                # belief_state[key][slot_name][turn_id * 2] (the .bin file contains double the amount of turn, but turns 0 and 1 are identical, so are 2 and 3, 4 and 5, etc.)
+                state = {}
+                state['contexts'] = turn
+                state['bs'] = batch['bs'][turn_id]
+                state['db'] = batch['db'][turn_id]
+                state['context_lens'] = batch['context_lens'][turn_id]
+                state['keys'] = batch['keys'][turn_id]
+                state['goals'] = np.concatenate([batch['goals_list'][d][turn_id] for d in range(7)])
+
+                action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id].tolist(), do_pad=True))
+
+                try:
+                    next_state = {}
+                    next_state['contexts'] = batch['contexts'][turn_id + 1]
+                    next_state['bs'] = batch['bs'][turn_id + 1]
+                    next_state['db'] = batch['db'][turn_id + 1]
+                    next_state['context_lens'] = batch['context_lens'][turn_id + 1]
+                    next_state['keys'] = batch['keys'][turn_id + 1]
+                    next_state['goals'] = np.concatenate([batch['goals_list'][d][turn_id + 1] for d in range(7)])
+
+                    next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id + 1].tolist(), do_pad=True))
+                    # next_state['outputs'] = batch['outputs'][turn_id + 1]
+                    done = 0
+                except:
+                    next_state = {}
+                    next_state['contexts'] = np.zeros(batch['contexts'][turn_id].shape)
+                    next_state['bs'] = [0] * self.corpus.bs_size
+                    next_state['db'] = [0] * self.corpus.db_size
+                    next_state['context_lens'] = batch['context_lens'][turn_id]
+                    next_state['keys'] = batch['keys'][turn_id]
+                    next_state['goals'] = [0] * self.corpus.goal_size
+
+                    next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, [0], do_pad=True))
+                    # next_state['outputs'] = np.zeros(self.sv_config.max_utt_len)
+                    done = 1
+                
+                actions.append(action)
+                states.append(state)
+                next_states.append(next_state)
+                next_actions.append(next_action)
+                dones.append(done)
+
+            # compute reward of dialogue from corpus here
+            data = delex_dialogues[key]
+            goal, success, match, requestables, _ = self.evaluator._evaluateRealDialogue(data, key, soft=False)
+            corpus_successes += success
+            corpus_matches += match
+            total += 1
+
+            rewards[-1] = float(success) 
+            g = self.agent.cvae.np2var(np.array([rewards[-1]]), FLOAT).view(1, 1)
+
+            # self.all_rewards.append(reward)
+            # r = (rewards - np.mean(self.all_rewards)) / max(1e-4, np.std(self.all_rewards))
+            
+            # compute accumulated discounted reward
+            returns = []
+            for _ in rewards:
+                returns.insert(0, float(g))
+                g = g * self.rl_config.gamma
+
+
+            # add tuples to replay buffer
+            # assert(len(states) == len(rewards) and len(states) == len(actions) and len(states) == len(dones) and len(states)== len(next_states), len(next_actions)==len(next_states))
+            replay_buffer.add(states, actions, rewards, next_states, next_actions, dones, returns)
+                # for i in range(len(states)):
+                    # replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], next_action[i], dones[i], returns[i])
+
+    def _infoGain(self, P, Q):
+        M = [p + q for p, q in zip(P, Q)]
+        return 0.5 * (entropy(P, M, base=2) + entropy(Q, M, base=2))
+    
+    def run(self):
+        # get exps, step, learn, act, validate
+        n = 0
+        best_valid_loss = np.inf
+        best_rewards = -1 * np.inf
+        vae_epoch_id = 0
+
+        # BEFORE RUN, RECORD INITIAL PERFORMANCE
+        # print("Recoding initial performance")
+        self.agent.actor.eval()
+        self.agent.critic.eval()
+        # test_rewards, test_match = self.validate_func(self.agent, self.evaluator, self.test_data, self.rl_config, mode="test", use_py=True)
+        # t_success, t_match, t_bleu, t_f1 = self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, None, verbose=False)
+
+        # self.ppl_test_file.write('{}\t{}\t{}\t{}\t{}\n'.format(n, np.mean(test_rewards), np.mean(test_match), t_bleu, t_f1))
+        # self.ppl_test_file.flush()
+        # self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(n, (t_success + t_match), t_success, t_match))
+        # self.rl_test_file.flush()
+
+        try:
+            for epoch_id in trange(self.rl_config.nepoch, desc="PLAS_epoch"):
+                for n in trange(self.rl_config.nepisode, desc="PLAS_batch"):
+                    # VAE training
+                    if self.train_vae and n % self.rl_config.train_vae_freq == 0:
+                        if vae_epoch_id == 0:
+                            train_vae_nepisode = self.rl_config.train_vae_nepisode_init
+                            offset_idx = 0
+                        else:
+                            train_vae_nepisode = self.rl_config.train_vae_nepisode
+                            offset_idx =  self.rl_config.train_vae_nepisode_init + self.rl_config.train_vae_nepisode * (vae_epoch_id - 1)
+
+                        logger.info("=== VAE Training ==")
+
+                        logger.info("validating starting performance")
+                        vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True)
+                        # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False)
+                        self.tb.add_scalar("train_vae_success", vae_success, offset_idx + 1)
+                        self.tb.add_scalar("train_vae__match", vae_match, offset_idx + 1)
+                        self.tb.add_scalar("train_vae_bleu", vae_bleu, offset_idx + 1)
+                        self.tb.add_scalar("train_vae_f1", vae_f1, offset_idx + 1)
+
+                        self.agent.actor.eval()
+                        self.agent.critic.eval()
+                        self.agent.cvae.train()
+                        for i in trange(train_vae_nepisode, desc="VAE training"):
+                            vae_loss = self.agent.train_vae_model(i)
+                            if i % 2500 == 0:
+                                # logger.info(f"vae training batch {i}/{train_vae_nepisode}, vae_loss:{vae_loss}")
+
+                                # logger.info("validating ending performance")
+                                vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True)
+                                # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False)
+                                self.tb.add_scalar("train_vae_success", vae_success, i + offset_idx)
+                                self.tb.add_scalar("train_vae__match", vae_match, i + offset_idx)
+                                self.tb.add_scalar("train_vae_bleu", vae_bleu, i + offset_idx)
+                                self.tb.add_scalar("train_vae_f1", vae_f1, i + offset_idx)
+
+                            self.tb.add_scalar("vae_loss", vae_loss, (i+1) + (self.rl_config.nepisode/self.rl_config.train_vae_freq) * (epoch_id))
+
+                        vae_epoch_id += 1
+
+                    # reinforcement learning
+                    self.agent.actor.train()
+                    self.agent.critic.train()
+                    self.agent.cvae.eval()
+                    # tic = time.perf_counter()
+                    plas_report, losses = self.agent.train(verbose=n%10==0, 
+                            max_words=self.rl_config.max_words, 
+                            temp=self.rl_config.temperature, 
+                            debug=n%500==0, 
+                            n=(n+1) + (self.rl_config.nepisode) * (epoch_id))
+                    # toc = time.perf_counter() 
+                    # print(f"One batch of forward pass plas training in {toc - tic:0.4f} seconds")
+
+
+                    if n % 20 == 0:
+                        for k, v in losses.items():
+                            self.tb.add_scalar(k, v, (self.rl_config.nepisode) * (epoch_id) + (n+1))
+
+                    if n % 100 == 0 and n > 0:
+                        logger.info("PLAS {}/{} episode, {}/{} epoch".format(n, self.rl_config.nepisode, epoch_id, self.rl_config.nepoch))
+                        self.learning_exp_file.write(plas_report + "\n")
+                        self.learning_exp_file.flush()
+
+                    # supervised learning
+                    if self.rl_config.sv_train_freq > 0 and n % self.rl_config.sv_train_freq == 0:
+                        self.train_func(self.agent.actor, self.sl_train_data, self.sv_config)
+
+                   
+                    if n > 1 and n % self.rl_config.record_freq == 0:
+                        self.agent.actor.eval()
+                        self.agent.critic.eval()
+
+                        print('-'*15, 'Recording start', '-'*15)
+                        # save train reward
+
+                        # PPL & reward on validation
+                        
+                        print("Running validation on train set")
+                        tr_success, tr_match, tr_bleu, tr_f1, tr_Q = self.generate_func(self.agent.actor, self.agent.cvae, self.train_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic)
+                        self.tb.add_scalar("train_Q", tr_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+
+
+                        print("Running validation on valid set")
+                        # valid_rewards, valid_match = self.validate_func(self.agent, self.evaluator, self.val_data, self.rl_config, use_py=True)
+                        # tic = time.perf_counter()
+                        v_success, v_match, v_bleu, v_f1, v_Q = self.generate_func(self.agent.actor,  self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic)
+                        # toc = time.perf_counter() 
+                        # print(f"One batch of validation pass of plas training in {toc - tic:0.4f} seconds")
+                        val_critic_error = v_Q - v_success
+                        # self.ppl_val_file.write('{}\t{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, np.mean(valid_rewards), np.mean(valid_match), v_bleu, v_f1))
+                        # self.ppl_val_file.flush()
+                        # print(v_success, v_match, v_bleu, v_f1, v_Q, val_critic_error)
+                        self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, (v_success + v_match), v_success, v_match))
+                        self.rl_val_file.flush()
+                        self.tb.add_scalar("validation_success", v_success, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_match", v_match, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_bleu", v_bleu, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_f1", v_f1, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_critic_mse", val_critic_error, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+
+                        print("Running validation on test set")
+                        # test_rewards, test_match = self.validate_func(self.agent, self.evaluator, self.test_data, self.rl_config, mode="test", use_py=True)
+                        t_success, t_match, t_bleu, t_f1, t_Q = self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, None, verbose=False, critic=self.agent.critic)
+                        test_critic_error = t_Q - t_success
+                        # self.ppl_test_file.write('{}\t{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, np.mean(test_rewards), np.mean(test_match), t_bleu, t_f1))
+                        # self.ppl_test_file.flush()
+                        # print(t_success, t_match, t_bleu, t_f1, t_Q, test_critic_error)
+                        self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.rl_config.nepisode + n, (t_success + t_match), t_success, t_match))
+                        self.rl_test_file.flush()
+                        self.tb.add_scalar("test_success", t_success, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_match", t_match, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_bleu", t_bleu, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_f1", t_f1, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_critic_mse", test_critic_error, (n+1) + (self.rl_config.nepisode) * (epoch_id))
+
+
+                        # save model is needed
+                        if self.validate_with_critic:
+                            if v_Q > best_rewards:
+                                print("Model saved with estimated value of {}".format(np.mean(v_Q)))
+                                print("Corpus success {} and match {}".format(np.mean(v_success), np.mean(v_match)))
+                                th.save(self.agent.cvae.state_dict(), self.rl_config.reward_best_model_path)
+                                th.save(self.agent.actor.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".actor"))
+                                th.save(self.agent.critic.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".critic"))
+                                best_rewards = v_Q
+
+                        else:
+                            if v_success+v_match > best_rewards:
+                            # if np.mean(v_success) > best_rewards:
+                                print("Model saved with success {} and match {}".format(np.mean(v_success), np.mean(v_match)))
+                                th.save(self.agent.cvae.state_dict(), self.rl_config.reward_best_model_path)
+                                th.save(self.agent.actor.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".actor"))
+                                th.save(self.agent.critic.state_dict(), self.rl_config.reward_best_model_path.replace(".model", ".critic"))
+                                # best_rewards = np.mean(v_success)
+                                best_rewards = v_success + v_match
+                        
+                        # for name, weight in self.agent.named_parameters():
+                            # tb.add_histogram(name, weight, n + 1 )
+                            # tb.add_histogram(f'{name}.grad', weight.grad, n + 1)
+
+
+                        # self.sys_model.train()
+                        print('-'*15, 'Recording end', '-'*15)
+
+        except KeyboardInterrupt:
+            print("RL training stopped from keyboard")
+        self.tb.close()
+
+        print("$$$ Load {}-model".format(self.rl_config.reward_best_model_path))
+        self.sv_config.batch_size = 32
+        self.agent.cvae.load_state_dict(th.load(self.rl_config.reward_best_model_path))
+        self.agent.actor.load_state_dict(th.load(self.rl_config.reward_best_model_path.replace(".model", ".actor")))
+        self.agent.critic.load_state_dict(th.load(self.rl_config.reward_best_model_path.replace(".model", ".critic")))
+
+        # validate(self.sys_model, self.val_data, self.sv_config, use_py=True)
+        # validate(self.sys_model, self.test_data, self.sv_config, use_py=True)
+
+        with open(os.path.join(self.rl_config.record_path, 'train_file.txt'), 'w') as f:
+            self.generate_func(self.agent.actor, self.agent.cvae, self.train_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic)
+
+        with open(os.path.join(self.rl_config.record_path, 'valid_file.txt'), 'w') as f:
+            self.generate_func(self.agent.actor, self.agent.cvae, self.val_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic)
+
+        with open(os.path.join(self.rl_config.record_path, 'test_file.txt'), 'w') as f:
+            self.generate_func(self.agent.actor, self.agent.cvae, self.test_data, self.sv_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic)
+
+class OfflineCritic(object):
+    def __init__(self, agent, corpus, sv_config, critic_config, generate_func, name="", vae_gen=None, forward_only=False):
+        self.tb = SummaryWriter(comment=name)
+        logger.info(f"Run will be logged in tensorboard as {name}")
+
+        self.agent = agent
+        self.corpus = corpus
+        self.sv_config = sv_config
+        # self.sys_model = agent.model
+        self.critic_config = critic_config
+        # training func for supervised learning
+        self.train_func = task_train_single_batch
+        # self.record_func = record_task
+        self.validate_func = validate_offlinerl #policy evaluation
+        self.all_rewards = []
+        self.is_gauss = agent.is_gauss
+        self.train_vae = critic_config.train_vae
+
+        # prepare data loader
+        train_dial, val_dial, test_dial = self.corpus.get_corpus()
+        self.train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.sl_train_data = BeliefDbDataLoaders('Train', train_dial, self.sv_config)
+        self.val_data = BeliefDbDataLoaders('Val', val_dial, self.sv_config)
+        self.test_data = BeliefDbDataLoaders('Test', test_dial, self.sv_config)
+        self.ae_val_data = BeliefDbDataLoadersAE('Val', val_dial, self.sv_config)
+
+        # create log files
+        log_mode = "r" if forward_only else "w"
+        if self.critic_config.record_freq > 0:
+            # self.learning_exp_file = open(os.path.join(self.critic_config.record_path, 'offline-learning.tsv'), 'w')
+            # self.ppl_val_file = open(os.path.join(self.critic_config.record_path, 'val-ppl.tsv'), 'w')
+            self.rl_val_file = open(os.path.join(self.critic_config.record_path, 'val-critic.tsv'), log_mode)
+            # self.ppl_test_file = open(os.path.join(self.critic_config.record_path, 'test-ppl.tsv'), 'w')
+            self.rl_test_file = open(os.path.join(self.critic_config.record_path, 'test-critic.tsv'),  log_mode)
+        # evaluation
+        self.evaluator = evaluators.MultiWozEvaluator('SYS_WOZ', critic_config)
+        self.generate_func = generate_func
+        self.vae_generate_func = vae_gen
+
+    def extract(self, data_feed, replay_buffer):
+        """
+        extract experiences from corpus
+        """ 
+        data_feed.epoch_init(self.sv_config, shuffle=False, verbose=False, fix_batch=True)
+        # self.sys_model.eval()
+        self.keys = []
+        delex_dialogues = self.evaluator.delex_dialogues
+        belief_state = th.load("/gpfs/project/lubis/LAVA_code/LAVA_dev/data/MultiWOZ_2.1_SetSUMBT_EnD2/setsumbt_belief_states.bin")
+
+        corpus_successes = 0
+        corpus_matches = 0
+        total = 0
+        
+        all_actions = []
+        pred_strs = []
+        true_strs = []
+        
+        n = 0 
+
+        while True:
+            batch = data_feed.next_batch()
+
+            n += 1
+            
+            if n % 500 == 0:
+                print("Processing batch {}/{}".format(n, data_feed.num_batch))
+
+            if batch is None:
+                break
+            
+            batch_size = len(batch['keys'])
+            key = batch['keys'][0]
+            self.keys.append(key)
+            assert len(set(batch['keys'])) == 1
+
+            rewards = np.zeros(batch_size)
+
+            states = []
+            actions = []
+            next_states = []
+            next_actions = []
+            dones = []
+
+    
+            # actions, pred_str, true_str= self._get_actions(batch)
+            # actions, _ = self._sample_actions(batch)
+            # get states, next state, action, and done from data
+            for turn_id, turn in enumerate(batch['contexts']):
+                # for each turn, compute info gain wrt previous turn (if exist). 
+                # belief_state[key][slot_name][turn_id * 2] (the .bin file contains double the amount of turn, but turns 0 and 1 are identical, so are 2 and 3, 4 and 5, etc.)
+                state = {}
+                state['contexts'] = turn
+                state['bs'] = batch['bs'][turn_id]
+                state['db'] = batch['db'][turn_id]
+                state['context_lens'] = batch['context_lens'][turn_id]
+                state['keys'] = batch['keys'][turn_id]
+                state['goals'] = np.concatenate([batch['goals_list'][d][turn_id] for d in range(7)])
+
+                action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id].tolist(), do_pad=True))
+
+                try:
+                    next_state = {}
+                    next_state['contexts'] = batch['contexts'][turn_id + 1]
+                    next_state['bs'] = batch['bs'][turn_id + 1]
+                    next_state['db'] = batch['db'][turn_id + 1]
+                    next_state['context_lens'] = batch['context_lens'][turn_id + 1]
+                    next_state['keys'] = batch['keys'][turn_id + 1]
+                    next_state['goals'] = np.concatenate([batch['goals_list'][d][turn_id + 1] for d in range(7)])
+
+                    next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, batch['outputs'][turn_id + 1].tolist(), do_pad=True))
+                    # next_state['outputs'] = batch['outputs'][turn_id + 1]
+                    done = 0
+                except:
+                    next_state = {}
+                    next_state['contexts'] = np.zeros(batch['contexts'][turn_id].shape)
+                    next_state['bs'] = [0] * self.corpus.bs_size
+                    next_state['db'] = [0] * self.corpus.db_size
+                    next_state['context_lens'] = batch['context_lens'][turn_id]
+                    next_state['keys'] = batch['keys'][turn_id]
+                    next_state['goals'] = [0] * self.corpus.goal_size
+
+                    next_action = np.asarray(self.corpus.pad_to(self.sv_config.max_utt_len, [0], do_pad=True))
+                    # next_state['outputs'] = np.zeros(self.sv_config.max_utt_len)
+                    done = 1
+                
+                actions.append(action)
+                states.append(state)
+                next_states.append(next_state)
+                next_actions.append(next_action)
+                dones.append(done)
+
+
+            # compute reward of dialogue from corpus here
+            data = delex_dialogues[key]
+            goal, success, match, requestables, _ = self.evaluator._evaluateRealDialogue(data, key, soft=False)
+            corpus_successes += success
+            corpus_matches += match
+            total += 1
+
+            rewards[-1] = float(success) 
+            g = self.agent.cvae.np2var(np.array([rewards[-1]]), FLOAT).view(1, 1)
+
+            # compute accumulated discounted reward
+            returns = []
+            for _ in rewards:
+                returns.insert(0, float(g))
+                g = g * self.critic_config.gamma
+
+
+            # add tuples to replay buffer
+            # assert(len(states) == len(rewards) and len(states) == len(actions) and len(states) == len(dones) and len(states)== len(next_states), len(next_actions)==len(next_states))
+            replay_buffer.add(states, actions, rewards, next_states, next_actions, dones, returns)
+                # for i in range(len(states)):
+                    # replay_buffer.add(states[i], actions[i], rewards[i], next_states[i], next_action[i], dones[i], returns[i])
+
+    def _infoGain(self, P, Q):
+        M = [p + q for p, q in zip(P, Q)]
+        return 0.5 * (entropy(P, M, base=2) + entropy(Q, M, base=2))
+    
+    def run(self):
+        # get exps, step, learn, act, validate
+        n = 0
+        best_valid_loss = np.inf
+        vae_epoch_id = 0
+
+        try:
+            for epoch_id in trange(self.critic_config.nepoch, desc="Critic_epoch"):
+                for n in trange(self.critic_config.nepisode, desc="Critic_batch"):
+                    # VAE training
+                    if self.train_vae and n % self.critic_config.train_vae_freq == 0:
+                        if vae_epoch_id == 0:
+                            train_vae_nepisode = self.critic_config.train_vae_nepisode_init
+                            offset_idx = 0
+                        else:
+                            train_vae_nepisode = self.critic_config.train_vae_nepisode
+                            offset_idx =  self.critic_config.train_vae_nepisode_init + self.critic_config.train_vae_nepisode * (vae_epoch_id - 1)
+
+                        logger.info("=== VAE Training ==")
+
+                        logger.info("validating starting performance")
+                        vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True)
+                        # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False)
+                        self.tb.add_scalar("train_vae_success", vae_success, offset_idx + 1)
+                        self.tb.add_scalar("train_vae_match", vae_match, offset_idx + 1)
+                        self.tb.add_scalar("train_vae_bleu", vae_bleu, offset_idx + 1)
+                        self.tb.add_scalar("train_vae_f1", vae_f1, offset_idx + 1)
+
+                        self.agent.critic.eval()
+                        self.agent.cvae.train()
+                        for i in trange(train_vae_nepisode, desc="VAE training"):
+                            vae_loss = self.agent.train_vae_model(i)
+                            if i % 2500 == 0:
+                                vae_success, vae_match, vae_bleu, vae_f1 = self.vae_generate_func(self.agent.cvae, self.ae_val_data, self.sv_config, self.evaluator, None, verbose=False, aux_mt=True)
+                                # _success, _match, _bleu, _f1 = self.vae_generate_func(self.agent.cvae, self.val_data, self.sv_config, self.evaluator, None, verbose=False)
+                                self.tb.add_scalar("train_vae_success", vae_success, i + offset_idx)
+                                self.tb.add_scalar("train_vae_match", vae_match, i + offset_idx)
+                                self.tb.add_scalar("train_vae_bleu", vae_bleu, i + offset_idx)
+                                self.tb.add_scalar("train_vae_f1", vae_f1, i + offset_idx)
+
+                            self.tb.add_scalar("vae_loss", vae_loss, (i+1) + (self.critic_config.nepisode/self.critic_config.train_vae_freq) * (epoch_id))
+
+                        vae_epoch_id += 1
+
+                    # critic_training
+                    self.agent.critic.train()
+                    self.agent.cvae.eval()
+                    # tic = time.perf_counter()
+                    plas_report, losses = self.agent.train_critic(verbose=n%10==0, 
+                            max_words=self.critic_config.max_words, 
+                            temp=self.critic_config.temperature, 
+                            debug=n%500==0, 
+                            n=(n+1) + (self.critic_config.nepisode) * (epoch_id))
+                    # toc = time.perf_counter()
+                    # print(f"One batch forward pass in {toc - tic:0.4f} seconds")
+                       
+
+
+                    if n % 20 == 0:
+                        for k, v in losses.items():
+                            self.tb.add_scalar(k, v, (self.critic_config.nepisode) * (epoch_id) + (n+1))
+
+                  
+                    if n % self.critic_config.record_freq == 0:
+                        self.agent.critic.eval()
+
+                        print('-'*15, 'Recording start', '-'*15)
+                        # save train reward
+
+                        # PPL & reward on validation
+                        print("Running validation on valid set")
+                        # tic = time.perf_counter()
+                        if not self.agent.raw_response:
+                            v_success, v_match, v_bleu, v_f1, v_Q = self.generate_func(self.agent.cvae, self.val_data, self.sv_config,self.critic_config, self.evaluator, None, verbose=False, critic=self.agent.critic, actor = self.agent.actor)
+                            val_critic_error = v_Q - v_success
+                            # print(v_success, v_match, v_bleu, v_f1, v_Q, val_critic_error)
+                            self.rl_val_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_success, v_match, v_Q))
+                            self.tb.add_scalar("validation_critic_mse", val_critic_error, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        else:
+                            v_Q = self.generate_func(self.val_data, self.agent, None)
+                            self.rl_val_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_Q))
+                        # toc = time.perf_counter()
+                        # print(f"One validation pass in {toc - tic:0.4f} seconds")
+                        self.rl_val_file.flush()
+
+                        # self.tb.add_scalar("validation_success", v_success, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("validation_match", v_match, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("validation_bleu", v_bleu, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("validation_f1", v_f1, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+
+                        print("Running validation on test set")
+                        if not self.agent.raw_response:
+                            t_success, t_match, t_bleu, t_f1, t_Q = self.generate_func(self.agent.cvae, self.test_data, self.sv_config, self.critic_config, self.evaluator, None, verbose=False, critic=self.agent.critic, actor = self.agent.actor)
+                            test_critic_error = t_Q - t_success
+                            # print(t_success, t_match, t_bleu, t_f1, t_Q, test_critic_error)
+                            self.rl_test_file.write('{}\t{}\t{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_success, t_match, t_Q))
+                            self.tb.add_scalar("test_critic_mse", test_critic_error, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        else:
+                            t_Q = self.generate_func(self.test_data, self.agent, None)
+                            self.rl_test_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_Q))
+
+                        self.rl_test_file.flush()
+                        # self.tb.add_scalar("test_success", t_success, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("test_match", t_match, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("test_bleu", t_bleu, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        # self.tb.add_scalar("test_f1", t_f1, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+                        self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+
+                        # always save critic as we don't have a good validation metric
+                        print("Critic saved after {} batches".format(epoch_id * self.critic_config.nepisode + n))
+                        th.save(self.agent.critic.state_dict(), self.critic_config.critic_model_path)
+                        # best_valid_loss = test_critic_error
+                     
+
+                        # for name, weight in self.agent.named_parameters():
+                            # tb.add_histogram(name, weight, n + 1 )
+                            # tb.add_histogram(f'{name}.grad', weight.grad, n + 1)
+
+                        self.agent.critic.train()
+                        print('-'*15, 'Recording end', '-'*15)
+
+        except KeyboardInterrupt:
+            print("Critic training stopped from keyboard")
+        self.tb.close()
+
+        print("$$$ Load {}-model".format(self.critic_config.critic_model_path))
+        self.sv_config.batch_size = 32
+        self.agent.critic.load_state_dict(th.load(self.critic_config.critic_model_path))
+
+        if not self.agent.raw_response:
+
+            with open(os.path.join(self.critic_config.record_path, 'train_file.txt'), 'w') as f:
+                self.generate_func(self.agent.cvae, self.train_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor = self.agent.actor)
+
+            with open(os.path.join(self.critic_config.record_path, 'valid_file.txt'), 'w') as f:
+                self.generate_func(self.agent.cvae, self.val_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor = self.agent.actor)
+
+            with open(os.path.join(self.critic_config.record_path, 'test_file.txt'), 'w') as f:
+                self.generate_func(self.agent.cvae, self.test_data, self.sv_config, self.critic_config, self.evaluator, num_batch=None, dest_f=f, critic=self.agent.critic, actor=self.agent.actor)
+    
+    def run_behavior_policy(self):
+        # get exps, step, learn, act, validate
+        n = 0
+        best_valid_loss = np.inf
+        vae_epoch_id = 0
+
+        try:
+            for epoch_id in trange(self.critic_config.nepoch, desc="Critic_epoch"):
+                for n in trange(self.critic_config.nepisode, desc="Critic_batch"):
+                    # critic_training
+                    self.agent.critic.train()
+                    self.agent.cvae.eval()
+                    # tic = time.perf_counter()
+                    plas_report, losses = self.agent.train_critic(verbose=n%10==0, 
+                            max_words=self.critic_config.max_words, 
+                            temp=self.critic_config.temperature, 
+                            sl=not self.agent.corpus_response,
+                            debug=n%500==0, 
+                            n=(n+1) + (self.critic_config.nepisode) * (epoch_id))
+
+                    if n % 20 == 0:
+                        for k, v in losses.items():
+                            self.tb.add_scalar(k, v, (self.critic_config.nepisode) * (epoch_id) + (n+1))
+
+                    if n % self.critic_config.record_freq == 0:
+                        self.agent.critic.eval()
+
+                        print('-'*15, 'Recording start', '-'*15)
+                        # save train reward
+
+                        # PPL & reward on validation
+                        print("Running validation on valid set")
+                        v_Q = self.generate_func(self.val_data, self.agent, None)
+                        self.rl_val_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, v_Q))
+                        self.rl_val_file.flush()
+                        self.tb.add_scalar("validation_Q", v_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+
+                        print("Running validation on test set")
+                        t_Q = self.generate_func(self.test_data, self.agent, None)
+                        self.rl_test_file.write('{}\t{}\n'.format(epoch_id * self.critic_config.nepisode + n, t_Q))
+                        self.rl_test_file.flush()
+                        self.tb.add_scalar("test_Q", t_Q, (n+1) + (self.critic_config.nepisode) * (epoch_id))
+
+                        # always save critic as we don't have a good validation metric
+                        print("Critic saved after {} batches".format(epoch_id * self.critic_config.nepisode + n))
+                        th.save(self.agent.critic.state_dict(), self.critic_config.critic_model_path)
+                     
+                        self.agent.critic.train()
+                        print('-'*15, 'Recording end', '-'*15)
+
+        except KeyboardInterrupt:
+            print("Critic training stopped from keyboard")
+        self.tb.close()
+
+        print("$$$ Load {}-model".format(self.critic_config.critic_model_path))
+        self.sv_config.batch_size = 32
+        self.agent.critic.load_state_dict(th.load(self.critic_config.critic_model_path))
+
+
+def validate_rl(dialog_eval, ctx_gen, num_episode=200):
+    print("Validate on training goals for {} episode".format(num_episode))
+    reward_list = []
+    agree_list = []
+    sent_metric = UniquenessSentMetric()
+    word_metric = UniquenessWordMetric()
+    for _ in range(num_episode):
+        ctxs = ctx_gen.sample()
+        conv, agree, rewards = dialog_eval.run(ctxs)
+        true_reward = rewards[0] if agree else 0
+        reward_list.append(true_reward)
+        agree_list.append(float(agree if agree is not None else 0.0))
+        for turn in conv:
+            if turn[0] == 'Elder':
+                sent_metric.record(turn[1])
+                word_metric.record(turn[1])
+    results = {'sys_rew': np.average(reward_list),
+               'avg_agree': np.average(agree_list),
+               'sys_sent_unique': sent_metric.value(),
+               'sys_unique': word_metric.value()}
+    return results
+
+def train_single_batch(model, train_data, config):
+    batch_cnt = 0
+    optimizer = model.get_optimizer(config, verbose=False)
+    model.train()
+    
+    # decoding CE
+    train_data.epoch_init(config, shuffle=True, verbose=False)
+    for i in range(16):
+        batch = train_data.next_batch()
+        if batch is None:
+            train_data.epoch_init(config, shuffle=True, verbose=False)
+            batch = train_data.next_batch()
+        optimizer.zero_grad()
+        loss = model(batch, mode=TEACH_FORCE)
+        model.backward(loss, batch_cnt)
+        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+        optimizer.step()
+
+def task_train_single_batch(model, train_data, config):
+    batch_cnt = 0
+    optimizer = model.get_optimizer(config, verbose=False)
+    model.train()
+
+    # decoding CE
+    train_data.epoch_init(config, shuffle=True, verbose=False)
+    for i in range(16):
+        batch = train_data.next_batch()
+        if batch is None:
+            train_data.epoch_init(config, shuffle=True, verbose=False)
+            batch = train_data.next_batch()
+        optimizer.zero_grad()
+        loss = model(batch, mode=TEACH_FORCE)
+        model.backward(loss, batch_cnt)
+        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+        optimizer.step()
+
+def train(model, train_data, val_data, test_data, config, evaluator, gen=None):
+    patience = 10
+    valid_loss_threshold = np.inf
+    best_valid_loss = np.inf
+    batch_cnt = 0
+    optimizer = model.get_optimizer(config)
+    done_epoch = 0
+    best_epoch = 0
+    train_loss = LossManager()
+    model.train()
+    logger.info(summary(model, show_weights=False))
+    saved_models = []
+    last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5
+    num_noised_tokens = {}
+
+    logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info('***** Epoch 0/{} *****'.format(config.max_epoch))
+    while True:
+        train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch)
+        while True:
+            batch = train_data.next_batch()
+
+            if batch is None:
+                break
+ 
+            optimizer.zero_grad()
+            loss = model(batch, mode=TEACH_FORCE)
+            model.backward(loss, batch_cnt)
+            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+            optimizer.step()
+            train_loss.add_loss(loss)
+    
+            if batch_cnt % config.print_step == 0:
+                # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+                logger.info(train_loss.pprint('Train',
+                                        window=config.print_step, 
+                                        prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w)))
+                sys.stdout.flush()
+    
+            if batch_cnt % config.ckpt_step == 0:
+                logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+                logger.info('==== Evaluating Model ====')
+                logger.info(train_loss.pprint('Train'))
+                done_epoch += 1
+                logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch))
+
+                # generation
+                if gen is not None:
+                    gen(model, train_data, config, evaluator, num_batch=config.preview_batch_num)
+
+                # validation
+                valid_loss = validate(model, val_data, config, batch_cnt)
+                _ = validate(model, test_data, config, batch_cnt)
+
+                # update early stopping stats
+                if valid_loss < best_valid_loss:
+                    if valid_loss <= valid_loss_threshold * config.improve_threshold:
+                        patience = max(patience, done_epoch*config.patient_increase)
+                        valid_loss_threshold = valid_loss
+                        logger.info('Update patience to {}'.format(patience))
+    
+                    if config.save_model:
+                        cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
+                        logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time))
+                        th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch)))
+                        best_epoch = done_epoch
+                        saved_models.append(done_epoch)
+                        if len(saved_models) > last_n_model:
+                            remove_model = saved_models[0]
+                            saved_models = saved_models[-last_n_model:]
+                            os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model)))
+    
+                    best_valid_loss = valid_loss
+
+                if train_data.noise_type is not None:
+                    num_noised_tokens[done_epoch-1] = (train_data.num_noised_tokens, 
+                                                    train_data.num_noised_tokens / train_data.num_examples, 
+                                                    train_data.num_examples,
+                                                    train_data.noised_tokens_dist)
+                    train_data.num_noised_tokens = 0
+                    train_data.noised_tokens_dist = defaultdict(int)
+                    train_data.num_examples = 0
+
+    
+                if done_epoch >= config.max_epoch \
+                        or config.early_stop and patience <= done_epoch:
+
+                    file_path = os.path.join(config.saved_path, "tokens_noised_count.json")
+                    with open(file_path, "w+") as json_file:
+                        json.dump(num_noised_tokens, json_file, indent=4)
+
+                    if done_epoch < config.max_epoch:
+                        logger.info('!!!!! Early stop due to run out of patience !!!!!')
+                    logger.info('Best validation loss = %f' % (best_valid_loss, ))
+                    return best_epoch
+    
+                # exit eval model
+                model.train()
+                train_loss.clear()
+                logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch))
+                sys.stdout.flush()
+
+            batch_cnt += 1
+
+def mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=None):
+    patience = 10
+    valid_loss_threshold = np.inf
+    best_valid_loss = np.inf
+    batch_cnt = 0
+    optimizer = model.get_optimizer(config)
+    done_epoch = 0
+    best_epoch = 0
+    train_loss = LossManager()
+    model.train()
+    logger.info(summary(model, show_weights=False))
+    saved_models = []
+    last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5
+
+    logger.info('***** Training Begins at {} *****'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info('***** Epoch 0/{} *****'.format(config.max_epoch))
+    while True:
+        train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch)
+        while True:
+            batch = train_data.next_batch()
+            if batch is None:
+                break
+    
+            optimizer.zero_grad()
+            loss = model(batch, mode=TEACH_FORCE)
+            model.backward(loss, batch_cnt)
+            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+            optimizer.step()
+            batch_cnt += 1
+            train_loss.add_loss(loss)
+    
+            if batch_cnt % config.print_step == 0:
+                # print('Print step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+                logger.info(train_loss.pprint('Train',
+                                        window=config.print_step, 
+                                        prefix='{}/{}-({:.3f})'.format(batch_cnt%config.ckpt_step, config.ckpt_step, model.kl_w)))
+                sys.stdout.flush()
+
+            
+    
+            if batch_cnt % config.ckpt_step == 0:
+                logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+                logger.info('==== Evaluating Model ====')
+                logger.info(train_loss.pprint('Train'))
+                done_epoch += 1
+                logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch))
+
+                # generation
+                if gen is not None:
+                    gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num)
+                    gen(model, aux_val_data, config, evaluator, num_batch=config.preview_batch_num, aux_mt=True)
+
+                # validation
+                valid_loss = validate(model, val_data, config, batch_cnt)
+                _ = validate(model, test_data, config, batch_cnt)
+
+                # update early stopping stats
+                if valid_loss < best_valid_loss:
+                    if valid_loss <= valid_loss_threshold * config.improve_threshold:
+                        patience = max(patience, done_epoch*config.patient_increase)
+                        valid_loss_threshold = valid_loss
+                        logger.info('Update patience to {}'.format(patience))
+    
+                    if config.save_model:
+                        cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
+                        logger.info('!!Model Saved with loss = {},at {}.'.format(valid_loss, cur_time))
+                        th.save(model.state_dict(), os.path.join(config.saved_path, '{}-model'.format(done_epoch)))
+                        best_epoch = done_epoch
+                        saved_models.append(done_epoch)
+                        if len(saved_models) > last_n_model:
+                            remove_model = saved_models[0]
+                            saved_models = saved_models[-last_n_model:]
+                            os.remove(os.path.join(config.saved_path, "{}-model".format(remove_model)))
+    
+                    best_valid_loss = valid_loss
+
+                if not model.shared_train:
+                    if done_epoch % config.aux_train_freq == 0:
+                        model.train()
+                        train_aux(model, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=gen)
+             
+    
+                if done_epoch >= config.max_epoch \
+                        or config.early_stop and patience <= done_epoch:
+                    if done_epoch < config.max_epoch:
+                        logger.info('!!!!! Early stop due to run out of patience !!!!!')
+                    print('Best validation loss = %f' % (best_valid_loss, ))
+                    return best_epoch
+    
+                
+               
+                # exit eval model
+                model.train()
+                train_loss.clear()
+
+                logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.max_epoch))
+                sys.stdout.flush()
+                
+def train_aux(model, train_data, val_data, test_data, config, evaluator, gen=None):
+    patience = 10
+    valid_loss_threshold = np.inf
+    best_valid_loss = np.inf
+    batch_cnt = 0
+    optimizer = model.get_optimizer(config)
+    done_epoch = 0
+    best_epoch = 0
+    train_loss = LossManager()
+    model.train()
+    # logger.info(summary(model, show_weights=False))
+    saved_models = []
+    last_n_model = config.last_n_model if hasattr(config, 'last_n_model') else 5
+
+    logger.info('+++++ Aux Training Begins at {} +++++'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info('+++++ Epoch 0/{} +++++'.format(config.aux_max_epoch))
+    while True:
+        train_data.epoch_init(config, shuffle=True, verbose=done_epoch==0, fix_batch=config.fix_train_batch)
+        while True:
+            batch = train_data.next_batch()
+            if batch is None:
+                break
+
+            optimizer.zero_grad()
+            loss = model.forward_aux(batch, mode=TEACH_FORCE)
+            model.backward(loss, batch_cnt)
+            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+            optimizer.step()
+            batch_cnt += 1
+            train_loss.add_loss(loss)
+
+            if batch_cnt % config.ckpt_step == 0:
+                done_epoch += 1
+                logger.info('done epoch {} -> {}'.format(done_epoch-1, done_epoch))
+    
+                logger.info('Checkpoint step at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+                logger.info('++++ Evaluating Model ++++')
+                logger.info(train_loss.pprint('Aux train'))
+
+                # generation
+                if gen is not None:
+                    gen(model, val_data, config, evaluator, num_batch=config.preview_batch_num, aux_mt=True)
+
+                # validation
+                valid_loss = aux_validate(model, val_data, config, batch_cnt)
+                _ = aux_validate(model, test_data, config, batch_cnt)
+
+                # update early stopping stats
+                if valid_loss < best_valid_loss:
+                    if valid_loss <= valid_loss_threshold * config.improve_threshold:
+                        patience = max(patience, done_epoch*config.patient_increase)
+                        valid_loss_threshold = valid_loss
+                        logger.info('Update patience to {}'.format(patience))
+    
+                    if config.save_model:
+                        cur_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
+                        logger.info('!!New best model with loss = {},at {}.'.format(valid_loss, cur_time))
+                        th.save(model.state_dict(), os.path.join(config.saved_path, 'aux-{}-model'.format(done_epoch)))
+                        best_epoch = done_epoch
+                        saved_models.append(done_epoch)
+                        if len(saved_models) > last_n_model:
+                            remove_model = saved_models[0]
+                            saved_models = saved_models[-last_n_model:]
+                            os.remove(os.path.join(config.saved_path, "aux-{}-model".format(remove_model)))
+
+                    best_valid_loss = valid_loss
+
+                if done_epoch >= config.aux_max_epoch \
+                        or config.early_stop and patience <= done_epoch:
+                    if done_epoch < config.aux_max_epoch:
+                        logger.info('!!!!! Early stop due to run out of patience !!!!!')
+                    print('Best validation loss = %f' % (best_valid_loss, ))
+                    return best_epoch
+
+                # exit eval model
+                model.train()
+                train_loss.clear()
+                # logger.info('\n***** Epoch {}/{} *****'.format(done_epoch, config.aux_max_epoch))
+                sys.stdout.flush()
+
+def validate(model, val_data, config, batch_cnt=None, use_py=None):
+    model.eval()
+    val_data.epoch_init(config, shuffle=False, verbose=False)
+    losses = LossManager()
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+        if use_py is not None:
+            # loss = model(batch, mode=TEACH_FORCE, use_py=use_py)
+            loss = model(batch, mode=TEACH_FORCE)
+        else:
+            loss = model(batch, mode=TEACH_FORCE)
+
+        losses.add_loss(loss)
+        losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt))
+
+    valid_loss = losses.avg_loss()
+    # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info(losses.pprint(val_data.name))
+    logger.info('Total valid loss = {}'.format(valid_loss))
+    sys.stdout.flush()
+    return valid_loss
+
+def validate_offlinerl(agent, evaluator, val_data, config, mode="valid", batch_cnt=None, use_py=None):
+
+    """
+    get the true reward and q-reward of the dataset
+    """
+    agent.model.eval()
+    agent.critic.eval()
+    val_data.epoch_init(config, shuffle=False, verbose=False, fix_batch=True)
+    criterion = th.nn.MSELoss()
+    valid_rewards = []
+    valid_match = []
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+        
+        with th.no_grad():
+            actions, task_report, success, match = agent.run(batch, evaluator)
+
+        valid_rewards.append(success)
+        valid_match.append(success)
+
+    return valid_rewards, valid_match
+
+def debug(model, val_data, config, n_z=1, batch_cnt=None, use_py=None):
+    model.train()
+    val_data.epoch_init(config, shuffle=False, verbose=False)
+    losses = LossManager()
+
+    de_tknize = get_detokenize()
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+
+        batch_size = batch['bs'].shape[0] 
+        all_preds = defaultdict(list)
+        true_str = []
+        zs, joint_logpz = model.sample_z(batch, n_z=n_z)
+        for i in range(n_z):
+            # true_labels = batch['outputs'].data.numpy() # (batch_size, output_seq_len)
+
+            logprobs, outputs = model.decode_z(zs[i], config.batch_size, max_words=config.max_dec_len)
+            # move from GPU to CPU
+
+            for b_id in range(len(outputs)):
+                pred_str = get_sent_list(model.vocab, de_tknize, outputs[b_id]) 
+                all_preds[b_id].append(pred_str)
+                if i == 0:
+                    true_str.append(get_sent(model.vocab, de_tknize, batch['outputs'], b_id)) 
+
+        for i in range(batch_size):
+            print('True: {}'.format(true_str[i], ))
+            for n in range(n_z):
+                print('Pred{}: {}'.format(n, all_preds[i][n]))
+            print('='*30)
+
+def validate_mt(model, val_data, config, batch_cnt=None, use_py=None):
+    model.eval()
+    val_data.epoch_init(config, shuffle=False, verbose=False)
+    losses = LossManager()
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+        loss = model(batch, mode=TEACH_FORCE)
+        losses.add_loss(loss)
+        losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt))
+
+    valid_loss = losses.avg_loss()
+    # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info(losses.pprint(val_data.name))
+    logger.info('Total valid loss = {}'.format(valid_loss))
+    sys.stdout.flush()
+    return valid_loss
+
+def aux_validate(model, val_data, config, batch_cnt=None, use_py=None):
+    model.eval()
+    val_data.epoch_init(config, shuffle=False, verbose=False)
+    losses = LossManager()
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+        loss = model.forward_aux(batch, mode=TEACH_FORCE)
+
+        losses.add_loss(loss)
+        losses.add_backward_loss(model.model_sel_loss(loss, batch_cnt))
+
+    valid_loss = losses.avg_loss()
+    # print('Validation finished at {}'.format(datetime.now().strftime("%Y-%m-%d %H-%M-%S")))
+    logger.info(losses.pprint(val_data.name))
+    logger.info('Total aux valid loss = {}'.format(valid_loss))
+    sys.stdout.flush()
+    return valid_loss
+
+def generate(model, data, config, evaluator, num_batch, dest_f=None):
+    
+    def write(msg):
+        if msg is None or msg == '':
+            return
+        if dest_f is None:
+            print(msg)
+        else:
+            dest_f.write(msg + '\n')
+
+    model.eval()
+    de_tknize = get_detokenize()
+    data.epoch_init(config, shuffle=num_batch is not None, verbose=False)
+    evaluator.initialize()
+    logger.info('Generation: {} batches'.format(data.num_batch
+                                          if num_batch is None
+                                          else num_batch))
+    batch_cnt = 0
+    print_cnt = 0
+    while True:
+        batch_cnt += 1
+        batch = data.next_batch()
+        if batch is None or (num_batch is not None and data.ptr > num_batch):
+            break
+        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+
+        # move from GPU to CPU
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len)
+        true_labels = labels.data.numpy() # (batch_size, output_seq_len)
+
+        # get attention if possible
+        if config.dec_use_attn:
+            pred_attns = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE]]
+            pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes(0, 1) # (batch_size, max_dec_len, max_ctx_len)
+        else:
+            pred_attns = None
+        # get context
+        ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len)
+        ctx_len = batch.get('context_lens') # (batch_size, )
+
+        for b_id in range(pred_labels.shape[0]):
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) 
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            prev_ctx = ''
+            if ctx is not None:
+                ctx_str = []
+                for t_id in range(ctx_len[b_id]):
+                    temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False)
+                    # print('temp_str = %s' % (temp_str, ))
+                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
+                    ctx_str.append(temp_str)
+                ctx_str = '|'.join(ctx_str)[-200::]
+                prev_ctx = 'Source context: {}'.format(ctx_str)
+
+            evaluator.add_example(true_str, pred_str)
+
+            if num_batch is None or batch_cnt < 2:
+                print_cnt += 1
+                write('prev_ctx = %s' % (prev_ctx, ))
+                write('True: {}'.format(true_str, ))
+                write('Pred: {}'.format(pred_str, ))
+                write('='*30)
+                if num_batch is not None and print_cnt > 10:
+                    break
+
+    write(evaluator.get_report())
+    # write(evaluator.get_groundtruth_report())
+    write('Generation Done')
+
+def get_sent(vocab, de_tknize, data, b_id, stop_eos=True, stop_pad=True):
+    ws = []
+    for t_id in range(data.shape[1]):
+        w = vocab[int(data[b_id, t_id])]
+        if (stop_eos and w == EOS) or (stop_pad and w == PAD):
+            break
+        if w != PAD:
+            ws.append(w)
+
+    return de_tknize(ws)
+
+def get_sent_list(vocab, de_tknize, data, stop_eos=True, stop_pad=True):
+    ws = []
+    for t_id in range(len(data)):
+        w = vocab[data[t_id]]
+        if (stop_eos and w == EOS) or (stop_pad and w == PAD):
+            break
+        if w != PAD:
+            ws.append(w)
+
+    return de_tknize(ws)
+
+def most_frequent(List):
+    occ_count = Counter(List)
+    return occ_count.most_common(1)[0][0]
+
+def generate_with_name(model, data, config):
+    model.eval()
+    de_tknize = get_detokenize()
+    data.epoch_init(config, shuffle=False, verbose=False)
+    logger.info('Generation With Name: {} batches.'.format(data.num_batch))
+
+    from collections import defaultdict
+    res = defaultdict(dict)
+    while True:
+        batch = data.next_batch()
+        if batch is None:
+            break
+        keys, outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+        
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) # (batch_size, max_dec_len)
+        true_labels = labels.cpu().data.numpy() # (batch_size, output_seq_len)
+
+        for b_id in range(pred_labels.shape[0]):
+            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) 
+            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
+            dlg_name, dlg_turn = keys[b_id]
+            res[dlg_name][dlg_turn] = {'pred': pred_str, 'true': true_str}
+
+    return res
diff --git a/latent_dialog/metric.py b/latent_dialog/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..433c19e369fd7c8274cd1629a053137278f7079b
--- /dev/null
+++ b/latent_dialog/metric.py
@@ -0,0 +1,151 @@
+import time
+from collections import OrderedDict
+
+
+class NumericMetric(object):
+    """Base class for a numeric metric."""
+    def __init__(self):
+        self.k = 0
+        self.n = 0
+
+    def reset(self):
+        pass
+
+    def record(self, k, n=1):
+        self.k += k
+        self.n += n
+
+    def value(self):
+        self.n = max(1, self.n)
+        return 1.0 * self.k / self.n
+
+
+class AverageMetric(NumericMetric):
+    """Average."""
+    def show(self):
+        return '%.2f' % (1. * self.value())
+
+
+class PercentageMetric(NumericMetric):
+    """Percentage."""
+    def show(self):
+        return '%2.2f%%' % (100. * self.value())
+
+
+class TimeMetric(object):
+    """Time based metric."""
+    def __init__(self):
+        self.t = 0
+        self.n = 0
+
+    def reset(self):
+        self.last_t = time.time()
+
+    def record(self, n=1):
+        self.t += time.time() - self.last_t
+        self.n += 1
+
+    def value(self):
+        self.n = max(1, self.n)
+        return 1.0 * self.t / self.n
+
+    def show(self):
+        return '%.3fs' % (1. * self.value())
+
+
+class UniquenessMetric(object):
+    """Metric that evaluates the number of unique sentences."""
+    def __init__(self):
+        self.seen = set()
+
+    def reset(self):
+        pass
+
+    def record(self, sen):
+        self.seen.add(' '.join(sen))
+
+    def value(self):
+        return len(self.seen)
+
+    def show(self):
+        return str(self.value())
+
+
+class TextMetric(object):
+    """Text based metric."""
+    def __init__(self, text):
+        self.text = text
+        self.k = 0
+        self.n = 0
+
+    def reset(self):
+        pass
+
+    def value(self):
+        self.n = max(1, self.n)
+        return 1. * self.k / self.n
+
+    def show(self):
+        return '%.2f' % (1. * self.value())
+
+
+class NGramMetric(TextMetric):
+    """Metric that evaluates n grams."""
+    def __init__(self, text, ngram=-1):
+        super(NGramMetric, self).__init__(text)
+        self.ngram = ngram
+
+    def record(self, sen):
+        n = len(sen) if self.ngram == -1 else self.ngram
+        for i in range(len(sen) - n + 1):
+            self.n += 1
+            target = ' '.join(sen[i:i + n])
+            if self.text.find(target) != -1:
+                self.k += 1
+
+
+class MetricsContainer(object):
+    """A container that stores and updates several metrics."""
+    def __init__(self):
+        self.metrics = OrderedDict()
+
+    def _register(self, name, ty, *args, **kwargs):
+        name = name.lower()
+        assert name not in self.metrics
+        self.metrics[name] = ty(*args, **kwargs)
+
+    def register_average(self, name, *args, **kwargs):
+        self._register(name, AverageMetric, *args, **kwargs)
+
+    def register_time(self, name, *args, **kwargs):
+        self._register(name, TimeMetric, *args, **kwargs)
+
+    def register_percentage(self, name, *args, **kwargs):
+        self._register(name, PercentageMetric, *args, **kwargs)
+
+    def register_ngram(self, name, *args, **kwargs):
+        self._register(name, NGramMetric, *args, **kwargs)
+
+    def register_uniqueness(self, name, *args, **kwargs):
+        self._register(name, UniquenessMetric, *args, **kwargs)
+
+    def record(self, name, *args, **kwargs):
+        name = name.lower()
+        assert name in self.metrics
+        self.metrics[name].record(*args, **kwargs)
+
+    def reset(self):
+        for m in self.metrics.values():
+            m.reset()
+
+    def value(self, name):
+        return self.metrics[name].value()
+
+    def show(self):
+        return ' '.join(['%s=%s' % (k, v.show()) for k, v in self.metrics.iteritems()])
+
+    def dict(self):
+        d = OrderedDict()
+        for k, v in self.metrics.items():
+            d[k] = v.show()
+        return d
diff --git a/latent_dialog/models_task.py b/latent_dialog/models_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08c7e1bdd8f597a7c7409d24351e16f0dfaf983
--- /dev/null
+++ b/latent_dialog/models_task.py
@@ -0,0 +1,3382 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from latent_dialog.base_models import BaseModel, frange_cycle_linear
+from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS
+from latent_dialog.utils import INT, FLOAT, LONG, Pack, cast_type
+from latent_dialog.enc2dec.encoders import RnnUttEncoder
+from latent_dialog.enc2dec.decoders import DecoderRNN, GEN, TEACH_FORCE
+from latent_dialog.criterions import NLLEntropy, CatKLLoss, Entropy, NormKLLoss, GaussianEntropy
+from latent_dialog import nn_lib
+import numpy as np
+import pdb
+import json
+
+
+class SysPerfectBD2Word(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysPerfectBD2Word, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        self.policy = nn.Sequential(nn.Linear(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                              config.dec_cell_size), nn.Tanh(), nn.Dropout(config.dropout))
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=self.utt_encoder.output_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            attn_context = enc_outs
+        else:
+            attn_context = None
+
+        # create decoder initial states
+        dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0)
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0)
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            return ret_dict, labels
+        if return_latent:
+            return Pack(nll=self.nll(dec_outputs, labels),
+                        latent_action=dec_init_state)
+        else:
+            return Pack(nll=self.nll(dec_outputs, labels))
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            attn_context = enc_outs
+        else:
+            attn_context = None
+
+        # create decoder initial states
+        dec_init_state = self.policy(th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=temp)
+        return logprobs, outs
+
+class SysPerfectBD2Cat(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysPerfectBD2Cat, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.k_size = config.k_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior
+        self.contextual_posterior = config.contextual_posterior
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        if "policy_dropout" in config and config.policy_dropout:
+            if "policy_dropout_rate" in config:
+                self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                  config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval)
+            else:
+                self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                  config.y_size, config.k_size, is_lstm=False)
+
+        else:
+            self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                              config.y_size, config.k_size, is_lstm=False)
+        self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False)
+        self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu)
+        if not self.simple_posterior:
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                                   config.y_size, config.k_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+        if "state_for_decoding" not in self.config:
+            self.state_for_decoding = False
+        else:
+            self.state_for_decoding = self.config.state_for_decoding
+
+        if self.state_for_decoding:
+            dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size
+        else:
+            dec_hidden_size = config.dec_cell_size
+
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=dec_hidden_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty":
+            req_tokens = []
+            for d in REQ_TOKENS.keys():
+                req_tokens.extend(REQ_TOKENS[d])
+            nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.vocab]))
+            print("req tokens assigned with special weights")
+            if config.use_gpu:
+                nll_weight = nll_weight.cuda()
+            self.nll.set_weight(nll_weight)
+
+        self.cat_kl_loss = CatKLLoss()
+        self.entropy_loss = Entropy()
+        self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size))
+        self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0))
+        if "kl_annealing" in self.config and config.kl_annealing=="cyclical":
+            self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10)    
+        else:
+            self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+
+        if self.use_gpu:
+            self.log_uniform_y = self.log_uniform_y.cuda()
+            self.eye = self.eye.cuda()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if isinstance(self.beta, float):
+            beta = self.beta
+        else:
+            if batch_cnt == None:
+                beta = self.beta[-1]
+            else:
+                beta = self.beta[int(batch_cnt)]
+
+
+        if self.simple_posterior or "kl_annealing" in self.config:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        if self.config.use_mi:
+            total_loss += (loss.b_pr * beta)
+
+        if self.config.use_diversity:
+            total_loss += loss.diversity
+
+        return total_loss
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            log_py = self.log_uniform_y
+        else:
+            logits_py, log_py = self.c2z(enc_last) # p(z|c)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or (use_py is not None and use_py is True):
+                sample_y = self.gumbel_connector(logits_py, hard=True)
+            else:
+                sample_y = self.gumbel_connector(logits_qy, hard=False)
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+
+            result['diversity'] = th.mean(p)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True)
+            return result
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_py, log_qy = self.c2z(enc_last)
+        else:
+            logits_py, log_qy = self.c2z(enc_last)
+
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        # if np.random.rand() < epsilon: # greedy exploration
+            # print("randomly sampling latent")
+            # idx = th.multinomial(th.cuda.FloatTensor(qy.shape).uniform_(), 1)
+        # else: # normal latent sampling
+        idx = th.multinomial(qy, 1).detach()
+        
+        logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+        sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+        sample_y.scatter_(1, idx, 1.0)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+        return logprobs, outs, joint_logpz, sample_y
+    
+    def sample_z(self, data_feed, n_z=1, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_py, log_qy = self.c2z(enc_last)
+        else:
+            logits_py, log_qy = self.c2z(enc_last)
+
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        zs = []
+        logpzs = []
+        for i in range(n_z):
+            idx = th.multinomial(qy, 1).detach()
+            logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+            joint_logpz = th.sum(logprob_sample_z, dim=1)
+            sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+            sample_y.scatter_(1, idx, 1.0)
+
+            zs.append(sample_y)
+            logpzs.append(joint_logpz)
+
+        
+        return th.stack(zs), th.stack(logpzs)
+
+    def categorical_logprob(logits_py, sample_z):
+        pdb.set_trace()
+        py = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_py = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        idx = th.multinomial(py, 1).detach()
+        logprob_sample_z = log_py.gather(1, idx).view(-1, self.y_size)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+
+        return joint_logpz
+
+    def encode_state(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        return enc_last
+
+    def get_z_via_rg(self, data_feed, hard=False):
+        enc_last = self.encode_state(data_feed)
+        logits_qy, log_qy = self.c2z(enc_last)
+        aux_sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+        return aux_sample_z, log_qy, logits_qy
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        
+        if data_feed:
+            ctx_lens = data_feed['context_lens']  # (batch_size, )
+            short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+            bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+            db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+ 
+        # pack attention context
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        if self.config.dec_use_attn:
+           z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+           attn_context = []
+           temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+           for z_id in range(self.y_size):
+               attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+           attn_context = th.cat(attn_context, dim=1)
+           dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+           dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+           attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+            # create decoder initial states
+            enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+
+        #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0)
+        #attn_context = self.np2var(attn_context, FLOAT)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # has to be forward_rl because we don't have the golden target
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                temp=temp)
+        return logprobs, outs
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+class SysEncodedBD2Cat(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysEncodedBD2Cat, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.k_size = config.k_size
+        self.y_size = config.y_size
+        self.config = config
+        self.simple_posterior = config.simple_posterior
+        self.contextual_posterior = config.contextual_posterior
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        if config.use_metadata_for_decoding:
+            self.metadata_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                             embedding_dim=int(config.embed_size / 2),
+                                             feat_size=0,
+                                             goal_nhid=0,
+                                             rnn_cell=config.utt_rnn_cell,
+                                             utt_cell_size=int(config.dec_cell_size / 2),
+                                             num_layers=config.num_layers,
+                                             input_dropout_p=config.dropout,
+                                             output_dropout_p=config.dropout,
+                                             bidirectional=config.bi_utt_cell,
+                                             variable_lengths=False,
+                                             use_attn=config.enc_use_attn,
+                                             embedding=self.embedding)
+
+        if "policy_dropout" in config and config.policy_dropout:
+            if "policy_dropout_rate" in config:
+                self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size,
+                                  config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval)
+            else:
+                self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size,
+                                  config.y_size, config.k_size, is_lstm=False)
+
+        else:
+            self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size,
+                                              config.y_size, config.k_size, is_lstm=False)
+        self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False)
+        self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu)
+        if not self.simple_posterior:
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size * 2,
+                                                   config.y_size, config.k_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+        if "state_for_decoding" not in self.config:
+            self.state_for_decoding = False
+        else:
+            self.state_for_decoding = self.config.state_for_decoding
+
+        dec_hidden_size = config.dec_cell_size
+        if config.use_metadata_for_decoding:
+            if "metadata_to_decoder" not in config or config.metadata_to_decoder == "concat":
+                dec_hidden_size += self.metadata_encoder.output_size
+        if self.state_for_decoding:
+            dec_hidden_size += self.utt_encoder.output_size
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=dec_hidden_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty":
+            req_tokens = []
+            for d in REQ_TOKENS.keys():
+                req_tokens.extend(REQ_TOKENS[d])
+            nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.vocab]))
+            print("req tokens assigned with special weights")
+            if config.use_gpu:
+                nll_weight = nll_weight.cuda()
+            self.nll.set_weight(nll_weight)
+
+        self.cat_kl_loss = CatKLLoss()
+        self.entropy_loss = Entropy()
+        self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size))
+        self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0))
+
+        if "kl_annealing" in self.config and config.kl_annealing=="cyclical":
+            self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10)    
+        else:
+            self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+
+
+        if self.use_gpu:
+            self.log_uniform_y = self.log_uniform_y.cuda()
+            self.eye = self.eye.cuda()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if isinstance(self.beta, float):
+            beta = self.beta
+        else:
+            if batch_cnt == None:
+                beta = self.beta[-1]
+            else:
+                beta = self.beta[int(batch_cnt % self.config.n_iter)]
+               
+        if self.simple_posterior or "kl_annealing" in self.config:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        if self.config.use_mi:
+            total_loss += (loss.b_pr * beta)
+
+        if self.config.use_diversity:
+            total_loss += loss.diversity
+
+        return total_loss
+
+    def extract_short_ctx(self, data_feed):
+        utts = []
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        context = data_feed['contexts']
+        bs = data_feed['bs']
+        db = data_feed['db']
+        if not isinstance(bs, list):
+            bs = data_feed['bs'].tolist()
+            db = data_feed['db'].tolist()
+
+        for b_id in range(len(context)):
+            utt = []
+            for t_id in range(ctx_lens[b_id]):
+                utt.extend(context[b_id][t_id])
+            try:
+                utt.extend(bs[b_id] + db[b_id])
+            except:
+                pdb.set_trace()
+            utts.append(self.pad_to(self.config.max_utt_len, utt, do_pad=True))
+        return np.array(utts)
+    
+    def extract_metadata(self, data_feed):
+        utts = []
+        bs = data_feed['bs']
+        db = data_feed['db']
+        if not isinstance(bs, list):
+            bs = data_feed['bs'].tolist()
+            db = data_feed['db'].tolist()
+
+        for b_id in range(len(bs)):
+            utt = []
+            if "metadata_db_only" in self.config and self.config.metadata_db_only:
+                utt.extend(db[b_id])
+            else:
+                utt.extend(bs[b_id] + db[b_id])
+            utts.append(self.pad_to(self.config.max_metadata_len, utt, do_pad=True))
+        return np.array(utts)
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        enc_last = utt_summary.unsqueeze(1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            log_py = self.log_uniform_y
+        else:
+            logits_py, log_py = self.c2z(enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                logits_qy, log_qy = self.xc2z(th.cat([enc_last.squeeze(), x_h.squeeze(1)], dim=1))
+            else:
+                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or (use_py is not None and use_py is True):
+                sample_y = self.gumbel_connector(logits_py, hard=True)
+            else:
+                sample_y = self.gumbel_connector(logits_qy, hard=False)
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+        
+        if self.config.use_metadata_for_decoding:
+            metadata = self.np2var(self.extract_metadata(data_feed), LONG) 
+            metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1))
+            if "metadata_to_decoder" in self.config:
+                if self.config.metadata_to_decoder == "add":
+                    dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1)
+                elif self.config.metadata_to_decoder == "avg":
+                    dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0)
+                else:
+                    dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2)
+            else:
+                dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2)
+
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2)
+        
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,   # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) # averaged over all samples
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+
+            result['diversity'] = th.mean(p)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True)
+            return result
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        enc_last = utt_summary.unsqueeze(1)
+        # create decoder initial states
+        logits_py, log_qy = self.c2z(enc_last)
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+        idx = th.multinomial(qy, 1).detach()
+        
+        logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+        sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+        sample_y.scatter_(1, idx, 1.0)
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+        
+        if self.config.use_metadata_for_decoding:
+            metadata = self.np2var(self.extract_metadata(data_feed), LONG) 
+            metadata_summary, _, metadata_enc_outs = self.metadata_encoder(metadata.unsqueeze(1))
+            if "metadata_to_decoder" in self.config:
+                if self.config.metadata_to_decoder == "add":
+                    dec_init_state = dec_init_state + metadata_summary.view(1, batch_size, -1)
+                elif self.config.metadata_to_decoder == "avg":
+                    dec_init_state = th.mean(th.stack((dec_init_state, metadata_summary.view(1, batch_size, -1))), dim=0)
+                else:
+                    dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2)
+            else:
+                dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2)
+        
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2)
+ 
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+        return logprobs, outs, joint_logpz, sample_y
+    
+    def sample_z(self, data_feed, n_z=1, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed), LONG) # contains bs and db
+        # metadata = self.np2var(self.extract_metadata(data_feed), LONG) 
+        # out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        # metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1))
+
+
+        # create decoder initial states
+        enc_last = utt_summary.unsqueeze(1)
+        if self.simple_posterior:
+            logits_py, log_qy = self.c2z(enc_last)
+        else:
+            logits_py, log_qy = self.c2z(enc_last)
+
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        zs = []
+        logpzs = []
+        for i in range(n_z):
+            idx = th.multinomial(qy, 1).detach()
+            logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+            joint_logpz = th.sum(logprob_sample_z, dim=1)
+            sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+            sample_y.scatter_(1, idx, 1.0)
+
+            zs.append(sample_y)
+            logpzs.append(joint_logpz)
+
+        
+        return th.stack(zs), th.stack(logpzs)
+
+    def decode_z(self, sample_y, batch_size, max_words=None, temp=1.0, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        # pack attention context
+        metadata = self.np2var(self.extract_metadata(data_feed), LONG) 
+        metadata_summary, _, metadata_enc_outs = self.utt_encoder(metadata.unsqueeze(1))
+
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        if self.config.dec_use_attn:
+           z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+           attn_context = []
+           temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+           for z_id in range(self.y_size):
+               attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+           attn_context = th.cat(attn_context, dim=1)
+           dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+           dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+           attn_context = None
+
+        dec_init_state = th.cat((dec_init_state, metadata_summary.view(1, batch_size, -1)), dim=2)
+
+        if self.config.use_metadata_for_decoding:
+            raise NotImplementedError
+
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, th.transpose(enc_last.squeeze(1), 1, 0)], dim=2)
+ 
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # has to be forward_rl because we don't have the golden target
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                temp=temp)
+        return logprobs, outs
+
+class SysAECat(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysAECat, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        # self.act_size = corpus.act_size
+        self.k_size = config.k_size
+        self.y_size = config.y_size
+        self.simple_posterior = True # minimize kl to uninformed prior instead of dist conditioned by context
+        self.contextual_posterior = False # does not use context cause AE task
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+        
+        if "ae_zero_padding" in self.config and self.config.ae_zero_padding:
+            # self.use_metadata = self.config.use_metadata
+            self.ae_zero_padding = self.config.ae_zero_padding
+            c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size
+        else:
+            # self.use_metadata = False
+            self.ae_zero_padding = False
+            c2z_input_size = self.utt_encoder.output_size
+
+
+        if "policy_dropout" in config and config.policy_dropout:
+            self.c2z = nn_lib.Hidden2DiscretewDropout(c2z_input_size,
+                                              config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval)
+        else:
+            self.c2z = nn_lib.Hidden2Discrete(c2z_input_size,
+                                              config.y_size, config.k_size, is_lstm=False)
+
+        self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False)
+        self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu)
+        # if not self.simple_posterior: #q(z|x,c)
+            # if self.contextual_posterior:
+                # # x, c, BS, and DB
+                # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size,
+                                                   # config.y_size, config.k_size, is_lstm=False)
+            # else:
+                # self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty":
+            req_tokens = []
+            for d in REQ_TOKENS.keys():
+                req_tokens.extend(REQ_TOKENS[d])
+            nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.vocab]))
+            print("req tokens assigned with special weights")
+            if config.use_gpu:
+                nll_weight = nll_weight.cuda()
+            self.nll.set_weight(nll_weight)
+
+
+        self.cat_kl_loss = CatKLLoss()
+        self.entropy_loss = Entropy()
+        self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size))
+        self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0))
+        self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+        if self.use_gpu:
+            self.log_uniform_y = self.log_uniform_y.cuda()
+            self.eye = self.eye.cuda()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.simple_posterior:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += self.beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        if self.config.use_mi:
+            total_loss += (loss.b_pr * self.beta)
+
+        if self.config.use_diversity:
+            total_loss += loss.diversity
+
+        return total_loss
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+        
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        if self.ae_zero_padding:
+            enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1)
+        else:
+            enc_last = utt_summary.squeeze(1)
+
+
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            log_py = self.log_uniform_y
+        # else:
+            # logits_py, log_py = self.c2z(enc_last)
+            # # encode response and use posterior to find q(z|x, c)
+            # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            # if self.contextual_posterior:
+                # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            # else:
+                # logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # # use prior at inference time, otherwise use posterior
+            # if mode == GEN or (use_py is not None and use_py is True):
+                # sample_y = self.gumbel_connector(logits_py, hard=False)
+            # else:
+                # sample_y = self.gumbel_connector(logits_qy, hard=True)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+
+            result['diversity'] = th.mean(p)
+            result['nll'] = self.nll(dec_outputs, labels)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True)
+            return result
+
+    def get_z_via_vae(self, data_feed, hard=False):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.utt_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        logits_qy, log_qy = self.c2z(aux_enc_last)
+        aux_sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+        return aux_sample_z, logits_qy, log_qy
+
+class SysMTCat(BaseModel):
+    def __init__(self, corpus, config): 
+        super(SysMTCat, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        # self.act_size = corpus.act_size
+        self.k_size = config.k_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context
+        self.contextual_posterior = config.contextual_posterior # does not use context cause AE task
+        if "shared_train" in config:
+            self.shared_train = config.shared_train
+        else:
+            self.shared_train = False
+
+        if "use_aux_kl" in config:
+            self.use_aux_kl = config.use_aux_kl
+        else:
+            self.use_aux_kl = False
+
+        self.embedding = None
+        self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+
+        self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                          config.y_size, config.k_size, is_lstm=False)
+        self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False)
+        self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu)
+        
+        if not self.simple_posterior: #q(z|x,c)
+            if self.contextual_posterior:
+                # x, c, BS, and DB
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size,
+                                                   config.y_size, config.k_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        self.cat_kl_loss = CatKLLoss()
+        self.entropy_loss = Entropy()
+        self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size))
+        self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0))
+        self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+        self.aux_pi_beta = self.config.aux_pi_beta if hasattr(self.config, 'aux_pi_beta') else 1.0
+        if self.use_gpu:
+            self.log_uniform_y = self.log_uniform_y.cuda()
+            self.eye = self.eye.cuda()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.shared_train:
+            if "selective_fine_tune" in self.config and self.config.selective_fine_tune:
+                total_loss = loss.nll + self.config.beta * loss.aux_pi_kl
+            else:
+                total_loss = loss.nll + loss.ae_nll + self.config.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl 
+        else:
+            if self.simple_posterior:
+                total_loss = loss.nll
+                if self.config.use_pr > 0.0:
+                    total_loss += self.config.beta * loss.pi_kl
+            else:
+                total_loss = loss.nll + loss.pi_kl
+
+
+        return total_loss
+    
+    def encode_state(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        return enc_last
+
+    def encode_action(self, data_feed):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = aux_utt_summary.squeeze(1)
+
+        return aux_enc_last
+
+    def get_z_via_vae(self, data_feed, hard=False):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        logits_qy, log_qy = self.c2z(aux_enc_last)
+        aux_sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+        return aux_sample_z
+
+    def get_z_via_rg(self, data_feed, hard=False):
+        enc_last = self.encode_state(data_feed)
+        logits_qy, log_qy = self.c2z(enc_last)
+        aux_sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+        return aux_sample_z, log_qy, logits_qy
+
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        
+        # if data_feed:
+            # ctx_lens = data_feed['context_lens']  # (batch_size, )
+            # short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+            # bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+            # db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+ 
+        # pack attention context
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        if self.config.dec_use_attn:
+           z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+           attn_context = []
+           temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+           for z_id in range(self.y_size):
+               attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+           attn_context = th.cat(attn_context, dim=1)
+           dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+           dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+           attn_context = None
+
+        # decode
+        # if self.state_for_decoding:
+            # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+            # # create decoder initial states
+            # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+            # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+
+        #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0)
+        #attn_context = self.np2var(attn_context, FLOAT)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # has to be forward_rl because we don't have the golden target
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                temp=temp)
+        return logprobs, outs
+
+    def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        
+        # how to use z, alone or in combination with bs and db
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=False)
+            sample_y_discrete = self.gumbel_connector(logits_qy, hard=True)
+            log_py = self.log_uniform_y
+        # else:
+            # logits_py, log_py = self.c2z(enc_last)
+            # # encode response and use posterior to find q(z|x, c)
+            # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            # if self.contextual_posterior:
+                # logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            # else:
+                # logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # # use prior at inference time, otherwise use posterior
+            # if mode == GEN or (use_py is not None and use_py is True):
+                # sample_y = self.gumbel_connector(logits_py, hard=False)
+            # else:
+                # sample_y = self.gumbel_connector(logits_qy, hard=True)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+            result['diversity'] = th.mean(p)
+            result['nll'] = self.nll(dec_outputs, labels)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            return result
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        short_target_utts = self.np2var(data_feed['outputs'], LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            if self.shared_train:
+                aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last)
+                aux_sample_y = self.gumbel_connector(aux_logits_qy, hard=mode==GEN)
+
+            log_py = self.log_uniform_y
+        else:
+            logits_py, log_py = self.c2z(enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or (use_py is not None and use_py is True):
+                sample_y = self.gumbel_connector(logits_py, hard=False)
+            else:
+                sample_y = self.gumbel_connector(logits_qy, hard=True)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+
+        if self.shared_train:
+            if self.config.dec_use_attn:
+                aux_z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+                aux_attn_context = []
+                aux_temp_sample_y = aux_sample_y.view(-1, self.config.y_size, self.config.k_size)
+                for z_id in range(self.y_size):
+                    aux_attn_context.append(th.mm(aux_temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+                aux_attn_context = th.cat(aux_attn_context, dim=1)
+                aux_dec_init_state = th.sum(aux_attn_context, dim=1).unsqueeze(0)
+            else:
+                aux_dec_init_state = self.z_embedding(aux_sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+                aux_attn_context = None
+            if self.config.dec_rnn_cell == 'lstm':
+                aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state])
+
+
+
+        # decode
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            if self.shared_train:
+                ae_dec_outputs, ae_dec_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=aux_dec_init_state,  # tuple: (h, c)
+                                                               attn_context=aux_attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+                result['ae_nll'] = self.nll(ae_dec_outputs, labels)
+                aux_pi_kl = self.cat_kl_loss(log_qy, aux_log_qy, batch_size, unit_average=True)
+                aux_kl = self.cat_kl_loss(aux_log_qy, log_py, batch_size, unit_average=True)
+                result['aux_pi_kl'] = aux_pi_kl
+                result['aux_kl'] = aux_kl
+
+
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+            result['diversity'] = th.mean(p)
+            result['nll'] = self.nll(dec_outputs, labels)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            return result
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+    
+class SysActZCat(BaseModel):
+    def __init__(self, corpus, config): 
+        super(SysActZCat, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        # self.act_size = corpus.act_size
+        self.k_size = config.k_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior # minimize kl to uninformed prior instead of dist conditioned by context
+        self.contextual_posterior = config.contextual_posterior # does not use context cause AE task
+
+        if "use_aux_kl" in config:
+            self.use_aux_kl = config.use_aux_kl
+        else:
+            self.use_aux_kl = False
+        
+        if "use_aux_c2z" in config:
+            self.use_aux_c2z = config.use_aux_c2z
+        else:
+            self.use_aux_c2z = False
+
+
+
+        self.embedding = None
+        self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        # if "policy_dropout" in config and config.policy_dropout:
+            # self.c2z = nn_lib.Hidden2DiscretewDropout(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                              # config.y_size, config.k_size, is_lstm=False, p_dropout=config.policy_dropout_rate, dropout_on_eval=config.dropout_on_eval)
+        # else:
+        self.c2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                          config.y_size, config.k_size, is_lstm=False)
+        if self.use_aux_c2z:
+                self.aux_c2z = nn_lib.Hidden2Discrete(self.aux_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+
+        self.z_embedding = nn.Linear(self.y_size * self.k_size, config.dec_cell_size, bias=False)
+        self.gumbel_connector = nn_lib.GumbelConnector(config.use_gpu)
+        
+        if not self.simple_posterior: #q(z|x,c)
+            if self.contextual_posterior:
+                # x, c, BS, and DB
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size,
+                                                   config.y_size, config.k_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Discrete(self.utt_encoder.output_size, config.y_size, config.k_size, is_lstm=False)
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty":
+            req_tokens = []
+            for d in REQ_TOKENS.keys():
+                req_tokens.extend(REQ_TOKENS[d])
+            nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.vocab]))
+            print("req tokens assigned with special weights")
+            if config.use_gpu:
+                nll_weight = nll_weight.cuda()
+            self.nll.set_weight(nll_weight)
+
+        self.cat_kl_loss = CatKLLoss()
+        self.entropy_loss = Entropy()
+        self.log_uniform_y = Variable(th.log(th.ones(1) / config.k_size))
+        self.eye = Variable(th.eye(self.config.y_size).unsqueeze(0))
+        self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+        if self.use_gpu:
+            self.log_uniform_y = self.log_uniform_y.cuda()
+            self.eye = self.eye.cuda()
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.simple_posterior:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += self.beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        if self.config.use_mi:
+            total_loss += (loss.b_pr * self.beta)
+
+        if self.config.use_diversity:
+            total_loss += loss.diversity
+
+        return total_loss
+    
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        short_target_utts = self.np2var(data_feed['outputs'], LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            if self.use_aux_c2z:
+                aux_logits_qy, aux_log_qy = self.aux_c2z(aux_utt_summary.squeeze(1))
+            else:
+                aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last)
+
+            aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            log_py = aux_log_qy
+        else: 
+            logits_py, log_py = self.c2z(enc_last)
+            aux_logits_qy, aux_log_qy = self.c2z(aux_enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))
+
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or (use_py is not None and use_py is True):
+                sample_y = self.gumbel_connector(logits_py, hard=True)
+            else:
+                sample_y = self.gumbel_connector(logits_qy, hard=False)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               # (batch_size, response_size-1)
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               # (batch_size, max_ctx_len, ctx_cell_size)
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+            result['diversity'] = th.mean(p)
+            result['nll'] = self.nll(dec_outputs, labels)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True)
+            return result
+
+    def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_qy, log_qy = self.c2z(enc_last)
+            sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN)
+            log_py = self.log_uniform_y
+        # else:
+            # p_mu, p_logvar = self.c2z(enc_last)
+            # # encode response and use posterior to find q(z|x, c)
+            # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            # if self.contextual_posterior:
+                # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            # else:
+                # q_mu, q_logvar = self.xc2z(x_h.squeeze(1))
+
+            # # use prior at inference time, otherwise use posterior
+            # if mode == GEN or use_py:
+                # sample_z = self.gauss_connector(p_mu, p_logvar)
+            # else:
+                # sample_z = self.gauss_connector(q_mu, q_logvar)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_y
+            ret_dict['log_qy'] = log_qy
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            # regularization qy to be uniform
+            avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size))
+            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
+            b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True)
+            mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True)
+            pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True)
+            q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size)  # b
+            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)
+
+            result['pi_kl'] = pi_kl
+            result['diversity'] = th.mean(p)
+            result['nll'] = self.nll(dec_outputs, labels)
+            result['b_pr'] = b_pr
+            result['mi'] = mi
+            result['pi_entropy'] = self.entropy_loss(log_qy, unit_average=True)
+            return result
+
+    def get_z_via_vae(self, data_feed, hard=False):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        logits_qy, log_qy = self.c2z(aux_enc_last)
+        aux_sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+
+        return aux_sample_z, logits_qy, log_qy
+
+    def encode_state(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        return enc_last
+
+    def encode_action(self, data_feed):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = aux_utt_summary.squeeze(1)
+
+        return aux_enc_last
+
+    def categorical_logprob(self, logits_py, sample_z, temp=0.1):
+        py = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_py = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        idx = th.multinomial(sample_z, 1).detach()
+        logprob_sample_z = log_py.gather(1, idx).view(-1, self.y_size)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+
+        return joint_logpz
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_py, log_qy = self.c2z(enc_last)
+        else:
+            logits_py, log_qy = self.c2z(enc_last)
+
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+        idx = th.multinomial(qy, 1).detach()
+        logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+        sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+        sample_y.scatter_(1, idx, 1.0)
+
+        # pack attention context
+        if self.config.dec_use_attn:
+            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+            attn_context = []
+            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+            for z_id in range(self.y_size):
+                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+            attn_context = th.cat(attn_context, dim=1)
+            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+            attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                  temp=0.1)
+        return logprobs, outs, joint_logpz, sample_y
+    
+    def sample_z(self, data_feed, n_z=1, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        if self.simple_posterior:
+            logits_py, log_qy = self.c2z(enc_last)
+        else:
+            logits_py, log_qy = self.c2z(enc_last)
+
+        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
+        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
+
+        zs = []
+        logpzs = []
+        for i in range(n_z):
+            idx = th.multinomial(qy, 1).detach()
+            logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+            joint_logpz = th.sum(logprob_sample_z, dim=1)
+            sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+            sample_y.scatter_(1, idx, 1.0)
+
+            zs.append(sample_y)
+            logpzs.append(joint_logpz)
+
+        
+        return th.stack(zs), th.stack(logpzs)
+
+    def get_z_via_rg(self, data_feed, hard=False):
+        enc_last = self.encode_state(data_feed)
+        logits_qy, log_qy = self.c2z(enc_last)
+        sample_z = self.gumbel_connector(logits_qy, hard=hard)
+        
+        return sample_z, log_qy, logits_qy
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        # pack attention context
+
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        if self.config.dec_use_attn:
+           z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0)
+           attn_context = []
+           temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
+           for z_id in range(self.y_size):
+               attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
+           attn_context = th.cat(attn_context, dim=1)
+           dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
+        else:
+           dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))
+           attn_context = None
+
+        
+        # decode
+
+        #dec_init_state = self.np2var(dec_init_state, FLOAT).unsqueeze(0)
+        #attn_context = self.np2var(attn_context, FLOAT)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # has to be forward_rl because we don't have the golden target
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                temp=temp)
+
+        return logprobs, outs
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+class SysPerfectBD2Gauss(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysPerfectBD2Gauss, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior
+        if "contextual posterior" in config: 
+            self.contextual_posterior = config.contextual_posterior
+        else:
+            self.contextual_posterior = True # default value is true, i.e. q(z|x,c)
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                          config.y_size, is_lstm=False)
+        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
+        self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size)
+        if not self.simple_posterior:
+            # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                               # config.y_size, is_lstm=False)
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                                   config.y_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False)
+
+        if "state_for_decoding" not in self.config:
+            self.state_for_decoding = False
+        else:
+            self.state_for_decoding = self.config.state_for_decoding
+
+        if self.state_for_decoding:
+            dec_hidden_size = config.dec_cell_size + self.utt_encoder.output_size + self.db_size + self.bs_size
+        else:
+            dec_hidden_size = config.dec_cell_size
+
+
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=dec_hidden_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+
+        self.gauss_kl = NormKLLoss(unit_average=True)
+        self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.simple_posterior:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += self.config.beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        return total_loss
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            p_mu, p_logvar = self.zero, self.zero
+        else:
+            p_mu, p_logvar = self.c2z(enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                q_mu, q_logvar = self.xc2z(x_h.squeeze(1))
+
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or use_py:
+                sample_z = self.gauss_connector(p_mu, p_logvar)
+            else:
+                sample_z = self.gauss_connector(q_mu, q_logvar)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            result['pi_kl'] = pi_kl
+            result['nll'] = self.nll(dec_outputs, labels)
+            return result
+
+    def gaussian_logprob(self, mu, logvar, sample_z):
+        var = th.exp(logvar)
+        constant = float(-0.5 * np.log(2*np.pi))
+        logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var)
+        return logprob
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        p_mu, p_logvar = self.c2z(enc_last)
+
+        sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach()
+        logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+        return logprobs, outs, joint_logpz, sample_z
+
+    def get_z_via_rg(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        q_mu, q_logvar = self.c2z(enc_last)
+
+        sample_z = self.gauss_connector(q_mu, q_logvar)
+        
+        return sample_z, q_mu, q_logvar
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        
+        if data_feed:
+            ctx_lens = data_feed['context_lens']  # (batch_size, )
+            short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+            bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+            db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+ 
+        # pack attention context
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        dec_init_state = self.z_embedding(sample_y.unsqueeze(0))
+        if (dec_init_state != dec_init_state).any():
+            pdb.set_trace()
+        attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+            # create decoder initial states
+            enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+
+        return logprobs, outs
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+class SysAEGauss(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysAEGauss, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        # self.act_size = corpus.act_size
+        self.y_size = config.y_size
+        self.simple_posterior = True
+        self.contextual_posterior = False
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        # if "use_metadata" in self.config and self.config.use_metadata:
+        if "ae_zero_padding" in self.config and self.config.ae_zero_padding:
+            # self.use_metadata = self.config.use_metadata
+            self.ae_zero_padding = self.config.ae_zero_padding
+            c2z_input_size = self.utt_encoder.output_size + self.db_size + self.bs_size
+        else:
+            # self.use_metadata = False
+            self.ae_zero_padding = False
+            c2z_input_size = self.utt_encoder.output_size
+
+        self.c2z = nn_lib.Hidden2Gaussian(c2z_input_size,
+                                          config.y_size, is_lstm=False)
+
+        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
+       
+        self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size)
+        if not self.simple_posterior:
+            # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                               # config.y_size, is_lstm=False)
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                                   config.y_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False)
+
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+
+        self.gauss_kl = NormKLLoss(unit_average=True)
+        self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
+
+
+        if "kl_annealing" in self.config and config.kl_annealing=="cyclical":
+            if "n_iter" not in self.config:
+                config['n_iter'] = config.ckpt_step  * config.max_epoch
+            self.beta = frange_cycle_linear(config.n_iter, start=self.config.beta_start, stop=self.config.beta_end, n_cycle=10)    
+        else:
+            self.beta = self.config.beta if hasattr(self.config, 'beta') else 0.0
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if isinstance(self.beta, float):
+            beta = self.beta
+        else:
+            if batch_cnt == None:
+                beta = self.beta[-1]
+            else:
+                beta = self.beta[int(batch_cnt)]
+
+
+        if self.simple_posterior or "kl_annealing" in self.config:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        return total_loss
+
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+        # print(short_ctx_utts[0])
+        # print(out_utts[0])
+
+
+        # create decoder initial states
+        # if self.use_metadata:
+        if self.ae_zero_padding:
+            enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1)
+        else:
+            enc_last = utt_summary.squeeze(1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            p_mu, p_logvar = self.zero, self.zero
+        # else:
+            # p_mu, p_logvar = self.c2z(enc_last)
+            # # encode response and use posterior to find q(z|x, c)
+            # x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            # if self.contextual_posterior:
+                # q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            # else:
+                # q_mu, q_logvar = self.xc2z(x_h.squeeze(1))
+
+            # # use prior at inference time, otherwise use posterior
+            # if mode == GEN or use_py:
+                # sample_z = self.gauss_connector(p_mu, p_logvar)
+            # else:
+                # sample_z = self.gauss_connector(q_mu, q_logvar)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            # print(labels[0])
+            # print("========")
+            # pdb.set_trace()
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            result['pi_kl'] = pi_kl
+            result['nll'] = self.nll(dec_outputs, labels)
+            return result
+
+    def get_z_via_vae(self, responses):
+        batch_size = responses.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.utt_encoder(responses.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+        aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar)
+        
+        return aux_sample_z, aux_q_mu, aux_q_logvar
+
+    def gaussian_logprob(self, mu, logvar, sample_z):
+        var = th.exp(logvar)
+        constant = float(-0.5 * np.log(2*np.pi))
+        logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var)
+        return logprob
+
+class SysMTGauss(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysMTGauss, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior
+        self.contextual_posterior = config.contextual_posterior
+        if "shared_train" in config:
+            self.shared_train = config.shared_train
+        else:
+            self.shared_train = False
+
+        if "use_aux_kl" in config:
+            self.use_aux_kl = config.use_aux_kl
+        else:
+            self.use_aux_kl = False
+
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+        
+        self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+        self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                          config.y_size, is_lstm=False)
+        # if self.shared_train:
+            # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size,
+                                          # config.y_size, is_lstm=False)
+
+        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
+        self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size)
+        if not self.simple_posterior:
+            # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                               # config.y_size, is_lstm=False)
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                                   config.y_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False)
+
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        if "state_for_decoding" not in config:
+            self.state_for_decoding = False
+        else:
+            self.state_for_decoding = config.state_for_decoding
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        self.entropy_loss = GaussianEntropy()
+
+        self.gauss_kl = NormKLLoss(unit_average=True)
+        self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
+
+        self.aux_pi_beta = self.config.aux_pi_beta if hasattr(self.config, 'aux_pi_beta') else 1.0
+
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.shared_train:
+            if "selective_fine_tune" in self.config and self.config.selective_fine_tune:
+                total_loss = loss.nll + self.config.beta * loss.aux_pi_kl
+            else:
+                total_loss = loss.nll + loss.ae_nll + self.aux_pi_beta * loss.aux_pi_kl + self.config.beta * loss.aux_kl 
+        else:
+            if self.simple_posterior:
+                total_loss = loss.nll
+                if self.config.use_pr > 0.0:
+                    total_loss += self.config.beta * loss.pi_kl
+            else:
+                total_loss = loss.nll + loss.pi_kl
+
+
+        return total_loss
+    
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        short_target_utts = self.np2var(data_feed['outputs'], LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1))
+        
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # aux_enc_last = aux_utt_summary.squeeze(1)
+        aux_enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), aux_utt_summary.squeeze(1)], dim=1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            # logprob_sample_z = self.gaussian_logprob(q_mu, self.zero, sample_z)
+            # joint_logpz = th.sum(logprob_sample_z, dim=1)
+            # pdb.set_trace()
+
+            if self.shared_train:
+                # aux_q_mu, aux_q_logvar = self.aux_c2z(aux_enc_last)
+                aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+                aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar)
+                p_mu, p_logvar = self.zero, self.zero
+        else:
+            p_mu, p_logvar = self.c2z(enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                q_mu, q_logvar = self.xc2z(x_h.squeeze(1))
+
+            aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+            
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or use_py:
+                sample_z = self.gauss_connector(p_mu, p_logvar)
+            else:
+                sample_z = self.gauss_connector(q_mu, q_logvar)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        if self.shared_train:
+            aux_dec_init_state = self.z_embedding(aux_sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+            if self.shared_train:
+                aux_dec_init_state = tuple([aux_dec_init_state, aux_dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            return ret_dict, labels
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            if self.shared_train:
+                ae_dec_outputs, ae_hidden_state, ae_ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=aux_dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+                result['ae_nll'] = self.nll(ae_dec_outputs, labels)
+                aux_pi_kl = self.gauss_kl(q_mu, q_logvar, aux_q_mu, aux_q_logvar)
+                aux_kl = self.gauss_kl(aux_q_mu, aux_q_logvar, p_mu, p_logvar)
+                result['aux_pi_kl'] = aux_pi_kl
+                result['aux_kl'] = aux_kl
+                # result['aux_entropy'] = self.entropy_loss(aux_q_mu, aux_q_logvar)
+
+
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            result['pi_kl'] = pi_kl
+            # result['pi_entropy'] = self.entropy_loss(q_mu, q_logvar)
+            result['nll'] = self.nll(dec_outputs, labels)
+            return result
+    
+    def encode_state(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        return enc_last
+
+    def encode_action(self, data_feed):
+        batch_size = data_feed.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(data_feed.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = aux_utt_summary.squeeze(1)
+
+        return aux_enc_last
+            
+    def get_z_via_vae(self, responses):
+        batch_size = responses.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+        aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar)
+        
+        return aux_sample_z, aux_q_mu, aux_q_logvar
+
+    def get_z_via_rg(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        q_mu, q_logvar = self.c2z(enc_last)
+
+        sample_z = self.gauss_connector(q_mu, q_logvar)
+        
+        return sample_z, q_mu, q_logvar
+
+    def gaussian_prob(self, mu, logvar, sample_z):
+        var = th.exp(logvar)
+
+        den = th.sqrt(2 * np.pi * var)
+        po = - (th.pow((sample_z - mu), 2) / (2 * var))
+
+        prob = 1 / den * th.exp(po)
+
+        return prob
+
+    def gaussian_logprob(self, mu, logvar, sample_z):
+        var = th.exp(logvar)
+        constant = float(-0.5 * np.log(2*np.pi))
+        logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var)
+        return logprob
+
+        return logprobs, outs, joint_logpz, sample_z
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        p_mu, p_logvar = self.c2z(enc_last)
+
+        # sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach()
+        sample_z = self.gauss_connector(p_mu, p_logvar)
+        logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+        return logprobs, outs, joint_logpz, sample_z
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        
+        if data_feed:
+            ctx_lens = data_feed['context_lens']  # (batch_size, )
+            short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+            bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+            db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+ 
+        # pack attention context
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        dec_init_state = self.z_embedding(sample_y.unsqueeze(0))
+        if (dec_init_state != dec_init_state).any():
+            pdb.set_trace()
+        attn_context = None
+
+        # decode
+        # if self.state_for_decoding:
+            # if not data_feed:
+                # raise ValueError
+            # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+            # # create decoder initial states
+            # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+            # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+
+        return logprobs, outs
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
+
+    def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, sample_z=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            p_mu, p_logvar = self.zero, self.zero
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            result['pi_kl'] = pi_kl
+            return result
+
+class SysActZGauss(BaseModel):
+    def __init__(self, corpus, config):
+        super(SysActZGauss, self).__init__(config)
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.y_size = config.y_size
+        self.simple_posterior = config.simple_posterior
+        self.contextual_posterior = config.contextual_posterior
+
+        self.embedding = None
+        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+        
+        self.aux_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
+                                         embedding_dim=config.embed_size,
+                                         feat_size=0,
+                                         goal_nhid=0,
+                                         rnn_cell=config.utt_rnn_cell,
+                                         utt_cell_size=config.utt_cell_size,
+                                         num_layers=config.num_layers,
+                                         input_dropout_p=config.dropout,
+                                         output_dropout_p=config.dropout,
+                                         bidirectional=config.bi_utt_cell,
+                                         variable_lengths=False,
+                                         use_attn=config.enc_use_attn,
+                                         embedding=self.embedding)
+
+
+        self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size,
+                                          config.y_size, is_lstm=False)
+        # self.aux_c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size,
+                                          # config.y_size, is_lstm=False)
+
+        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
+        self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size)
+        if not self.simple_posterior:
+            # self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                               # config.y_size, is_lstm=False)
+            if self.contextual_posterior:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
+                                                   config.y_size, is_lstm=False)
+            else:
+                self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False)
+
+
+        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
+                                  rnn_cell=config.dec_rnn_cell,
+                                  input_size=config.embed_size,
+                                  hidden_size=config.dec_cell_size,
+                                  num_layers=config.num_layers,
+                                  output_dropout_p=config.dropout,
+                                  bidirectional=False,
+                                  vocab_size=self.vocab_size,
+                                  use_attn=config.dec_use_attn,
+                                  ctx_cell_size=config.dec_cell_size,
+                                  attn_mode=config.dec_attn_mode,
+                                  sys_id=self.bos_id,
+                                  eos_id=self.eos_id,
+                                  use_gpu=config.use_gpu,
+                                  max_dec_len=config.max_dec_len,
+                                  embedding=self.embedding)
+
+        self.nll = NLLEntropy(self.pad_id, config.avg_type)
+        if config.avg_type == "weighted" and config.nll_weight=="no_match_penalty":
+            req_tokens = []
+            for d in REQ_TOKENS.keys():
+                req_tokens.extend(REQ_TOKENS[d])
+            nll_weight = Variable(th.FloatTensor([10. if token in req_tokens  else 1. for token in self.vocab]))
+            print("req tokens assigned with special weights")
+            if config.use_gpu:
+                nll_weight = nll_weight.cuda()
+            self.nll.set_weight(nll_weight)
+
+        if "state_for_decoding" not in config:
+            self.state_for_decoding = False
+        else:
+            self.state_for_decoding = config.state_for_decoding
+
+
+        self.gauss_kl = NormKLLoss(unit_average=True)
+        self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
+    
+    def valid_loss(self, loss, batch_cnt=None):
+        if self.simple_posterior:
+            total_loss = loss.nll
+            if self.config.use_pr > 0.0:
+                total_loss += self.config.beta * loss.pi_kl
+        else:
+            total_loss = loss.nll + loss.pi_kl
+
+        if self.config.use_mi:
+            total_loss += (loss.b_pr * self.beta)
+
+        if self.config.use_diversity:
+            total_loss += loss.diversity
+
+        if "match_z" in self.config and self.config.match_z:
+            total_loss += loss.z_mse
+
+        return total_loss
+    
+    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        short_target_utts = self.np2var(data_feed['outputs'], LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(short_target_utts.unsqueeze(1))
+        
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        aux_enc_last = th.cat([bs_label, db_label, aux_utt_summary.squeeze(1)], dim=1)
+        # aux_enc_last = aux_utt_summary.squeeze(1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            # p_mu, p_logvar = self.aux_c2z(aux_enc_last)
+            p_mu, p_logvar = self.c2z(aux_enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            aux_sample_z = self.gauss_connector(p_mu, p_logvar)
+        else:
+            p_mu, p_logvar = self.c2z(enc_last)
+            # encode response and use posterior to find q(z|x, c)
+            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
+            if self.contextual_posterior:
+                q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))
+            else:
+                q_mu, q_logvar = self.xc2z(x_h.squeeze(1))
+
+            aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+            
+            # use prior at inference time, otherwise use posterior
+            if mode == GEN or use_py:
+                sample_z = self.gauss_connector(p_mu, p_logvar)
+            else:
+                sample_z = self.gauss_connector(q_mu, q_logvar)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            return ret_dict, labels
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            z_mse = F.mse_loss(aux_sample_z, sample_z)
+            result['pi_kl'] = pi_kl
+            result['z_mse'] = z_mse
+            # result['nll'] = self.nll(dec_outputs, labels)
+            return result
+
+    def forward_rl(self, data_feed, max_words, temp=0.1):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        # create decoder initial states
+        p_mu, p_logvar = self.c2z(enc_last)
+
+        # sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach()
+        sample_z = self.gauss_connector(p_mu, p_logvar)
+        logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z)
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.state_for_decoding:
+            dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+        return logprobs, outs, joint_logpz, sample_z
+
+    def forward_aux(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        # act_label = self.np2var(data_feed['act'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        batch_size = len(ctx_lens)
+
+        utt_summary, _, enc_outs = self.aux_encoder(short_ctx_utts.unsqueeze(1))
+
+        # get decoder inputs
+        dec_inputs = out_utts[:, :-1]
+        labels = out_utts[:, 1:].contiguous()
+
+        # create decoder initial states
+        enc_last = th.cat([th.zeros_like(bs_label), th.zeros_like(db_label), utt_summary.squeeze(1)], dim=1)
+
+        # create decoder initial states
+        if self.simple_posterior:
+            q_mu, q_logvar = self.c2z(enc_last)
+            sample_z = self.gauss_connector(q_mu, q_logvar)
+            p_mu, p_logvar = self.zero, self.zero
+
+        # pack attention context
+        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
+        attn_context = None
+
+        # decode
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
+                                                               dec_inputs=dec_inputs,
+                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
+                                                               attn_context=attn_context,
+                                                               mode=mode,
+                                                               gen_type=gen_type,
+                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
+        if mode == GEN:
+            ret_dict['sample_z'] = sample_z
+            ret_dict['q_mu'] = q_mu
+            ret_dict['q_logvar'] = q_logvar
+            return ret_dict, labels
+
+        else:
+            result = Pack(nll=self.nll(dec_outputs, labels))
+            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
+            result['pi_kl'] = pi_kl
+            result['nll'] = self.nll(dec_outputs, labels)
+            return result
+    
+    def get_z_via_vae(self, responses):
+        batch_size = responses.shape[0]
+        aux_utt_summary, _, aux_enc_outs = self.aux_encoder(responses.unsqueeze(1))
+        
+        # create decoder initial states
+        aux_enc_last = th.cat([self.np2var(np.zeros([batch_size, self.bs_size]), LONG), self.np2var(np.zeros([batch_size, self.db_size]), LONG), aux_utt_summary.squeeze(1)], dim=1)
+
+        aux_q_mu, aux_q_logvar = self.c2z(aux_enc_last)
+        aux_sample_z = self.gauss_connector(aux_q_mu, aux_q_logvar)
+        
+        return aux_sample_z, aux_q_mu, aux_q_logvar
+
+    def get_z_via_rg(self, data_feed):
+        ctx_lens = data_feed['context_lens']  # (batch_size, )
+        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+        q_mu, q_logvar = self.c2z(enc_last)
+
+        sample_z = self.gauss_connector(q_mu, q_logvar)
+        
+        return sample_z, q_mu, q_logvar
+
+    def decode_z(self, sample_y, batch_size, data_feed=None, max_words=None, temp=0.1, gen_type='greedy'):
+        """
+        generate response from latent var
+        """
+        
+        if data_feed:
+            ctx_lens = data_feed['context_lens']  # (batch_size, )
+            short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
+            bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+            db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
+ 
+        # pack attention context
+        if isinstance(sample_y, np.ndarray):
+            sample_y = self.np2var(sample_y, FLOAT)
+
+        dec_init_state = self.z_embedding(sample_y.unsqueeze(0))
+        if (dec_init_state != dec_init_state).any():
+            pdb.set_trace()
+        attn_context = None
+
+        # decode
+        # if self.state_for_decoding:
+            # if not data_feed:
+                # raise ValueError
+            # utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+            # # create decoder initial states
+            # enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+            # dec_init_state = th.cat([dec_init_state, enc_last.unsqueeze(0)], dim=2)
+
+        if self.config.dec_rnn_cell == 'lstm':
+            dec_init_state = tuple([dec_init_state, dec_init_state])
+
+        # decode
+        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
+                                                 dec_init_state=dec_init_state,
+                                                 attn_context=attn_context,
+                                                 vocab=self.vocab,
+                                                 max_words=max_words,
+                                                 temp=0.1)
+
+        return logprobs, outs
+
+    def gaussian_logprob(self, mu, logvar, sample_z):
+        var = th.exp(logvar)
+        constant = float(-0.5 * np.log(2*np.pi))
+        logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var)
+        return logprob
+
+        return logprobs, outs, joint_logpz, sample_z
+
+    def pad_to(self, max_len, tokens, do_pad):
+        if len(tokens) >= max_len:
+            # print("cutting off, ", tokens)
+            return tokens[: max_len-1] + [tokens[-1]]
+        elif do_pad:
+            return tokens + [0] * (max_len - len(tokens))
+        else:
+            return tokens
diff --git a/latent_dialog/nn_lib.py b/latent_dialog/nn_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..f20d21760b3606e8d7aeaa78381458fc2f2b0883
--- /dev/null
+++ b/latent_dialog/nn_lib.py
@@ -0,0 +1,225 @@
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.autograd import Variable
+from latent_dialog.utils import cast_type, FLOAT
+
+
+class IdentityConnector(nn.Module):
+    def __init(self):
+        super(IdentityConnector, self).__init__()
+
+    def forward(self, hidden_state):
+        return hidden_state
+
+
+class Bi2UniConnector(nn.Module):
+    def __init__(self, rnn_cell, num_layer, hidden_size, output_size):
+        super(Bi2UniConnector, self).__init__()
+        if rnn_cell == 'lstm':
+            self.fch = nn.Linear(hidden_size*2*num_layer, output_size)
+            self.fcc = nn.Linear(hidden_size*2*num_layer, output_size)
+        else:
+            self.fc = nn.Linear(hidden_size*2*num_layer, output_size)
+
+        self.rnn_cell = rnn_cell
+        self.hidden_size = hidden_size
+        self.output_size = output_size
+
+    def forward(self, hidden_state):
+        """
+        :param hidden_state: [num_layer, batch_size, feat_size]
+        :param inputs: [batch_size, feat_size]
+        :return: 
+        """
+        if self.rnn_cell == 'lstm':
+            h, c = hidden_state
+            num_layer = h.size()[0]
+            flat_h = h.transpose(0, 1).contiguous()
+            flat_c = c.transpose(0, 1).contiguous()
+            new_h = self.fch(flat_h.view(-1, self.hidden_size*num_layer))
+            new_c = self.fch(flat_c.view(-1, self.hidden_size*num_layer))
+            return (new_h.view(1, -1, self.output_size),
+                    new_c.view(1, -1, self.output_size))
+        else:
+            # FIXME fatal error here!
+            num_layer = hidden_state.size()[0]
+            new_s = self.fc(hidden_state.view(-1, self.hidden_size*num_layer))
+            new_s = new_s.view(1, -1, self.output_size)
+            return new_s
+
+
+class Hidden2Gaussian(nn.Module):
+    def __init__(self, input_size, output_size, is_lstm=False, has_bias=True):
+        super(Hidden2Gaussian, self).__init__()
+        if is_lstm:
+            self.mu_h = nn.Linear(input_size, output_size, bias=has_bias)
+            self.logvar_h = nn.Linear(input_size, output_size, bias=has_bias)
+
+            self.mu_c = nn.Linear(input_size, output_size, bias=has_bias)
+            self.logvar_c = nn.Linear(input_size, output_size, bias=has_bias)
+        else:
+            self.mu = nn.Linear(input_size, output_size, bias=has_bias)
+            self.logvar = nn.Linear(input_size, output_size, bias=has_bias)
+
+        self.is_lstm = is_lstm
+
+    def forward(self, inputs):
+        """
+        :param inputs: batch_size x input_size
+        :return:
+        """
+        if self.is_lstm:
+            h, c= inputs
+            if h.dim() == 3:
+                h = h.squeeze(0)
+                c = c.squeeze(0)
+
+            mu_h, mu_c = self.mu_h(h), self.mu_c(c)
+            logvar_h, logvar_c = self.logvar_h(h), self.logvar_c(c)
+            return mu_h+mu_c, logvar_h+logvar_c
+        else:
+            # if inputs.dim() == 3:
+            #    inputs = inputs.squeeze(0)
+            mu = self.mu(inputs)
+            logvar = self.logvar(inputs)
+            return mu, logvar
+
+
+class Hidden2Discrete(nn.Module):
+    def __init__(self, input_size, y_size, k_size, is_lstm=False, has_bias=True):
+        super(Hidden2Discrete, self).__init__()
+        self.y_size = y_size
+        self.k_size = k_size
+        latent_size = self.k_size*self.y_size
+        if is_lstm:
+            self.p_h = nn.Linear(input_size, latent_size, bias=has_bias)
+
+            self.p_c = nn.Linear(input_size, latent_size, bias=has_bias)
+        else:
+            self.p_h = nn.Linear(input_size, latent_size, bias=has_bias)
+
+        self.is_lstm = is_lstm
+
+    def forward(self, inputs):
+        """
+        :param inputs: batch_size x input_size
+        :return:
+        """
+        if self.is_lstm:
+            h, c= inputs
+            if h.dim() == 3:
+                h = h.squeeze(0)
+                c = c.squeeze(0)
+            logits = self.p_h(h) + self.p_c(c)
+        else:
+            logits = self.p_h(inputs)
+        logits = logits.view(-1, self.k_size)
+        log_qy = F.log_softmax(logits, dim=1)
+        return logits, log_qy
+
+
+class Hidden2DiscretewDropout(nn.Module):
+    def __init__(self, input_size, y_size, k_size, is_lstm=False, has_bias=True, p_dropout=0.1, dropout_on_eval=False):
+        super(Hidden2DiscretewDropout, self).__init__()
+        self.y_size = y_size
+        self.k_size = k_size
+        latent_size = self.k_size*self.y_size
+        self.dropout_on_eval = dropout_on_eval
+        self.p_dropout = p_dropout
+        if not dropout_on_eval:
+            self.dropout = nn.Dropout(p_dropout)
+        if is_lstm:
+            self.p_h = nn.Linear(input_size, latent_size, bias=has_bias)
+
+            self.p_c = nn.Linear(input_size, latent_size, bias=has_bias)
+        else:
+            self.p_h = nn.Linear(input_size, latent_size, bias=has_bias)
+
+        self.is_lstm = is_lstm
+
+    def forward(self, inputs):
+        """
+        :param inputs: batch_size x input_size
+        :return:
+        """
+        if self.dropout_on_eval:
+            drop = nn.Dropout(self.p_dropout)
+        else:
+            drop = self.dropout
+        if self.is_lstm:
+
+            h, c= inputs
+            if h.dim() == 3:
+                h = h.squeeze(0)
+                c = c.squeeze(0)
+            logits = drop(self.p_h(h)) + drop(self.p_c(c))
+        else:
+            logits = drop(self.p_h(inputs))
+        logits = logits.view(-1, self.k_size)
+        log_qy = F.log_softmax(logits, dim=1)
+        return logits, log_qy
+
+
+
+class GaussianConnector(nn.Module):
+    def __init__(self, use_gpu):
+        super(GaussianConnector, self).__init__()
+        self.use_gpu = use_gpu
+
+    def forward(self, mu, logvar):
+        """
+        Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the
+        reparametrization trick.
+        TODO: this should be better be a instance method in a Gaussian class.
+        :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching
+        :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None.
+        :return:
+        """
+        epsilon = th.randn(logvar.size())
+        epsilon = cast_type(Variable(epsilon), FLOAT, self.use_gpu)
+        std = th.exp(0.5 * logvar)
+        z = mu + std * epsilon
+        return z
+
+
+class GumbelConnector(nn.Module):
+    def __init__(self, use_gpu):
+        super(GumbelConnector, self).__init__()
+        self.use_gpu = use_gpu
+
+    def sample_gumbel(self, logits, use_gpu, eps=1e-20):
+        u = th.rand(logits.size())
+        sample = Variable(-th.log(-th.log(u + eps) + eps))
+        sample = cast_type(sample, FLOAT, use_gpu)
+        return sample
+
+    def gumbel_softmax_sample(self, logits, temperature, use_gpu):
+        """ Draw a sample from the Gumbel-Softmax distribution"""
+        eps = self.sample_gumbel(logits, use_gpu)
+        y = logits + eps
+        return F.softmax(y / temperature, dim=y.dim()-1)
+
+    def forward(self, logits, temperature=1.0, hard=False,
+                return_max_id=False):
+        """
+        :param logits: [batch_size, n_class] unnormalized log-prob
+        :param temperature: non-negative scalar
+        :param hard: if True take argmax
+        :param return_max_id
+        :return: [batch_size, n_class] sample from gumbel softmax
+        """
+        y = self.gumbel_softmax_sample(logits, temperature, self.use_gpu)
+        _, y_hard = th.max(y, dim=1, keepdim=True)
+        if hard:
+            y_onehot = cast_type(Variable(th.zeros(y.size())), FLOAT, self.use_gpu)
+            y_onehot.scatter_(1, y_hard, 1.0)
+            y = y_onehot
+        if return_max_id:
+            return y, y_hard
+        else:
+            return y
+
+
+
diff --git a/latent_dialog/normalizer/__init__.py b/latent_dialog/normalizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc b/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a89295ad2294f7cd75fb4b3d46cb68e167d6985b
Binary files /dev/null and b/latent_dialog/normalizer/__pycache__/__init__.cpython-36.pyc differ
diff --git a/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc b/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3eff08c2b9522413e90ddfd5120f0bbd5389e298
Binary files /dev/null and b/latent_dialog/normalizer/__pycache__/delexicalize.cpython-36.pyc differ
diff --git a/latent_dialog/normalizer/delexicalize.py b/latent_dialog/normalizer/delexicalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..df4636b24d0083bffb9b6d1303dc05d447720970
--- /dev/null
+++ b/latent_dialog/normalizer/delexicalize.py
@@ -0,0 +1,283 @@
+import re
+import os
+# import simplejson as json
+import json
+
+
+digitpat = re.compile('\d+')
+timepat = re.compile("\d{1,2}[:]\d{1,2}")
+pricepat = re.compile("\d{1,3}[.]\d{1,2}")
+
+CUR_PATH = os.path.join(os.path.dirname(__file__))
+fin = open(os.path.join(CUR_PATH, 'mapping.pair'), 'r')
+replacements = []
+for line in fin.readlines():
+    tok_from, tok_to = line.replace('\n', '').split('\t')
+    replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))
+
+# FORMAT
+# domain_value
+# restaurant_postcode
+# restaurant_address
+# taxi_car8
+# taxi_number
+# train_id etc..
+
+def insertSpace(token, text):
+    sidx = 0
+    while True:
+        sidx = text.find(token, sidx)
+        if sidx == -1:
+            break
+        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
+                re.match('[0-9]', text[sidx + 1]):
+            sidx += 1
+            continue
+        if text[sidx - 1] != ' ':
+            text = text[:sidx] + ' ' + text[sidx:]
+            sidx += 1
+        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
+            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
+        sidx += 1
+    return text
+
+
+def normalize(text):
+    # lower case every word
+    text = text.lower()
+
+    # replace white spaces in front and end
+    text = re.sub(r'^\s*|\s*$', '', text)
+
+    # hotel domain pfb30
+    text = re.sub(r"b&b", "bed and breakfast", text)
+    text = re.sub(r"b and b", "bed and breakfast", text)
+
+    # normalize phone number
+    ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text)
+    if ms:
+        sidx = 0
+        for m in ms:
+            sidx = text.find(m[0], sidx)
+            if text[sidx - 1] == '(':
+                sidx -= 1
+            eidx = text.find(m[-1], sidx) + len(m[-1])
+            text = text.replace(text[sidx:eidx], ''.join(m))
+
+    # normalize postcode
+    ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})',
+                    text)
+    if ms:
+        sidx = 0
+        for m in ms:
+            sidx = text.find(m, sidx)
+            eidx = sidx + len(m)
+            text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:]
+
+    # weird unicode bug
+    text = re.sub(u"(\u2018|\u2019)", "'", text)
+
+    # replace time and and price
+    text = re.sub(timepat, ' [value_time] ', text)
+    text = re.sub(pricepat, ' [value_price] ', text)
+
+    # replace st.
+    text = text.replace(';', ',')
+    text = re.sub('$\/', '', text)
+    text = text.replace('/', ' and ')
+
+    # replace other special characters
+    text = text.replace('-', ' ')
+    text = re.sub('[\":\<>@\(\)]', '', text)
+
+    # insert white space before and after tokens:
+    for token in ['?', '.', ',', '!']:
+        text = insertSpace(token, text)
+
+    # insert white space for 's
+    text = insertSpace('\'s', text)
+
+    # replace it's, does't, you'd ... etc
+    text = re.sub('^\'', '', text)
+    text = re.sub('\'$', '', text)
+    text = re.sub('\'\s', ' ', text)
+    text = re.sub('\s\'', ' ', text)
+    for fromx, tox in replacements:
+        text = ' ' + text + ' '
+        text = text.replace(fromx, tox)[1:-1]
+
+    # remove multiple spaces
+    text = re.sub(' +', ' ', text)
+
+    # concatenate numbers
+    tmp = text
+    tokens = text.split()
+    i = 1
+    while i < len(tokens):
+        if re.match(u'^\d+$', tokens[i]) and \
+                re.match(u'\d+$', tokens[i - 1]):
+            tokens[i - 1] += tokens[i]
+            del tokens[i]
+        else:
+            i += 1
+    text = ' '.join(tokens)
+
+    return text
+
+
+def prepareSlotValuesIndependent():
+    domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police']
+    requestables = ['phone', 'address', 'postcode', 'reference', 'id']
+    dic = []
+    dic_area = []
+    dic_food = []
+    dic_price = []
+
+    # read databases
+    for domain in domains:
+        try:
+            fin = file(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + domain + '_db.json'))
+            db_json = json.load(fin)
+            fin.close()
+
+            for ent in db_json:
+                for key, val in ent.items():
+                    if val == '?' or val == 'free':
+                        pass
+                    elif key == 'address':
+                        dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
+                        if "road" in val:
+                            val = val.replace("road", "rd")
+                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
+                        elif "rd" in val:
+                            val = val.replace("rd", "road")
+                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
+                        elif "st" in val:
+                            val = val.replace("st", "street")
+                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
+                        elif "street" in val:
+                            val = val.replace("street", "st")
+                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
+                    elif key == 'name':
+                        dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
+                        if "b & b" in val:
+                            val = val.replace("b & b", "bed and breakfast")
+                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
+                        elif "bed and breakfast" in val:
+                            val = val.replace("bed and breakfast", "b & b")
+                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
+                        elif "hotel" in val and 'gonville' not in val:
+                            val = val.replace("hotel", "")
+                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
+                        elif "restaurant" in val:
+                            val = val.replace("restaurant", "")
+                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
+                    elif key == 'postcode':
+                        dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']'))
+                    elif key == 'phone':
+                        dic.append((val, '[' + domain + '_' + 'phone' + ']'))
+                    elif key == 'trainID':
+                        dic.append((normalize(val), '[' + domain + '_' + 'id' + ']'))
+                    elif key == 'department':
+                        dic.append((normalize(val), '[' + domain + '_' + 'department' + ']'))
+
+                    # NORMAL DELEX
+                    elif key == 'area':
+                        dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']'))
+                    elif key == 'food':
+                        dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']'))
+                    elif key == 'pricerange':
+                        dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']'))
+                    else:
+                        pass
+                    # TODO car type?
+        except:
+            pass
+
+        if domain == 'hospital':
+            dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']'))
+            dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']'))
+            dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']'))
+            dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']'))
+            dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']'))
+            dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']'))
+            dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']'))
+
+        elif domain == 'police':
+            dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']'))
+            dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']'))
+            dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']'))
+            dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']'))
+            dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']'))
+
+    # add at the end places from trains
+    fin = open(os.path.join(CUR_PATH.replace('latent_dialog/normalizer', ''), 'data/norm-multi-woz/' + 'train' + '_db.json'))
+    db_json = json.load(fin)
+    fin.close()
+
+    for ent in db_json:
+        for key, val in ent.items():
+            if key == 'departure' or key == 'destination':
+                dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']'))
+
+    # add specific values:
+    for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']:
+        dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']'))
+
+    # more general values add at the end
+    dic.extend(dic_area)
+    dic.extend(dic_food)
+    dic.extend(dic_price)
+
+    return dic
+
+
+def delexicalise(utt, dictionary):
+    for key, val in dictionary:
+        utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
+        utt = utt[1:-1]  # why this?
+
+    return utt
+
+
+def delexicaliseReferenceNumber(sent, metadata):
+    """Based on the belief state, we can find reference number that
+    during data gathering was created randomly."""
+    domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']  # , 'police']
+    if metadata:
+        for domain in domains:
+            if metadata[domain]['book']['booked']:
+                for slot in metadata[domain]['book']['booked'][0]:
+                    if slot == 'reference':
+                        val = '[' + domain + '_' + slot + ']'
+                    else:
+                        val = '[' + domain + '_' + slot + ']'
+                    key = normalize(metadata[domain]['book']['booked'][0][slot])
+                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
+
+                    # try reference with hashtag
+                    key = normalize("#" + metadata[domain]['book']['booked'][0][slot])
+                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
+
+                    # try reference with ref#
+                    key = normalize("ref#" + metadata[domain]['book']['booked'][0][slot])
+                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
+    return sent
+
+
+def delexicalse_num(sent):
+    # changes to numbers only here
+    digitpat = re.compile('\d+')
+    sent = re.sub(digitpat, '[value_count]', sent)
+    return sent
+
+
+def e2e_delecalise(utt, dictionary, metadata):
+    utt = normalize(utt)
+    utt = delexicalise(utt, dictionary)
+    utt = delexicaliseReferenceNumber(utt, metadata)
+    return delexicalse_num(utt)
+
+
+if __name__ == '__main__':
+    prepareSlotValuesIndependent()
diff --git a/latent_dialog/normalizer/mapping.pair b/latent_dialog/normalizer/mapping.pair
new file mode 100644
index 0000000000000000000000000000000000000000..34df41d01e93ce27039e721e1ffb55bf9267e5a2
--- /dev/null
+++ b/latent_dialog/normalizer/mapping.pair
@@ -0,0 +1,83 @@
+it's	it is
+don't	do not
+doesn't	does not
+didn't	did not
+you'd	you would
+you're	you are
+you'll	you will
+i'm	i am
+they're	they are
+that's	that is
+what's	what is
+couldn't	could not
+i've	i have
+we've	we have
+can't	cannot
+i'd	i would
+i'd	i would
+aren't	are not
+isn't	is not
+wasn't	was not
+weren't	were not
+won't	will not
+there's	there is
+there're	there are
+. .	.
+restaurants	restaurant -s
+hotels	hotel -s
+laptops	laptop -s
+cheaper	cheap -er
+dinners	dinner -s
+lunches	lunch -s
+breakfasts	breakfast -s
+expensively	expensive -ly
+moderately	moderate -ly
+cheaply	cheap -ly
+prices	price -s
+places	place -s
+venues	venue -s
+ranges	range -s
+meals	meal -s
+locations	location -s
+areas	area -s
+policies	policy -s
+children	child -s
+kids	kid -s
+kidfriendly	kid friendly
+cards	card -s
+upmarket	expensive
+inpricey	cheap
+inches	inch -s
+uses	use -s
+dimensions	dimension -s
+driverange	drive range
+includes	include -s
+computers	computer -s
+machines	machine -s
+families	family -s
+ratings	rating -s
+constraints	constraint -s
+pricerange	price range
+batteryrating	battery rating
+requirements	requirement -s
+drives	drive -s
+specifications	specification -s
+weightrange	weight range
+harddrive	hard drive
+batterylife	battery life
+businesses	business -s
+hours	hour -s
+one	1
+two	2
+three	3
+four	4
+five	5
+six	6
+seven	7
+eight	8
+nine	9
+ten	10
+eleven	11
+twelve	12
+anywhere	any where
+good bye	goodbye
diff --git a/latent_dialog/offlinerl_utils.py b/latent_dialog/offlinerl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de121dd23145faac8230dd52db96528a5934d217
--- /dev/null
+++ b/latent_dialog/offlinerl_utils.py
@@ -0,0 +1,652 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+#
+# Copyright © 2021 lubis <lubis@hilbert50>
+#
+# Distributed under terms of the MIT license.
+
+"""
+
+"""
+
+import torch as th
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import numpy as np
+import pdb
+import random
+import copy
+from collections import namedtuple, deque
+from torch.autograd import Variable
+from latent_dialog.enc2dec.encoders import RnnUttEncoder
+from latent_dialog.utils import get_detokenize, cast_type, extract_short_ctx, np2var, LONG, FLOAT
+from latent_dialog.corpora import SYS, EOS, PAD, BOS, DOMAIN_REQ_TOKEN, ACTIVE_BS_IDX, NO_MATCH_DB_IDX, REQ_TOKENS
+import dill
+import time
+
+class Actor(nn.Module):
+    def __init__(self, model, corpus, config):
+        super(Actor, self).__init__()
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.config = config
+
+        self.use_gpu = config.use_gpu
+
+        self.embedding = None
+        self.is_stochastic = config.is_stochastic
+        self.y_size = config.y_size
+        if 'k_size' in config:
+            self.k_size = config.k_size
+            self.is_gauss = False
+        else:
+            self.max_action = config.max_action if "max_action" in config else None
+            self.is_gauss = True
+
+        self.utt_encoder = copy.deepcopy(model.utt_encoder)
+        self.c2z = copy.deepcopy(model.c2z)
+        if not self.is_gauss:
+            self.gumbel_connector = copy.deepcopy(model.gumbel_connector)
+        else:
+            self.gauss_connector = copy.deepcopy(model.gauss_connector)
+            self.gaussian_logprob = model.gaussian_logprob
+            self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
+
+        
+    def forward(self, data_feed, hard=False):
+        short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+        if self.is_gauss:
+            q_mu, q_logvar = self.c2z(enc_last)
+            if self.is_stochastic:
+                sample_z = self.gauss_connector(q_mu, q_logvar)
+            else:
+                sample_z = q_mu
+            logprob_sample_z = self.gaussian_logprob(q_mu, q_logvar, sample_z)
+        else:
+            logits_qy, log_qy = self.c2z(enc_last)
+            qy = F.softmax(logits_qy / 1.0, dim=1)  # (batch_size, vocab_size, )
+            log_qy = F.log_softmax(logits_qy, dim=1)  # (batch_size, vocab_size, )
+
+            if self.is_stochastic:
+                idx = th.multinomial(qy, 1).detach()
+                soft_z = self.gumbel_connector(logits_qy, hard=False)
+            else:
+                idx = th.argmax(th.exp(log_qy), dim=1, keepdim=True)
+                soft_z = th.exp(log_qy)
+            sample_z = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
+            sample_z.scatter_(1, idx, 1.0)
+            logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
+
+        joint_logpz = th.sum(logprob_sample_z, dim=1)
+        return joint_logpz, sample_z
+
+class DeterministicGaussianActor(nn.Module):
+    def __init__(self, model, corpus, config):
+        super(DeterministicGaussianActor, self).__init__()
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.config = config
+
+        self.use_gpu = config.use_gpu
+
+        self.embedding = None
+        self.y_size = config.y_size
+        self.max_action = config.max_action if "max_action" in config else None
+        self.is_gauss = True
+
+        self.utt_encoder = copy.deepcopy(model.utt_encoder)
+
+        self.policy = copy.deepcopy(model.c2z)
+
+    def forward(self, data_feed):
+        short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+        mu, logvar = self.policy(enc_last)
+        z = mu
+        if self.max_action is not None:
+            z =  self.max_action * th.tanh(z)
+
+        return z, mu, logvar
+
+class StochasticGaussianActor(nn.Module):
+    def __init__(self, model, corpus, config):
+        super(StochasticGaussianActor, self).__init__()
+        self.vocab = corpus.vocab
+        self.vocab_dict = corpus.vocab_dict
+        self.vocab_size = len(self.vocab)
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.bos_id = self.vocab_dict[BOS]
+        self.eos_id = self.vocab_dict[EOS]
+        self.pad_id = self.vocab_dict[PAD]
+        self.config = config
+
+        self.use_gpu = config.use_gpu
+
+        self.embedding = None
+        self.y_size = config.y_size
+        self.max_action = config.max_action if "max_action" in config else None
+        self.is_gauss = True
+
+        self.utt_encoder = copy.deepcopy(model.utt_encoder)
+        self.policy = copy.deepcopy(model.c2z)
+        self.gauss_connector = copy.deepcopy(model.gauss_connector)
+
+    def forward(self, data_feed, n_z=1):
+        short_ctx_utts = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))
+        # create decoder initial states
+        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
+
+        q_mu, q_logvar = self.policy(enc_last)
+        if n_z > 1:
+            z = [self.gauss_connector(q_mu, q_logvar) for _ in range(n_z)]
+        else:
+            z = self.gauss_connector(q_mu, q_logvar)
+
+        return z, q_mu, q_logvar
+
+class RecurrentCritic(nn.Module):
+    def __init__(self,cvae, corpus, config, args):
+        super(RecurrentLatentCritic, self).__init__()
+
+        self.embedding = None
+        self.word_plas = args.word_plas
+        self.state_dim = cvae.utt_encoder.output_size
+        if self.word_plas:
+            self.action_dim = cvae.aux_encoder.output_size
+        else:
+            self.action_dim = config.y_size #TODO adjust for categorical
+        
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim
+        self.goal_to_critic = args.goal_to_critic
+        if self.goal_to_critic:
+            raise NotImplementedError
+
+        self.use_gpu = config.use_gpu
+
+        self.state_encoder = copy.deepcopy(cvae.utt_encoder)
+        if self.word_plas:
+            self.action_encoder = copy.deepcopy(cvae.aux_encoder)
+        else:
+            self.action_encoder = None
+
+        self.q11 = nn.Linear(self.input_dim, 500)
+        self.q12 = nn.Linear(500, 300)
+        self.q13 = nn.Linear(300, 100)
+        self.q14 = nn.Linear(100, 20)
+        self.q15 = nn.Linear(20, 1)
+
+        self.q21 = nn.Linear(self.input_dim, 500)
+        self.q22 = nn.Linear(500, 300)
+        self.q23 = nn.Linear(300, 100)
+        self.q24 = nn.Linear(100, 20)
+        self.q25 = nn.Linear(20, 1)
+
+    def forward(self, data_feed, act):
+        ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1))
+        if self.word_plas:
+            resp_summary, _, _ = self.action_encoder(act.unsqueeze(1))
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1)
+        else:
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1)
+
+        q1 = self.q11(sa)
+        q1 = F.relu(self.q12(q1))
+        q1 = F.relu(self.q13(q1))
+        q1 = F.relu(self.q14(q1))
+        q1 = self.q15(q1)
+
+
+        q2 = self.q21(sa)
+        q2 = F.relu(self.q22(q2))
+        q2 = F.relu(self.q23(q2))
+        q2 = F.relu(self.q24(q2))
+        q2 = self.q25(q2)
+
+        return q1, q2
+
+    def q1(self, data_feed, act):
+        ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1))
+        if self.word_plas:
+            resp_summary, _, _ = self.action_encoder(act.unsqueeze(1))
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(0)], dim=1)
+        else:
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1)
+        
+        q1 = self.q11(sa)
+        q1 = F.relu(self.q12(q1))
+        q1 = F.relu(self.q13(q1))
+        q1 = F.relu(self.q14(q1))
+        q1 = self.q15(q1)
+
+        return q1
+
+
+        if self.goal_to_critic:
+            try:
+                goals = np2var(data_feed['goals'], FLOAT, self.use_gpu)
+            except KeyError:
+                goals = []
+                for turn_id in range(len(ctx_summary)):
+                    goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)]))
+                goals = np2var(np.asarray(goals), FLOAT, self.use_gpu)
+
+        #OPTION 1 add goal to encoder for each time step
+        if self.goal_to_critic and self.add_goal=="early":
+            sa = th.cat([sa, goals], dim = 1)
+
+        output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1)))
+
+        #OPTION 2 add goal combined with hidden state to predict final score
+        if self.goal_to_critic and self.add_goal=="late":
+            output = th.cat([output, goals.unsqueeze(1)], dim = 2)
+
+        q1 = self.q11(output.squeeze(1))
+
+        if self.activation_function == "relu":
+            q1 = F.relu(q1)
+        elif self.activation_function == "sigmoid":
+            q1 = th.sigmoid(q1)
+        elif self.activation_function == "tanh":
+            q1 = F.tanh(q1)
+
+        return q1
+    
+class SingleHierarchicalRecurrentCritic(nn.Module):
+    def __init__(self, cvae, corpus, config, args):
+        super(SingleHierarchicalRecurrentCritic, self).__init__()
+
+        self.hidden_size = 500
+        self.args = args
+        if "model_path" in args:
+            _path = args.model_path
+        else:
+            _path = args.sv_model_path
+
+        if "gauss" in _path:
+            self.is_gauss = True
+        else:
+            self.is_gauss = False
+        self.embedding = None
+        self.word_plas = args.word_plas
+        self.state_dim = cvae.utt_encoder.output_size
+        if self.word_plas:
+            try:
+                self.action_dim = cvae.aux_encoder.output_size
+            except:
+                self.action_dim = cvae.utt_encoder.output_size
+        else:
+            if self.is_gauss:
+                self.action_dim = config.y_size 
+            else:
+                if args.embed_z_for_critic:
+                    self.embed_z = True
+                    self.action_dim = config.dec_cell_size
+                else:
+                    self.embed_z = False
+                    self.action_dim = config.y_size * config.k_size
+
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim
+
+        self.goal_to_critic = args.goal_to_critic
+        self.add_goal = args.add_goal
+        if self.goal_to_critic:
+            self.goal_size = corpus.goal_size
+            if self.add_goal == "early":
+                self.input_dim += self.goal_size
+
+
+        self.use_gpu = config.use_gpu
+
+        self.state_encoder = copy.deepcopy(cvae.utt_encoder)
+        if self.word_plas:
+            try:
+                self.action_encoder = copy.deepcopy(cvae.aux_encoder)
+            except:
+                self.action_encoder = copy.deepcopy(cvae.utt_encoder)
+        else:
+            self.action_encoder = None
+
+        self.dialogue_encoder = nn.LSTM(
+                input_size = self.input_dim,
+                hidden_size = self.hidden_size,
+                dropout=0.1
+                )
+
+        if self.add_goal=="late":
+            self.q11 = nn.Linear(self.hidden_size + self.goal_size, 1)
+        else:
+            self.q11 = nn.Linear(self.hidden_size, 1)
+        self.activation_function = args.critic_actf if "critic_actf" in args else "none"
+
+        self.critic_dropout = args.critic_dropout
+        if self.critic_dropout:
+            self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False)
+        else:
+            self.d = th.nn.Dropout(p=0.0, inplace=False)
+
+
+    def forward(self, data_feed, act):
+        ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1))
+        if self.word_plas:
+            resp_summary, _, _ = self.action_encoder(act.unsqueeze(1)) # takes about 0.006seconds
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1)
+        else:
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1)
+
+        if self.goal_to_critic:
+            try:
+                goals = np2var(data_feed['goals'], FLOAT, self.use_gpu)
+            except KeyError:
+                goals = []
+                for turn_id in range(len(ctx_summary)):
+                    goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)]))
+                goals = np2var(np.asarray(goals), FLOAT, self.use_gpu)
+
+        #OPTION 1 add goal to encoder for each time step
+        if self.goal_to_critic and self.add_goal=="early":
+            sa = th.cat([sa, goals], dim = 1)
+
+        output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1)))
+
+        #OPTION 2 add goal combined with hidden state to predict final score
+        if self.goal_to_critic and self.add_goal=="late":
+            output = th.cat([output, goals.unsqueeze(1)], dim = 2)
+
+        q1 = self.q11(output.squeeze(1))
+
+        if self.activation_function == "relu":
+            q1 = F.relu(q1)
+        elif self.activation_function == "sigmoid":
+            q1 = th.sigmoid(q1)
+        elif self.activation_function == "tanh":
+            q1 = F.tanh(q1)
+
+        return q1
+    
+    def forward_target(self, data_feed, act, corpus_act):
+        ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1))
+        q1s =[]
+        for i in range(bs_label.shape[0]):
+            if self.word_plas:
+                if i > 0:
+                    corpus_resp_summary, _, _ = self.action_encoder(corpus_act[:i].unsqueeze(1))
+                else:
+                    corpus_resp_summary = th.Tensor([])
+                    if self.args.use_gpu:
+                        corpus_resp_summary = corpus_resp_summary.cuda()
+                actor_resp_summary, _, _ = self.action_encoder(act[i].unsqueeze(1))
+                sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_resp_summary[:i], actor_resp_summary], dim=0).squeeze(1)], dim=1)
+            else:
+                sa = th.cat([ctx_summary[:i+1].squeeze(1), bs_label[:i+1], db_label[:i+1], th.cat([corpus_act[:i], act[i].unsqueeze(0)], dim=0)], dim=1)
+
+            if self.goal_to_critic:
+                try:
+                    goals = np2var(data_feed['goals'][:i+1], FLOAT, self.use_gpu)
+                except KeyError:
+                    goals = []
+                    for turn_id in range(i+1):
+                        goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)]))
+                    goals = np2var(np.asarray(goals), FLOAT, self.use_gpu)
+
+            #OPTION 1 add goal to encoder for each time step
+            if self.goal_to_critic and self.add_goal=="early":
+                sa = th.cat([sa, goals], dim = 1)
+
+            output, (hn, cn) = self.dialogue_encoder(self.d(sa.unsqueeze(1)))
+
+            #OPTION 2 add goal combined with hidden state to predict final score
+            if self.goal_to_critic and self.add_goal=="late":
+                output = th.cat([output, goals.unsqueeze(1)], dim = 2)
+
+            q1 = self.q11(output.squeeze(1))
+
+            if self.activation_function == "relu":
+                q1 = F.relu(q1)
+            elif self.activation_function == "sigmoid":
+                q1 = th.sigmoid(q1)
+            elif self.activation_function == "tanh":
+                q1 = F.tanh(q1)
+
+            q1s.append(q1[-1])
+
+        return th.cat(q1s, dim=0).unsqueeze(1)
+
+class SingleRecurrentCritic(nn.Module):
+    def __init__(self, cvae, corpus, config, args):
+        super(SingleRecurrentCritic, self).__init__()
+
+        if "gauss" in args.saved_path:
+            self.is_gauss = True
+        else:
+            self.is_gauss = False
+        self.embedding = None
+        self.word_plas = args.word_plas
+        self.state_dim = cvae.utt_encoder.output_size
+        if self.word_plas:
+            self.action_dim = cvae.aux_encoder.output_size
+        else:
+            if self.is_gauss:
+                self.action_dim = config.y_size 
+            else:
+                if args.embed_z_for_critic:
+                    self.action_dim = config.dec_cell_size
+                else:
+                    self.action_dim = config.y_size * config.k_size
+
+        self.bs_size = corpus.bs_size
+        self.db_size = corpus.db_size
+        self.input_dim = self.state_dim + self.bs_size + self.db_size + self.action_dim
+
+        self.goal_to_critic = args.goal_to_critic
+        if self.goal_to_critic:
+            self.goal_size = corpus.goal_size
+            self.input_dim += self.goal_size
+
+
+        self.use_gpu = config.use_gpu
+
+        self.state_encoder = copy.deepcopy(cvae.utt_encoder)
+        if self.word_plas:
+            self.action_encoder = copy.deepcopy(cvae.aux_encoder)
+        else:
+            self.action_encoder = None
+
+        self.q11 = nn.Linear(self.input_dim, 1)
+        self.activation_function = args.critic_actf if "critic_actf" in args else "none"
+
+        self.critic_dropout = args.critic_dropout
+        if self.critic_dropout:
+            self.d = th.nn.Dropout(p=args.critic_dropout_rate, inplace=False)
+        else:
+            self.d = th.nn.Dropout(p=0.0, inplace=False)
+
+    def forward(self, data_feed, act):
+        ctx = np2var(extract_short_ctx(data_feed['contexts'], data_feed['context_lens']), LONG, self.use_gpu)
+        bs_label = np2var(data_feed['bs'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+        db_label = np2var(data_feed['db'], FLOAT, self.use_gpu)  # (batch_size, max_ctx_len, max_utt_len)
+
+        ctx_summary, _, _ = self.state_encoder(ctx.unsqueeze(1))
+        if self.word_plas:
+            resp_summary, _, _ = self.action_encoder(act.unsqueeze(1))
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, resp_summary.squeeze(1)], dim=1)
+        else:
+            sa = th.cat([ctx_summary.squeeze(1), bs_label, db_label, act], dim=1)
+
+        if self.goal_to_critic:
+            try:
+                goals = np2var(data_feed['goals'], FLOAT, self.use_gpu)
+            except KeyError:
+                goals = []
+                for turn_id in range(len(ctx_summary)):
+                    goals.append(np.concatenate([data_feed['goals_list'][d][turn_id] for d in range(7)]))
+                goals = np2var(np.asarray(goals), FLOAT, self.use_gpu)
+            sa = th.cat([sa, goals], dim = 1)
+
+        q1 = self.q11(self.d(sa))
+
+        if self.activation_function == "relu":
+            q1 = F.relu(q1)
+        elif self.activation_function == "sigmoid":
+            q1 = th.sigmoid(q1)
+
+        return q1
+
+class ReplayBuffer(object):
+    """
+    Buffer to store experiences, to be used in off-policy learning
+    """
+    def __init__(self, config): 
+
+        self.batch_size = config.batch_size
+        self.fix_episode = config.fix_episode
+        
+        self.experiences = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "next_action", "done", "Return"])
+        self.memory = deque()
+        self.seed = random.seed(config.random_seed)
+
+    def add(self, states, actions, rewards, next_states, next_actions, dones, Returns):
+        if self.fix_episode:
+            self._add_episode(states, actions, rewards, next_states, next_actions, dones, Returns)
+        else:
+            for i in range(len(states)):
+                self._add(states[i], actions[i], rewards[i], next_states[i], next_actions[i], dones[i], Returns[i])
+
+    def _add(self, state, action, reward, next_state, next_action, done, Return):
+        e = self.experiences(state, action, reward, next_state, next_action, done, Return)
+        self.memory.append(e)
+
+    def _add_episode(self, states, actions, rewards, next_states, next_actions, dones, Returns):
+        ep = []
+        for s, a, r, s_, a_, d, R in zip(states, actions, rewards, next_states, next_actions, dones, Returns):
+            ep.append(self.experiences(s, a, r, s_, a_, d, R))
+        self.memory.append(ep)
+
+
+    def sample(self):
+        if self.fix_episode:
+            return self._sample_episode()
+        else:
+            return self._sample()
+
+    def _sample(self):
+        experiences = random.sample(self.memory, k = self.batch_size)
+
+        states = {}
+        states['contexts'] = np.asarray([e.state['contexts'] for e in experiences])
+        states['bs'] = np.asarray([e.state['bs'] for e in experiences])
+        states['db'] = np.asarray([e.state['db'] for e in experiences])
+        states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) 
+        states['goals'] = np.asarray([e.state['goals'] for e in experiences]) 
+
+        actions = np.asarray([e.action for e in experiences if e is not None])
+        rewards = np.asarray([e.reward for e in experiences if e is not None])
+        
+        next_states = {}
+        next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences])
+        next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences])
+        next_states['db'] = np.asarray([e.next_state['db'] for e in experiences])
+        next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) 
+        next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences])
+        
+        next_actions = np.asarray([e.next_action for e in experiences if e is not None])
+
+        dones = np.asarray([e.done for e in experiences if e is not None])
+        returns = np.asarray([e.Return for e in experiences if e is not None])
+
+        return (states, actions, rewards, next_states, next_actions, dones, returns)
+    
+    def _sample_episode(self):
+        episodes = random.sample(self.memory, k = 1)
+
+        for experiences in episodes:
+            states = {}
+            states['contexts'] = np.asarray([e.state['contexts'] for e in experiences])
+            states['bs'] = np.asarray([e.state['bs'] for e in experiences])
+            states['db'] = np.asarray([e.state['db'] for e in experiences])
+            states['keys'] = [e.state['keys'] for e in experiences]
+            states['context_lens'] = np.asarray([e.state['context_lens'] for e in experiences]) 
+            states['goals'] = np.asarray([e.state['goals'] for e in experiences]) 
+
+            actions = np.asarray([e.action for e in experiences if e is not None])
+            rewards = np.asarray([e.reward for e in experiences if e is not None])
+            
+            next_states = {}
+            next_states['contexts'] = np.asarray([e.next_state['contexts'] for e in experiences])
+            next_states['bs'] = np.asarray([e.next_state['bs'] for e in experiences])
+            next_states['db'] = np.asarray([e.next_state['db'] for e in experiences])
+            next_states['keys'] = [e.next_state['keys'] for e in experiences]
+            next_states['context_lens'] = np.asarray([e.next_state['context_lens'] for e in experiences]) 
+            next_states['goals'] = np.asarray([e.next_state['goals'] for e in experiences])
+            
+            next_actions = np.asarray([e.next_action for e in experiences if e is not None])
+
+            dones = np.asarray([e.done for e in experiences if e is not None])
+            returns = np.asarray([e.Return for e in experiences if e is not None])
+
+        return (states, actions, rewards, next_states, next_actions, dones, returns)
+
+    def __len__(self):
+        return len(self.memory)
+
+    def save(self, path):
+        with open(path, 'wb') as f:
+            dill.dump(self.memory, f)
+
+    def load(self, path):
+        with open(path, 'rb') as f:
+            self.memory = dill.load(f)
+
+    def load_add(self, path):
+        with open(path, 'rb') as f:
+            self.memory += dill.load(f)
+
diff --git a/latent_dialog/record.py b/latent_dialog/record.py
new file mode 100644
index 0000000000000000000000000000000000000000..613d34ef78e3c0e7ffcc9977ecfa50bf2c347e21
--- /dev/null
+++ b/latent_dialog/record.py
@@ -0,0 +1,169 @@
+import numpy as np
+from latent_dialog.enc2dec.decoders import TEACH_FORCE, GEN, DecoderRNN, GEN_VALID
+from collections import Counter
+
+
+class UniquenessSentMetric(object):
+    """Metric that evaluates the number of unique sentences."""
+    def __init__(self):
+        self.seen = set()
+        self.all_sents = []
+
+    def record(self, sen):
+        self.seen.add(' '.join(sen))
+        self.all_sents.append(' '.join(sen))
+
+    def value(self):
+        return len(self.seen)
+
+    def top_n(self, n):
+        return Counter(self.all_sents).most_common(n)
+
+
+class UniquenessWordMetric(object):
+    """Metric that evaluates the number of unique sentences."""
+    def __init__(self):
+        self.seen = set()
+
+    def record(self, word_list):
+        self.seen.update(word_list)
+
+    def value(self):
+        return len(self.seen)
+
+
+def record_task(n_epsd, model, val_data, config, ppl_f, dialog, ctx_gen_eval, rl_f):
+    record_ppl(n_epsd, model, val_data, config, ppl_f)
+    record_rl_task(n_epsd, dialog, ctx_gen_eval, rl_f)
+
+
+def record(n_epsd, model, val_data, sv_config, lm_model, ppl_f, dialog, ctx_gen_eval, rl_f):
+    record_ppl_with_lm(n_epsd, model, val_data, sv_config, lm_model, ppl_f)
+    record_rl(n_epsd, dialog, ctx_gen_eval, rl_f)
+
+
+def record_ppl_with_lm(n_epsd, model, data, config, lm_model, ppl_f):
+    model.eval()
+    loss_list = []
+    data.epoch_init(config, shuffle=False, verbose=True)
+    while True:
+        batch = data.next_batch()
+        if batch is None:
+            break
+        for i in range(1):
+            loss = model(batch, mode=TEACH_FORCE, use_py=True)
+            loss_list.append(loss.nll.item())
+
+    # USE LM to test generation performance
+    data.epoch_init(config, shuffle=False, verbose=False)
+    gen_loss_list = []
+    # first generate
+    while True:
+        batch = data.next_batch()
+        if batch is None:
+            break
+
+        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)
+        # move from GPU to CPU
+        labels = labels.cpu()
+        pred_labels = [t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]]
+        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1)  # (batch_size, max_dec_len)
+        # clean up the pred labels
+        clean_pred_labels = np.zeros((pred_labels.shape[0], pred_labels.shape[1]+1))
+        clean_pred_labels[:, 0] = model.sys_id
+        for b_id in range(pred_labels.shape[0]):
+            for t_id in range(pred_labels.shape[1]):
+                token = pred_labels[b_id, t_id]
+                clean_pred_labels[b_id, t_id + 1] = token
+                if token in [model.eos_id] or t_id == pred_labels.shape[1]-1:
+                    break
+
+        pred_out_lens = np.sum(np.sign(clean_pred_labels), axis=1)
+        max_pred_lens = np.max(pred_out_lens)
+        clean_pred_labels = clean_pred_labels[:, 0:int(max_pred_lens)]
+        batch['outputs'] = clean_pred_labels
+        batch['output_lens'] = pred_out_lens
+        loss = lm_model(batch, mode=TEACH_FORCE)
+        gen_loss_list.append(loss.nll.item())
+
+    avg_loss = np.average(loss_list)
+    avg_ppl = np.exp(avg_loss)
+    gen_avg_loss = np.average(gen_loss_list)
+    gen_avg_ppl = np.exp(gen_avg_loss)
+
+    ppl_f.write('{}\t{}\t{}\n'.format(n_epsd, avg_ppl, gen_avg_ppl))
+    ppl_f.flush()
+    model.train()
+
+
+def record_ppl(n_epsd, model, val_data, config, ppl_f):
+    model.eval()
+    loss_list = []
+    val_data.epoch_init(config, shuffle=False, verbose=True)
+    while True:
+        batch = val_data.next_batch()
+        if batch is None:
+            break
+        loss = model(batch, mode=TEACH_FORCE, use_py=True)
+        loss_list.append(loss.nll.item())
+    aver_loss = np.average(loss_list)
+    aver_ppl = np.exp(aver_loss)
+    ppl_f.write('{}\t{}\n'.format(n_epsd, aver_ppl))
+    ppl_f.flush()
+    model.train()
+
+
+def record_rl(n_epsd, dialog, ctx_gen, rl_f):
+    conv_list = []
+    reward_list = []
+    agree_list = []
+    sent_metric = UniquenessSentMetric()
+    word_metric = UniquenessWordMetric()
+
+    for ctxs in ctx_gen.ctxs:
+        conv, agree, rewards = dialog.run(ctxs)
+        true_reward = rewards[0] if agree else 0
+        reward_list.append(true_reward)
+        conv_list.append(conv)
+        agree_list.append(float(agree) if agree is not None else 0.0)
+        for turn in conv:
+            if turn[0] == 'Elder':
+                sent_metric.record(turn[1])
+                word_metric.record(turn[1])
+
+    # json.dump(conv_list, text_f, indent=4)
+    aver_reward = np.average(reward_list)
+    aver_agree = np.average(agree_list)
+    unique_sent_num = sent_metric.value()
+    unique_word_num = word_metric.value()
+    print(sent_metric.top_n(10))
+
+    rl_f.write('{}\t{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, aver_agree, unique_sent_num, unique_word_num))
+    rl_f.flush()
+
+
+def record_rl_task(n_epsd, dialog, goal_gen, rl_f):
+    conv_list = []
+    reward_list = []
+    sent_metric = UniquenessSentMetric()
+    word_metric = UniquenessWordMetric()
+    print("Begin RL testing")
+    cnt = 0
+    for g_key, goal in goal_gen.iter(1):
+        cnt += 1
+        conv, success = dialog.run(g_key, goal)
+        true_reward = success
+        reward_list.append(true_reward)
+        conv_list.append(conv)
+        for turn in conv:
+            if turn[0] == 'Elder':
+                sent_metric.record(turn[1])
+                word_metric.record(turn[1])
+
+    # json.dump(conv_list, text_f, indent=4)
+    aver_reward = np.average(reward_list)
+    unique_sent_num = sent_metric.value()
+    unique_word_num = word_metric.value()
+    rl_f.write('{}\t{}\t{}\t{}\n'.format(n_epsd, aver_reward, unique_sent_num, unique_word_num))
+    rl_f.flush()
+    print("End RL testing")
\ No newline at end of file
diff --git a/latent_dialog/utils.py b/latent_dialog/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6244f87538e4399b70d497dd4f18c28b878c0b91
--- /dev/null
+++ b/latent_dialog/utils.py
@@ -0,0 +1,140 @@
+import os
+import numpy as np
+import random
+import torch as th
+from torch.autograd import Variable
+from nltk import RegexpTokenizer
+from nltk.tokenize.treebank import TreebankWordDetokenizer
+import logging
+import sys
+from collections import defaultdict
+
+INT = 0
+LONG = 1
+FLOAT = 2
+
+
+class Pack(dict):
+    def __getattr__(self, name):
+        try:
+            return self[name]
+        except KeyError:
+            raise AttributeError(name)
+
+    def add(self, **kwargs):
+        for k, v in kwargs.items():
+            self[k] = v
+
+    def copy(self):
+        pack = Pack()
+        for k, v in self.items():
+            if type(v) is list:
+                pack[k] = list(v)
+            else:
+                pack[k] = v
+        return pack
+
+    @staticmethod
+    def msg_from_dict(dictionary, tokenize, speaker2id, bos_id, eos_id, include_domain=False):
+        pack = Pack()
+        for k, v in dictionary.items():
+            pack[k] = v
+        pack['speaker'] = speaker2id[pack.speaker]
+        pack['conf'] = dictionary.get('conf', 1.0)
+        utt = pack['utt']
+        if 'QUERY' in utt or "RET" in utt:
+            utt = str(utt)
+            # utt = utt.translate(None, ''.join([':', '"', "{", "}", "]", "["]))
+            utt = utt.translate(str.maketrans('', '', ''.join([':', '"', "{", "}", "]", "["])))
+            utt = str(utt)
+        if include_domain:
+            pack['utt'] = [bos_id, pack['speaker'], pack['domain']] + tokenize(utt) + [eos_id]
+        else:
+            pack['utt'] = [bos_id, pack['speaker']] + tokenize(utt) + [eos_id]
+        return pack
+
+def get_tokenize():
+    return RegexpTokenizer(r'\w+|#\w+|<\w+>|%\w+|[^\w\s]+').tokenize
+
+def get_detokenize():
+    return lambda x: TreebankWordDetokenizer().detokenize(x)
+
+def cast_type(var, dtype, use_gpu):
+    if use_gpu:
+        if dtype == INT:
+            var = var.type(th.cuda.IntTensor)
+        elif dtype == LONG:
+            var = var.type(th.cuda.LongTensor)
+        elif dtype == FLOAT:
+            var = var.type(th.cuda.FloatTensor)
+        else:
+            raise ValueError('Unknown dtype')
+    else:
+        if dtype == INT:
+            var = var.type(th.IntTensor)
+        elif dtype == LONG:
+            var = var.type(th.LongTensor)
+        elif dtype == FLOAT:
+            var = var.type(th.FloatTensor)
+        else:
+            raise ValueError('Unknown dtype')
+    return var
+
+def read_lines(file_name):
+    """Reads all the lines from the file."""
+    assert os.path.exists(file_name), 'file does not exists %s' % file_name
+    lines = []
+    with open(file_name, 'r') as f:
+        for line in f:
+            lines.append(line.strip())
+    return lines
+
+def set_seed(seed):
+    """Sets random seed everywhere."""
+    random.seed(seed)
+    np.random.seed(seed)
+    th.manual_seed(seed)
+    if th.cuda.is_available():
+        th.cuda.manual_seed(seed)
+        th.backends.cudnn.enabled = False
+        th.backends.cudnn.benchmark = False
+        th.backends.cudnn.deterministic = True
+
+def prepare_dirs_loggers(config, script=""):
+    logFormatter = logging.Formatter("%(message)s")
+    rootLogger = logging.getLogger()
+    rootLogger.setLevel(logging.DEBUG)
+
+    consoleHandler = logging.StreamHandler(sys.stdout)
+    consoleHandler.setLevel(logging.DEBUG)
+    consoleHandler.setFormatter(logFormatter)
+    rootLogger.addHandler(consoleHandler)
+
+    # if hasattr(config, 'forward_only') and config.forward_only:
+    if 'forward_only' in config and config.forward_only:
+        return
+
+    fileHandler = logging.FileHandler(os.path.join(config.saved_path,'session.log'))
+    fileHandler.setLevel(logging.DEBUG)
+    fileHandler.setFormatter(logFormatter)
+    rootLogger.addHandler(fileHandler)
+
+def get_chat_tokenize():
+    return nltk.RegexpTokenizer(r'\w+|<sil>|[^\w\s]+').tokenize
+
+class missingdict(defaultdict):
+    def __missing__(self, key):
+        return self.default_factory()
+
+def extract_short_ctx(context, context_lens, backward_size=1):
+    utts = []
+    for b_id in range(context.shape[0]):
+        utts.append(context[b_id, context_lens[b_id]-1])
+    return np.array(utts)
+
+def np2var(inputs, dtype, use_gpu):
+    if inputs is None:
+        return None
+    return cast_type(Variable(th.from_numpy(inputs)), 
+                     dtype, 
+                     use_gpu)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e4ea15340c736d5766f80318768080feffabfa68
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+dill==0.3.4
+matplotlib==3.5.1
+nltk==3.5
+numpy==1.19.5
+requests==2.28.1
+scikit_learn==1.1.2
+scipy==1.9.1
+tabulate==0.8.9
+torch==1.8.2+cu111
+tqdm==4.57.0