import time
import os
import json
import sys
from pathlib import Path
sys.path.append((Path(__file__).parent / '../').resolve().as_posix())
import torch as th
import pdb
import logging
import random
from latent_dialog.utils import Pack, prepare_dirs_loggers, set_seed
import latent_dialog.corpora as corpora
from latent_dialog.data_loaders import BeliefDbDataLoadersAE, BeliefDbDataLoaders
from latent_dialog.evaluators import MultiWozEvaluator
from latent_dialog.models_task import SysMTGauss
from latent_dialog.main import train, validate, mt_train
import latent_dialog.domain as domain
from experiments_woz.dialog_utils import task_generate

# def main(seed):

def main(seed, pretrained_folder, pretrained_model_id):
    domain_name = 'object_division'
    domain_info = domain.get_domain(domain_name)

    if pretrained_folder is not None:
        ae_config_path = os.path.join('sys_config_log_model', pretrained_folder, 'config.json')
        ae_config = Pack(json.load(open(ae_config_path)))
        ae_model_path = os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id))
        train_path=ae_config.train_path
        valid_path=ae_config.valid_path
        test_path=ae_config.test_path

    else:
        ae_model_path = None
        ae_config_path = None
        train_path='../data/data_2.1/train_dials.json'
        valid_path='../data/data_2.1/val_dials.json'
        test_path='../data/data_2.1/test_dials.json'

    if seed is None:
        seed = ae_config.seed
 
    base_path = Path(__file__).parent
    config = Pack(
        seed = seed,
        ae_model_path=ae_model_path,
        ae_config_path=ae_config_path,
        train_path=train_path,
        valid_path=valid_path,
        test_path=test_path,
        dact_path=(base_path / '../data/norm-multi-woz/damd_dialog_acts.json').resolve().as_posix(),
        ae_zero_pad=True,
        max_vocab_size=1000,
        last_n_model=1,
        max_utt_len=50,
        max_dec_len=50,
        backward_size=2,
        batch_size=128,
        use_gpu=True,
        op='adam',
        init_lr=0.0005,
        l2_norm=1e-05,
        momentum=0.0,
        grad_clip=5.0,
        dropout=0.5,
        max_epoch=50,
        # aux_train_freq=10,
        # aux_max_epoch=1, # epoch per training, i.e every aux_train_freq, train aux_max_epoch time
        shared_train=True,
        embed_size=256,
        num_layers=1,
        utt_rnn_cell='gru',
        utt_cell_size=300,
        bi_utt_cell=True,
        enc_use_attn=True,
        dec_use_attn=False,
        dec_rnn_cell='lstm',
        dec_cell_size=300,
        dec_attn_mode='cat',
        y_size=200,
        beta = 0.01,
        # aux_pi_beta = 0.01, # default to 1.0 if not set
        simple_posterior=True,
        contextual_posterior=False,
        use_mi =False,
        use_pr =True,
        use_diversity = False,
        #
        beam_size=20,
        fix_batch = True,
        fix_train_batch=False,
        avg_type='word',
        # avg_type='slot',
        # slot_weight=10,
        use_aux_kl=False,
        selective_fine_tune=False,
        print_step=300,
        ckpt_step=1416,
        improve_threshold=0.996,
        patient_increase=2.0,
        save_model=True,
        early_stop=False,
        gen_type='greedy',
        preview_batch_num=50,
        k=domain_info.input_length(),
        init_range=0.1,
        pretrain_folder='2021-11-25-16-47-37-mt_gauss',
        forward_only=False
    )
    set_seed(config.seed)

    start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    stats_path = (base_path / 'sys_config_log_model').resolve().as_posix()
    if config.forward_only:
        saved_path = os.path.join(stats_path, config.pretrain_folder)
        config = Pack(json.load(open(os.path.join(saved_path, 'config.json'))))
        config['forward_only'] = True
    else:
        saved_path = os.path.join(stats_path, start_time+'-'+os.path.basename(__file__).split('.')[0])
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
    config.saved_path = saved_path

    if ae_config_path is not None:
        config["max_dec_len"] = ae_config.max_dec_len
        config["dec_use_attn"] = ae_config.dec_use_attn
        config["dec_rnn_cell"] = ae_config.dec_rnn_cell
        config["dec_cell_size"] = ae_config.dec_cell_size
        config["dec_attn_mode"] = ae_config.dec_attn_mode
        config["embed_size"] = ae_config.embed_size
        config["y_size"] = ae_config.y_size
        # Denoising setup
        if "noise_type" in ae_config:
            noise_type = ae_config.noise_type
        else:
            noise_type = None
        if noise_type:
            config["noise_type"] = ae_config.noise_type
            config["noise_p"] = ae_config.noise_p
            config["remove_tokens"] = ae_config.remove_tokens
            config["no_special"] = ae_config.no_special
    else:
        noise_type = None
        vocab_dict = None


    prepare_dirs_loggers(config)
    logger = logging.getLogger()
    logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30))
    config.saved_path = saved_path

    # save configuration
    with open(os.path.join(saved_path, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)  # sort_keys=True

    # data for AE training
    aux_corpus = corpora.NormMultiWozCorpusAE(config)
    vocab_dict = aux_corpus.vocab_dict if noise_type else None
    aux_train_dial, aux_val_dial, aux_test_dial = aux_corpus.get_corpus()

    aux_train_data = BeliefDbDataLoadersAE('Train', aux_train_dial, config, noise=noise_type, ind_voc=vocab_dict, logger=logger)
    aux_val_data = BeliefDbDataLoadersAE('Val', aux_val_dial, config)
    aux_test_data = BeliefDbDataLoadersAE('Test', aux_test_dial, config)

    # data for RG training
    corpus = corpora.NormMultiWozCorpus(config)
    train_dial, val_dial, test_dial = corpus.get_corpus()

    train_data = BeliefDbDataLoaders('Train', train_dial, config)
    val_data = BeliefDbDataLoaders('Val', val_dial, config)
    test_data = BeliefDbDataLoaders('Test', test_dial, config)

    evaluator = MultiWozEvaluator('SysWoz', config)

    model = SysMTGauss(corpus, config)
    model_dict = model.state_dict()

    # load params from saved ae_model
    if pretrained_folder is not None:
        ae_model_dict = th.load(config.ae_model_path, map_location=lambda storage, location: storage)
        tmp_k = [k for k in model_dict.keys() if k not in ae_model_dict.keys()]
        aux_model_dict = {}
        for k in tmp_k:
            aux_model_dict[k] = ae_model_dict[k.replace("aux", "utt")] #utt_encoder param in ae model is now moved to aux_params in a new dict
        model_dict.update(ae_model_dict) # all params except aux_encoder
        model_dict.update(aux_model_dict) # aux encoder
        model.load_state_dict(model_dict)

    if config.use_gpu:
        model.cuda()
    
    # only train utt_encoder
    if config.ae_config_path is not None and config.selective_fine_tune:
        for name, param in model.named_parameters():
            if "utt_encoder" not in name:
            # if "decoder" in name or "z_embedding" in name:
                param.requires_grad = False

    best_epoch = None
    if not config.forward_only:
        try:
            best_epoch = mt_train(model, train_data, val_data, test_data, aux_train_data, aux_val_data, aux_test_data, config, evaluator, gen=task_generate)
        except KeyboardInterrupt:
            print('Training stopped by keyboard.')
    if best_epoch is None:
        model_ids = sorted([int(p.replace('-model', '')) for p in os.listdir(saved_path) if 'model' in p and 'rl' not in p and 'aux' not in p])
        best_epoch = model_ids[-1]

    print("$$$ Load {}-model".format(best_epoch))
    # config.batch_size = 32
    model.load_state_dict(th.load(os.path.join(saved_path, '{}-model'.format(best_epoch))))


    logger.info("Forward Only Evaluation")

    validate(model, val_data, config)
    validate(model, test_data, config)

    with open(os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f:
        task_generate(model, val_data, config, evaluator, num_batch=None, dest_f=f)

    with open(os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f:
        task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f)

    with open(os.path.join(saved_path, '{}_{}_valid_AE_file.txt'.format(start_time, best_epoch)), 'w') as f:
        task_generate(model, aux_val_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True)
     
    with open(os.path.join(saved_path, '{}_{}_test_AE_file.txt'.format(start_time, best_epoch)), 'w') as f:
        task_generate(model, aux_test_data, config, evaluator, num_batch=None, dest_f=f, aux_mt=True)



    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)

if __name__ == "__main__":
    ### Train from scratch ###

    for i in [1, 2, 3, 4]:
        main(i, None, None)
    

    ### Load pre-trained AE ###
    # multiwoz 2.0
    # pretrained = {"2020-02-28-16-49-48-sl_gauss_ae":100}

    # multiwoz 2.1
    # pretrained = {"2020-10-15-17-11-59-sl_gauss_ae":92}
    
    # for p in pretrained.keys():
        # folder = p
        # id_ = pretrained[p]
        # main(None, folder, id_)