# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

Code-Style: E226, W503
Docstring-Style: D200, D400

TODO: Add README for ANN training

"""
import numpy as np
import time
import os
import torch
import json
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, \
    roc_auc_score

import ANN_Model


class ModelTrainer:
    """Class for ANN model training.

    Trains and tests a model with set loss function and optimizer.

    Attributes
    ----------
    model : torch.nn.Module
        ANN model instance for evaluation.
    loss_function : torch.nn.modules.loss
        Function to evaluate loss during model training.
    optimizer : torch.optim
        Optimizer for model training.
    validation_loss : torch.Tensor
        List of validation loss values during training.

    Methods
    -------
    epoch_training(dataset, num_epochs, verbose)
        Trains model for a given number of epochs.
    test_model(training_set, test_set)
        Evaluates predictions of a model.
    save_model(directory, model_name)
        Saves state and validation loss of a model.

    """
    def __init__(self, config: dict) -> None:
        """Initializes ModelTrainer.

        Parameters
        ----------
        config : dict
            Additional parameters for model trainer.

        """
        self._reset(config)

    def _reset(self, config: dict) -> None:
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for model trainer.

        """
        self._batch_size = config.pop('batch_size', 500)
        self._num_epochs = config.pop('num_epochs', 1000)
        self._threshold = config.pop('threshold', 1e-5)

        # Set random seed
        seed = config.pop('random_seed', None)
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.use_deterministic_algorithms(True)

        model = config.pop('model', 'ThreeLayerReLu')
        model_config = config.pop('model_config', {})
        loss_function = config.pop('loss_function', 'BCELoss')
        loss_config = config.pop('loss_config', {})
        optimizer = config.pop('optimizer', 'Adam')
        optimizer_config = config.pop('optimizer_config', {})

        # Set learning rate
        learning_rate = config.pop('learning_rate', 1e-2)
        optimizer_config['lr'] = learning_rate

        if not hasattr(ANN_Model, model):
            raise ValueError('Invalid model: "%s"' % model)
        if not hasattr(torch.nn.modules.loss, loss_function):
            raise ValueError('Invalid loss function: "%s"' % loss_function)
        if not hasattr(torch.optim, optimizer):
            raise ValueError('Invalid optimizer: "%s"' % optimizer)

        self._model = getattr(ANN_Model, model)(model_config)
        self._loss_function = getattr(torch.nn.modules.loss, loss_function)(
            **loss_config)
        self._optimizer = getattr(torch.optim, optimizer)(
            self._model.parameters(), **optimizer_config)
        self._validation_loss = torch.zeros(self._num_epochs//10)

    def epoch_training(self, dataset: torch.utils.data.dataset.TensorDataset,
                       num_epochs: int = None, verbose: bool = True) -> None:
        """Trains model for a given number of epochs.

        Trains model and saves the validation loss. The training stops after
        the given number of epochs or if the threshold is reached.

        Parameters
        ----------
        dataset : torch.utils.data.dataset.TensorDataset
            Training dataset.
        num_epochs : int, optional
            Number of epochs for training.
            Default: None (i.e. instance variable).
        verbose : bool, optional
            Flag whether commentary in console is wanted. Default: False.

        """
        tic = time.perf_counter()
        if num_epochs is None:
            num_epochs = self._num_epochs

        # Split data into training and validation set
        num_samples = len(dataset)
        if verbose:
            print('Splitting data randomly into training and validation set.')
        train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8),
                                                    round(num_samples*0.2)])

        # Load sets
        train_dl = DataLoader(train_ds, batch_size=self._batch_size,
                              shuffle=True)
        valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)

        # Train with validation
        if verbose:
            print('\nTraining model...')
            print('Number of epochs:', num_epochs)
        tic_train = time.perf_counter()
        for epoch in range(num_epochs):
            self._model.train()
            for x_batch, y_batch in train_dl:
                pred = self._model(x_batch.float())
                loss = self._loss_function(pred, y_batch.float()).mean()

                # Run back propagation, update the weights,
                # and zero gradients for next epoch
                loss.backward()
                self._optimizer.step()
                self._optimizer.zero_grad()

            # Determine validation loss
            self._model.eval()
            with torch.no_grad():
                valid_loss = sum(
                    self._loss_function(self._model(x_batch_valid.float()),
                                        y_batch_valid.float())
                    for x_batch_valid, y_batch_valid in valid_dl)

                # Report validation loss
                if (epoch+1) % 100 == 0:
                    self._validation_loss[int((epoch+1) / 100)-1] \
                        = valid_loss / len(valid_dl)
                    if verbose:
                        print(epoch+1, 'epochs completed. Loss:',
                              valid_loss / len(valid_dl))

                # Interrupt if threshold is reached
                if valid_loss / len(valid_dl) < self._threshold:
                    break
        toc_train = time.perf_counter()
        if verbose:
            print('Finished training model!')
            print(f'Training time: {toc_train-tic_train:0.4f}s\n')
        toc = time.perf_counter()
        if verbose:
            print(f'Total runtime: {toc-tic:0.4f}s\n')

    def test_model(self, training_set: torch.utils.data.dataset.TensorDataset,
                   test_set: torch.utils.data.dataset.TensorDataset) -> dict:
        """Evaluates predictions of a model.

        Trains a model and compares the predicted and true results by
        evaluating precision, recall, and f-score for both classes,
        as well as accuracy and AUROC score.

        Parameters
        ----------
        training_set : torch.utils.data.dataset.TensorDataset
            Training dataset.
        test_set : torch.utils.data.dataset.TensorDataset
            Test dataset.

        Returns
        -------
        dict
            Dictionary containing classification evaluation data.

        """
        # Train model
        self.epoch_training(training_set, num_epochs=50, verbose=False)
        self._model.eval()

        # Classify data
        x_test, y_test = test_set
        model_score = self._model(x_test.float())
        model_output = torch.argmax(model_score, dim=1)

        # Evaluate classification
        y_true = y_test.detach().numpy()[:, 1]
        y_pred = model_output.detach().numpy()
        accuracy = accuracy_score(y_true, y_pred)
        precision, recall, f_score, support = precision_recall_fscore_support(
            y_true, y_pred, zero_division=0)
        auroc = roc_auc_score(y_true, y_pred)

        return {'Precision_Smooth': precision[0],
                'Precision_Troubled': precision[1],
                'Recall_Smooth': recall[0],
                'Recall_Troubled': recall[1],
                'F-Score_Smooth': f_score[0],
                'F-Score_Troubled': f_score[1],
                'Accuracy': accuracy,
                'AUROC': auroc}

    def save_model(self, directory: str,
                   model_name: str = 'test_model') -> None:
        """Saves state and validation loss of a model.

        Parameters
        ----------
        directory : str
           Path to directory in which model is saved.
        model_name : str, optional
            Name of model for saving. Default: 'test_model'.

        """
        # Set paths for files if not existing already
        model_dir = directory + '/trained models'
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # Save model and loss
        torch.save(self._model.state_dict(), model_dir + '/' +
                   model_name + '.model.pt')
        torch.save(self._validation_loss, model_dir + '/' +
                   model_name + '.loss.pt')


def read_training_data(directory: str, normalized: bool = True) -> \
        torch.utils.data.dataset.TensorDataset:
    """Reads training data from directory.

    Parameters
    ----------
    directory : str
        Path to directory in which training data is saved.
    normalized : bool, optional
        Flag whether normalized data should be used. Default: True.

    Returns
    -------
    torch.utils.data.dataset.TensorDataset
        Training dataset.

    """
    # Get training dataset from saved file and map to Torch tensor and dataset
    input_file = directory + ('/input_data.normalized.npy'
                              if normalized else '/input_data.raw.npy')
    output_file = directory + '/output_data.npy'
    return TensorDataset(*map(torch.tensor, (np.load(input_file),
                                             np.load(output_file))))


def evaluate_models(models: dict, directory: str, num_iterations: int = 100,
                    compare_normalization: bool = False) -> None:
    """Evaluates the classification of a given set of models.

    Evaluates the classification and saves the results in a JSON file.

    Parameters
    ----------
    models : dict
        Dictionary of models to evaluate.
    directory : str
        Path to directory for saving resulting plots.
    num_iterations : int, optional
        Number of iterations for evaluation. Default: 100.
    compare_normalization : bool, optional
        Flag whether both normalized and raw data should be evaluated.
        Default: False.

    """
    tic = time.perf_counter()

    # Read training data
    print('Read normalized training data.')
    datasets = {'normalized': read_training_data(directory)}
    if compare_normalization:
        print('Read raw, non-normalized training data.')
        datasets['raw'] = read_training_data(directory, False)

    # Train models for evaluation
    print('\nTraining models with 5-fold cross validation...')
    print('Number of iterations:', num_iterations)
    tic_train = time.perf_counter()
    classification_stats = {}
    for iteration in range(num_iterations):
        # Split data for cross validation
        for train_index, test_index in KFold(
                n_splits=5, shuffle=True).split(datasets['normalized']):
            for dataset in datasets.keys():
                training_set = TensorDataset(*datasets[dataset][train_index])
                test_set = datasets[dataset][test_index]

                # Save results for each model on split dataset
                for model in models:
                    result = models[model].test_model(training_set, test_set)
                    for measure in result.keys():
                        if measure not in classification_stats.keys():
                            classification_stats[measure] = \
                                {model + ' (' + dataset + ')': []
                                 for model in models
                                 for dataset in datasets}
                        classification_stats[measure][model + ' (' + dataset +
                                                      ')'].append(
                            result[measure])
        # Report status
        if iteration+1 % max(10, 10*(num_iterations//100)):
            print(iteration+1, 'iterations completed.')
    toc_train = time.perf_counter()
    print('Finished training models with 5-fold cross validation!')
    print(f'Training time: {toc_train - tic_train:0.4f}s\n')

    # Set paths for plot files if not existing already
    plot_dir = directory + '/model evaluation'
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)

    # Save evaluation results in JSON format
    print('Saving evaluation results in JSON format.')
    with open(plot_dir + '/' + '_'.join(models.keys()) + '.json', 'w')\
            as json_file:
        json_file.write(json.dumps(classification_stats))
    toc = time.perf_counter()
    print(f'Total runtime: {toc - tic:0.4f}s')