import time import os import json import sys from pathlib import Path sys.path.append((Path(__file__).parent / '../').resolve().as_posix()) import torch as th import pdb import logging import random from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed import latent_dialog.corpora as corpora from latent_dialog.data_loaders import BeliefDbDataLoadersAE, BeliefDbDataLoaders from latent_dialog.evaluators import MultiWozEvaluator from latent_dialog.models_task import SysMTGauss from latent_dialog.main import train, validate, mt_train import latent_dialog.domain as domain from experiments_woz.dialog_utils import task_generate # def main(seed): def main(seed, pretrained_folder, pretrained_model_id): domain_name = 'object_division' domain_info = domain.get_domain(domain_name) if pretrained_folder is not None: ae_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json') ae_config = Pack(json.load(open(ae_config_path))) ae_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)) train_path=ae_config.train_path valid_path=ae_config.valid_path test_path=ae_config.test_path else: ae_model_path = None ae_config_path = None train_path='../data/data_2.1/train_dials.json' valid_path='../data/data_2.1/val_dials.json' test_path='../data/data_2.1/test_dials.json' if seed is None: seed = ae_config.seed base_path = Path(__file__).parent config = Pack( seed = seed, ae_model_path=ae_model_path, ae_config_path=ae_config_path, train_path=train_path, valid_path=valid_path, test_path=test_path, dact_path=(base_path / '../data/norm-multi-woz/damd_dialog_acts.json').resolve().as_posix(), ae_zero_pad=True, max_vocab_size=1000, last_n_model=1, max_utt_len=50, max_dec_len=50, backward_size=2, batch_size=128, use_gpu=True, op='adam', init_lr=0.0005, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=50, # aux_train_freq=10, # aux_max_epoch=1, # epoch per training, i.e every aux_train_freq, train aux_max_epoch time shared_train=True, embed_size=256, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=False, dec_rnn_cell='lstm', dec_cell_size=300, dec_attn_mode='cat', y_size=200, beta = 0.01, # aux_pi_beta = 0.01, # default to 1.0 if not set simple_posterior=True, contextual_posterior=False, use_mi =False, use_pr =True, use_diversity = False, # beam_size=20, fix_batch = True, fix_train_batch=False, avg_type='word', # avg_type='slot', # slot_weight=10, use_aux_kl=False, selective_fine_tune=False, print_step=300, ckpt_step=1416, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=50, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2021-11-25-16-47-37-mt_gauss', forward_only=False ) set_seed(config.seed) start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) stats_path = (base_path / 'sys_config_log_model').resolve().as_posix() if config.forward_only: saved_path = os.path.join(stats_path, config.pretrain_folder) config = Pack(json.load(open(os.path.join(saved_path, 'config.json')))) config['forward_only'] = True else: saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0]) if not os.path.exists(saved_path): os.makedirs(saved_path) config.saved_path = saved_path if ae_config_path is not None: config["max_dec_len"] = ae_config.max_dec_len config["dec_use_attn"] = ae_config.dec_use_attn config["dec_rnn_cell"] = ae_config.dec_rnn_cell config["dec_cell_size"] = ae_config.dec_cell_size config["dec_attn_mode"] = ae_config.dec_attn_mode config["embed_size"] = ae_config.embed_size config["y_size"] = ae_config.y_size # Denoising setup if "noise_type" in ae_config: noise_type = ae_config.noise_type else: noise_type = None if noise_type: config["noise_type"] = ae_config.noise_type config["noise_p"] = ae_config.noise_p config["remove_tokens"] = ae_config.remove_tokens config["no_special"] = ae_config.no_special else: noise_type = None vocab_dict = None prepare_dirs_loggers(config) logger = logging.getLogger() logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30)) config.saved_path = saved_path # save configuration with open(os.path.join(saved_path, 'config.json'), 'w') as f: json.dump(config, f, indent=4) # sort_keys=True # data for AE training aux_corpus = corpora.NormMultiWozCorpusAE(config) vocab_dict = aux_corpus.vocab_dict if noise_type else None aux_train_dial, aux_val_dial, aux_test_dial = aux_corpus.get_corpus() aux_train_data = BeliefDbDataLoadersAE('Train', aux_train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger) aux_val_data = BeliefDbDataLoadersAE('Val', aux_val_dial, config) aux_test_data = BeliefDbDataLoadersAE('Test', aux_test_dial, config) # data for RG training corpus = corpora.NormMultiWozCorpus(config) train_dial, val_dial, test_dial = corpus.get_corpus() train_data = BeliefDbDataLoaders('Train', train_dial, config) val_data = BeliefDbDataLoaders('Val', val_dial, config) test_data = BeliefDbDataLoaders('Test', test_dial, config) evaluator = MultiWozEvaluator('SysWoz', config) model = SysMTGauss(corpus, config) model_dict = model.state_dict() # load params from saved ae_model if pretrained_folder is not None: ae_model_dict = th.load(config.ae_model_path, map_location=lambda storage, location: storage) tmp_k = [k for k in model_dict.keys() if k not in ae_model_dict.keys()] aux_model_dict = {} for k in tmp_k: aux_model_dict[k] = ae_model_dict[k.replace("aux", "utt")] #utt_encoder param in ae model is now moved to aux_params in a new dict model_dict.update(ae_model_dict) # all params except aux_encoder model_dict.update(aux_model_dict) # aux encoder model.load_state_dict(model_dict) if config.use_gpu: model.cuda() # only train utt_encoder if config.ae_config_path is not None and config.selective_fine_tune: for name, param in model.named_parameters(): if "utt_encoder" not in name: # if "decoder" in name or "z_embedding" in name: param.requires_grad = False best_epoch = None if not config.forward_only: try: best_epoch = mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=task_generate) except KeyboardInterrupt: print('Training stopped by keyboard.') if best_epoch is None: model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if 'model' in p and 'rl' not in p and 'aux' not in p]) best_epoch = model_ids[-1] print("$$$ Load {}-model".format(best_epoch)) # config.batch_size = 32 model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch)))) logger.info("Forward Only Evaluation") validate(model, val_data, config) validate(model, test_data, config) with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f) with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f) with open(os.path.join(saved_path, '{}_{}_valid_AE_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, aux_val_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True) with open(os.path.join(saved_path, '{}_{}_test_AE_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, aux_test_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True) end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[END]', end_time, '=' * 30) if __name__ == "__main__": ### Train from scratch ### for i in [1, 2, 3, 4]: main(i, None, None) ### Load pre-trained AE ### # multiwoz 2.0 # pretrained = {"2020-02-28-16-49-48-sl_gauss_ae":100} # multiwoz 2.1 # pretrained = {"2020-10-15-17-11-59-sl_gauss_ae":92} # for p in pretrained.keys(): # folder = p # id_ = pretrained[p] # main(None, folder, id_)