# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

Code-Style: E226, W503
Docstring-Style: D200, D400

TODO: Test new ANN set-up with Soraya
TODO: Remove object set-up (for more flexibility)
TODO: Add documentation
TODO: Improve log output (timer, bit of text) -> Done
TODO: Throw exception for error due to missing classes
TODO: Allow multiple approximations in one config

"""
import numpy as np
import time
import matplotlib
from matplotlib import pyplot as plt
import os
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

import ANN_Model
from Plotting import plot_classification_accuracy, plot_boxplot

matplotlib.use('Agg')


class ModelTrainer(object):
    def __init__(self, config):
        self._reset(config)

    def _reset(self, config):
        self._dir = config.pop('dir', 'test_data')
        self._model_name = config.pop('model_name', '0')
        self._training_data = read_training_data(self._dir)

        self._batch_size = config.pop('batch_size', min(len(self._training_data)//2, 500))
        self._num_epochs = config.pop('num_epochs', 1000)
        self._threshold = config.pop('threshold', 1e-5)
        self._model = config.pop('model', 'ThreeLayerReLu')
        self._model_config = config.pop('model_config', {})
        self._loss_function = config.pop('loss_function', 'BCELoss')
        self._loss_config = config.pop('loss_config', {})
        self._optimizer = config.pop('optimizer', 'Adam')
        self._optimizer_config = config.pop('optimizer_config', {})

        # Set learning rate
        self._learning_rate = config.pop('learning_rate', 1e-2)
        self._optimizer_config['lr'] = self._learning_rate

        if not hasattr(ANN_Model, self._model):
            raise ValueError('Invalid model: "%s"' % self._model)
        if not hasattr(torch.nn.modules.loss, self._loss_function):
            raise ValueError('Invalid loss function: "%s"' % self._loss_function)
        if not hasattr(torch.optim, self._optimizer):
            raise ValueError('Invalid optimizer: "%s"' % self._optimizer)

        self._model = getattr(ANN_Model, self._model)(self._model_config)
        self._loss_function = getattr(torch.nn.modules.loss, self._loss_function)(
            **self._loss_config)
        self._optimizer = getattr(torch.optim, self._optimizer)(
            self._model.parameters(), **self._optimizer_config)
        self._validation_loss = torch.zeros(self._num_epochs//10)

    def epoch_training(self, dataset=None, num_epochs=None, verbose=True):
        tic = time.perf_counter()
        # Split data into training and validation set
        if dataset is None:
            dataset = self._training_data
        if num_epochs is None:
            num_epochs = self._num_epochs
        num_samples = len(dataset)
        if verbose:
            print('Splitting data randomly into training and validation set.')
        train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8), round(num_samples*0.2)])

        # Load sets
        train_dl = DataLoader(train_ds, batch_size=self._batch_size, shuffle=True)
        valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)

        # Training with Validation
        if verbose:
            print('\nTraining model...')
            print('Number of epochs:', num_epochs)
        tic_train = time.perf_counter()
        for epoch in range(num_epochs):
            self._model.train()
            for x_batch, y_batch in train_dl:
                pred = self._model(x_batch.float())
                loss = self._loss_function(pred, y_batch.float()).mean()

                # Run back propagation, update the weights, and zero gradients for next epoch
                loss.backward()
                self._optimizer.step()
                self._optimizer.zero_grad()

            self._model.eval()
            with torch.no_grad():
                valid_loss = sum(
                    self._loss_function(self._model(x_batch_valid.float()), y_batch_valid.float())
                    for x_batch_valid, y_batch_valid in valid_dl)

                if (epoch+1) % 100 == 0:
                    self._validation_loss[int((epoch+1) / 100)-1] = valid_loss / len(valid_dl)
                    if verbose:
                        print(epoch+1, 'epochs completed. Loss:', valid_loss / len(valid_dl))

                if valid_loss / len(valid_dl) < self._threshold:
                    break
        toc_train = time.perf_counter()
        if verbose:
            print('Finished training model!')
            print(f'Training time: {toc_train-tic_train:0.4f}s\n')
        toc = time.perf_counter()
        if verbose:
            print(f'Total runtime: {toc-tic:0.4f}s\n')

    def test_model(self, training_set, test_set):
        self.epoch_training(training_set, num_epochs=100, verbose=False)
        self._model.eval()

        x_test, y_test = test_set
        model_score = self._model(x_test.float())
        model_output = torch.argmax(model_score, dim=1)

        y_true = y_test.detach().numpy()[:, 1]
        y_pred = model_output.detach().numpy()
        accuracy = accuracy_score(y_true, y_pred)
        precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
        auroc = roc_auc_score(y_true, y_pred)

        return {'Precision_Smooth': precision[0], 'Precision_Troubled': precision[1],
                'Recall_Smooth': recall[0], 'Recall_Troubled': recall[1],
                'F-Score_Smooth': f_score[0], 'F-Score_Troubled': f_score[1],
                'Accuracy': accuracy, 'AUROC': auroc}

    def save_model(self):
        # Saving Model
        name = self._model_name

        # Set paths for files if not existing already
        model_dir = self._dir + '/trained models'
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
        torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')


def read_training_data(directory, normalized=True):
    # Get training dataset from saved file and map to Torch tensor and dataset
    input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy')
    output_file = directory + '/output_data.npy'
    return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file))))


def evaluate_models(models, directory, num_iterations=100, colors=None,
                    compare_normalization=False):
    tic = time.perf_counter()
    if colors is None:
        colors = {'Accuracy': 'magenta', 'Precision_Smooth': 'red',
                  'Precision_Troubled': '#8B0000', 'Recall_Smooth': 'blue',
                  'Recall_Troubled': '#00008B', 'F-Score_Smooth': 'green',
                  'F-Score_Troubled': '#006400', 'AUROC': 'yellow'}

    print('Read normalized training data.')
    datasets = {'normalized': read_training_data(directory)}
    if compare_normalization:
        print('Read raw, non-normalized training data.')
        datasets['raw'] = read_training_data(directory, False)
    classification_stats = {measure: {model + ' (' + dataset + ')': [] for model in models
                                      for dataset in datasets} for measure in colors}

    print('\nTraining models with 5-fold cross validation...')
    print('Number of iterations:', num_iterations)
    tic_train = time.perf_counter()
    for iteration in range(num_iterations):
        for train_index, test_index in KFold(
                n_splits=5, shuffle=True).split(datasets['normalized']):
            for dataset in datasets.keys():
                training_set = TensorDataset(*datasets[dataset][train_index])
                test_set = datasets[dataset][test_index]

                for model in models:
                    result = models[model].test_model(training_set, test_set)
                    for measure in colors:
                        classification_stats[measure][model + ' (' + dataset + ')'].append(
                            result[measure])
        if iteration+1%max(10, 10*(num_iterations//100)):
            print(iteration+1, 'iterations completed.')
    toc_train = time.perf_counter()
    print('Finished training models with 5-fold cross validation!')
    print(f'Training time: {toc_train - tic_train:0.4f}s\n')

    print('Plotting evaluation of trained models.')
    plot_boxplot(classification_stats, colors)
    classification_stats = {measure: {model + ' (' + dataset + ')': np.array(
        classification_stats[measure][model + ' (' + dataset + ')']).mean() for model in models
                                      for dataset in datasets} for measure in colors}
    plot_classification_accuracy(classification_stats, colors)

    # Set paths for plot files if not existing already
    plot_dir = directory + '/model evaluation'
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)

    # Save plots
    print('Saving plots.')
    for identifier in plt.get_figlabels():
        # Set path for figure directory if not existing already
        if not os.path.exists(plot_dir + '/' + identifier):
            os.makedirs(plot_dir + '/' + identifier)

        plt.figure(identifier)
        plt.savefig(plot_dir + '/' + identifier + '/' + '_'.join(models.keys()) + '.pdf')
    toc = time.perf_counter()
    print(f'Total runtime: {toc - tic:0.4f}s')