Skip to content
Snippets Groups Projects
Select Git revision
  • 97a04f8cbfa408e6e802488a357f8f9b26f01b29
  • master default protected
2 results

MovingParticles.mch

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    critic_json.py 13.22 KiB
    import time
    import os
    import sys
    import random
    sys.path.append('../')
    import json
    import torch as th
    import pdb
    from tqdm import trange
    from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
    from latent_dialog.corpora import NormMultiWozCorpus
    from latent_dialog.models_task import *
    from latent_dialog.agent_task import LatentCriticAgent, CriticAgent
    from latent_dialog.main import OfflineCritic
    from latent_dialog.evaluators import MultiWozEvaluator
    from latent_dialog.data_loaders import BeliefDbDataLoaders, BeliefDbDataLoadersAE
    from experiments_woz.dialog_utils import task_generate_critic, task_generate, task_run_critic
    from argparse import ArgumentParser
    
    
    def main(seed, pretrained_folder, pretrained_model_id, response_path):
        start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        print('[START]', start_time, '='*30)
        # RL configuration
        env = 'gpu'
    
        exp_dir = os.path.join('sys_config_log_model', "/".join(response_path.split("/")[-2:-1]).replace(".json", ""), "critic-"+start_time)
    
        if "rl" in pretrained_folder:
            join_fmt = "."
            config_path = os.path.join('sys_config_log_model', "/".join(pretrained_folder.split("/")[:-1]), "config.json")
        else:
            join_fmt = "-"
            config_path = os.path.join('sys_config_log_model', pretrained_folder, "config.json")
    
        # create exp folder
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
    
        critic_config = Pack(
            config_path = config_path,
            model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}{}model'.format(pretrained_model_id, join_fmt)), # used for encoder initialization for critic
            vae_config_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/config.json", # needed if raw_response=True and word_plas=False
            vae_model_path = "sys_config_log_model/2021-11-25-19-11-40-sl_gauss_ae/98-model",
            actor_path = None,
            critic_config_path = os.path.join(exp_dir, 'critic_config.json'),
            critic_model_path = os.path.join(exp_dir, 'critic_model'),
            saved_path = exp_dir,
            # ppl_best_model_path = os.path.join(exp_dir, 'ppl_best.model'),
            # reward_best_model_path = os.path.join(exp_dir, 'reward_best.model'),
            record_path = exp_dir,
            record_freq = 500,
            sv_train_freq= 0,  # TODO pay attention to main.py, cuz it is also controlled there
            use_gpu = env == 'gpu',
            nepoch = 1,
            nepisode = 1000,
            word_plas=True,
            raw_response=True,
            response_path=response_path,
            fix_episode=True,
            train_with_pseudotraj=False,
            train_with_full_data=False,
            reward_type="default", #default, turnPenalty, or infoGain
            infoGain_threshhold = 0.2, # only if reward_type is infoGain
            add_match_to_reward=False,
            soft_success=False,
            goal_to_critic=True,
            add_goal="early", #early or late. only when goal_to_critic=True and fix_episode=True
            critic_kl_loss=False,
            critic_kl_alpha=0.1,
            critic_dropout=True, # use with regularize_critic=False
            critic_dropout_rate=0.3, # only if dropout is true
            critic_dropout_agg="min", #avg or min
            critic_sample=False,
            critic_transformer=False,
            critic_actf="sigmoid", #relu or sigmoid or tanh or none
            embed_z_for_critic=True, #for categorical action only, when word_plas=False
            # critic_maxq=1, # only when actf=tanh or sigmoid, use with reward_type other than default
            critic_loss="mse", #mse or huber
            critic_rl_lr = 0.01,
            decay_critic_lr=False,
            train_vae=False,
            train_vae_freq=10001, # if larger than nepisode. only train once at the beginning
            train_vae_nepisode=0, # nepisode for the rest of VAE training
            train_vae_nepisode_init=10000, # nepisode for first VAE training
            weighted_vae_nll=True,
            rl_lr = 0.01,
            max_words = 50,
            temperature=1.0,
            momentum = 0.0,
            nesterov = False,
            gamma = 0.99,
            tau=0.005,
            lmbda=0.5,
            beta=0.001,
            batch_size=16,
            rl_clip = 1.0,
            n_z=10,
            random_seed = seed,
            policy_dropout=False,
            dropout_on_eval=False,
            fail_info_penalty=False
        )
    
        prepare_dirs_loggers(critic_config)
    
        # list config keys that are being compared for tensorboard naming
        tb_keys = ["critic_rl_lr", "reward_type", "critic_actf", "train_with_pseudotraj"]
        tensorboard_name = exp_dir.replace("sys_config_log_model/", "") + "-critic-" + "-".join([f"{k}={critic_config[k]}" for k in tb_keys])
    
        # load previous supervised learning configuration and corpus
        config = Pack(json.load(open(critic_config.config_path)))
        config['dropout'] = 0.0
        config['use_gpu'] = critic_config.use_gpu
        config['policy_dropout'] = critic_config.policy_dropout
        config['dropout_on_eval'] = critic_config.dropout_on_eval
        # assert config.train_path == critic_config.train_path
    
        # set random seed
        if critic_config.random_seed is None:
            try:
                critic_config.random_seed = config.seed
            except:
                critic_config.random_seed = config.random_seed
        set_seed(critic_config.random_seed)
    
    
        try:
            corpus = NormMultiWozCorpus(config)
        except FileNotFoundError:
            config['train_path'] = config.train_path.replace("/home/lubis", "")
            config['valid_path'] = config.valid_path.replace("/home/lubis", "")
            config['test_path'] = config.test_path.replace("/home/lubis", "")
            corpus = NormMultiWozCorpus(config)
    
        critic_config['train_path'] = config['train_path']
        critic_config['valid_path'] = config['valid_path']
        critic_config['test_path'] = config['test_path']
        
        if critic_config.reward_type == "default":
            critic_config['train_memory_path'] = config['train_path'].replace(".json", ".dill")
            critic_config['valid_memory_path'] = config['valid_path'].replace(".json", ".dill")
            critic_config['test_memory_path'] = config['test_path'].replace(".json", ".dill")
        else:
            critic_config['train_memory_path'] = config['train_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
            critic_config['valid_memory_path'] = config['valid_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
            critic_config['test_memory_path'] = config['test_path'].replace(".json", f"_{critic_config.reward_type}-{critic_config.infoGain_threshhold}.dill")
    
        if critic_config.fix_episode:
            critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-ep.dill")
            critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-ep.dill")
            critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-ep.dill")
    
        if critic_config.soft_success:
            critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-soft.dill")
            critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-soft.dill")
            critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-soft.dill")
    
        if critic_config.add_match_to_reward:
            critic_config['train_memory_path'] = critic_config['train_memory_path'].replace(".dill", "-wMatch.dill")
            critic_config['valid_memory_path'] = critic_config['valid_memory_path'].replace(".dill", "-wMatch.dill")
            critic_config['test_memory_path'] = critic_config['test_memory_path'].replace(".dill", "-wMatch.dill")
    
    
        critic_config['y_size'] = config['y_size']
    
    
        # save configuration
        with open(critic_config.critic_config_path, 'w') as f:
            json.dump(critic_config, f, indent=4)
    
        if "rl" in pretrained_folder:
            if "gauss" in pretrained_folder:
                sys_model = SysPerfectBD2Gauss(corpus, config)
            else:
                sys_model = SysPerfectBD2Cat(corpus, config)
        else:
            if "actz" in pretrained_folder:
                if "gauss" in pretrained_folder:
                    sys_model = SysActZGauss(corpus, config)
                else:
                    sys_model = SysActZCat(corpus, config)
            elif "mt" in pretrained_folder:
                if "gauss" in pretrained_folder:
                    sys_model = SysMTGauss(corpus, config)
                else:
                    sys_model = SysMTCat(corpus, config)
            else:
                if "gauss" in pretrained_folder:
                    sys_model = SysPerfectBD2Gauss(corpus, config)
                else:
                    sys_model = SysPerfectBD2Cat(corpus, config)
    
        vae_config = Pack(json.load(open(critic_config.vae_config_path)))
        if critic_config.raw_response and not critic_config.word_plas:
            if "gauss" in critic_config.vae_model_path:
                vae_model = SysAEGauss(corpus, vae_config)
            else:
                vae_model = SysAECat(corpus, vae_config)
            vae_model_dict = th.load(critic_config.vae_model_path, map_location=lambda storage, location: storage)
            vae_model.load_state_dict(vae_model_dict)
        else:
            vae_model = None
    
        if config.use_gpu:
            sys_model.cuda()
            if vae_model is not None:
                vae_model.cuda()
            
    
        mt_model_dict = th.load(critic_config.model_path, map_location=lambda storage, location: storage)
        sys_model.load_state_dict(mt_model_dict)
    
        sys_model.eval()
        evaluator = MultiWozEvaluator('SysWoz', config)
    
        if critic_config.word_plas:
            agent = CriticAgent(sys_model, corpus, critic_config, evaluator, name='System')
        else:
            agent = LatentCriticAgent(sys_model, corpus, critic_config, evaluator, name='System', vae=vae_model)
    
        main = OfflineCritic(agent, corpus, config, critic_config, task_run_critic, name=tensorboard_name, vae_gen=task_generate)
        # save sys model
        # th.save(sys_model.state_dict(), critic_config.rl_model_path)
    
        # initialize train buffer
        if os.path.isfile(critic_config.train_memory_path):
            print("Loading replay buffer for training from {}".format(critic_config.train_memory_path))
            agent.train_buffer.load(critic_config.train_memory_path)
            print(len(agent.train_buffer))
            if "train_with_full_data" in critic_config and critic_config.train_with_full_data:
                print(f"adding buffer {critic_config.valid_memory_path} to training buffer")
                agent.train_buffer.load_add(critic_config.valid_memory_path)
                print(len(agent.train_buffer))
                print(f"adding buffer {critic_config.test_memory_path} to training buffer")
                agent.train_buffer.load_add(critic_config.test_memory_path)
                print(len(agent.train_buffer))
    
        else:
            print("Extracting experiences from training data")
            main.extract(main.train_data, main.agent.train_buffer)
            print("Saving experiences to {}".format(critic_config.train_memory_path))
            agent.train_buffer.save(critic_config.train_memory_path)
    
        # initialize valid buffer
        if os.path.isfile(critic_config.valid_memory_path):
            print("Loading replay buffer for validation from {}".format(critic_config.valid_memory_path))
            agent.valid_buffer.load(critic_config.valid_memory_path)
        else:
            print("Extracting experiences from valid data")
            main.extract(main.val_data, main.agent.valid_buffer)
            print("Saving experiences to {}".format(critic_config.valid_memory_path))
            agent.valid_buffer.save(critic_config.valid_memory_path)
        
        # initialize test buffer
        if os.path.isfile(critic_config.test_memory_path):
            print("Loading replay buffer for test from {}".format(critic_config.test_memory_path))
            agent.test_buffer.load(critic_config.test_memory_path)
        else:
            print("Extracting experiences from test data")
            main.extract(main.test_data, main.agent.test_buffer)
            print("Saving experiences to {}".format(critic_config.test_memory_path))
            agent.test_buffer.save(critic_config.test_memory_path)
    
        if critic_config.use_gpu:
            agent.critic.cuda()
            agent.critic_target.cuda()
    
        #check system performance
        train_dial, val_dial, test_dial = corpus.get_corpus()
        train_data = BeliefDbDataLoaders('Train', train_dial, vae_config)
        val_data = BeliefDbDataLoaders('Val', val_dial, vae_config)
        test_data = BeliefDbDataLoaders('Test', test_dial, vae_config)
    
        with open(exp_dir + "/test_performance_start.txt", "w") as f:
            task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
        #train critic
        main.run()
     
        with open(exp_dir + "/test_performance_end.txt", "w") as f:
            task_run_critic(test_data, agent, None, evaluator=evaluator, outfile=f)
    
        end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        print('[END]', end_time, '='*30)
    
    
    if __name__ == '__main__' :
        parser = ArgumentParser()
    
        parser.add_argument("--infile", type=str, default="../data/augpt/test-predictions.json")
        args = parser.parse_args()
    
        # pick corresponding encoder
        if "hdsa" in args.infile or "HDSA" in args.infile:
            # MWOZ 2.0
            folder = "2020-05-12-14-51-49-actz_cat/rl-2020-05-18-10-50-48"
            id_ = "reward_best"
        elif "augpt" in args.infile:
            #MWOZ 2.1
            folder = "2021-11-25-11-52-47-mt_gauss" 
            id_ = "29"
    
        main(None, folder, id_, args.infile)