# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

TODO: Give option to compare multiple models
TODO: Use sklearn for classification -> Done
TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done
TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
TODO: Add log to pipeline
TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation
TODO: Add boxplot over CFV
TODO: Improve maximum selection runtime
TODO: Fix tensor mapping warning

"""
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold
# from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score

import ANN_Model
from Plotting import plot_classification_accuracy


class ModelTrainer(object):
    def __init__(self, config):
        self._reset(config)

    def _reset(self, config):
        self._dir = config.pop('dir', 'test_data')
        self._model_name = config.pop('model_name', '0')
        self._read_training_data()

        self._batch_size = config.pop('batch_size', min(len(self._training_data)//2, 500))
        self._num_epochs = config.pop('num_epochs', 1000)
        self._threshold = config.pop('threshold', 1e-5)
        self._model = config.pop('model', 'ThreeLayerReLu')
        self._model_config = config.pop('model_config', {})
        self._loss_function = config.pop('loss_function', 'BCELoss')
        self._loss_config = config.pop('loss_config', {})
        self._optimizer = config.pop('optimizer', 'Adam')
        self._optimizer_config = config.pop('optimizer_config', {})

        # Set learning rate
        self._learning_rate = config.pop('learning_rate', 1e-2)
        self._optimizer_config['lr'] = self._learning_rate

        if not hasattr(ANN_Model, self._model):
            raise ValueError('Invalid model: "%s"' % self._model)
        if not hasattr(torch.nn.modules.loss, self._loss_function):
            raise ValueError('Invalid loss function: "%s"' % self._loss_function)
        if not hasattr(torch.optim, self._optimizer):
            raise ValueError('Invalid optimizer: "%s"' % self._optimizer)

        self._model = getattr(ANN_Model, self._model)(self._model_config)
        self._loss_function = getattr(torch.nn.modules.loss, self._loss_function)(
            **self._loss_config)
        self._optimizer = getattr(torch.optim, self._optimizer)(
            self._model.parameters(), **self._optimizer_config)
        self._validation_loss = torch.zeros(self._num_epochs//10)

    def _read_training_data(self):
        # Get training dataset from saved file and map to Torch tensor and dataset
        input_file = self._dir + '/input_data.npy'
        output_file = self._dir + '/output_data.npy'
        self._training_data = TensorDataset(*map(torch.tensor, (np.load(input_file),
                                                                np.load(output_file))))

    def epoch_training(self, dataset=None, num_epochs=None):
        # Split data into training and validation set
        if dataset is None:
            dataset = self._training_data
        if num_epochs is None:
            num_epochs = self._num_epochs
        num_samples = len(dataset)
        train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8), round(num_samples*0.2)])

        # Load sets
        train_dl = DataLoader(train_ds, batch_size=self._batch_size, shuffle=True)
        valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)

        # Training with Validation
        for epoch in range(num_epochs):
            self._model.train()
            for x_batch, y_batch in train_dl:
                pred = self._model(x_batch.float())
                loss = self._loss_function(pred, y_batch.float()).mean()

                # Run back propagation, update the weights, and zero gradients for next epoch
                loss.backward()
                self._optimizer.step()
                self._optimizer.zero_grad()

            self._model.eval()
            with torch.no_grad():
                valid_loss = sum(
                    self._loss_function(self._model(x_batch_valid.float()), y_batch_valid.float())
                    for x_batch_valid, y_batch_valid in valid_dl)

                if (epoch+1) % 100 == 0:
                    self._validation_loss[int((epoch+1) / 100)-1] = valid_loss / len(valid_dl)
                    print(epoch+1, valid_loss / len(valid_dl))

                if valid_loss / len(valid_dl) < self._threshold:
                    break

    def test_model(self, num_iterations=100):
        classification_stats = []
        for iteration in range(num_iterations):
            dataset = self._training_data
            for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset):
                # print("TRAIN:", train_index, "TEST:", test_index)
                training_set = TensorDataset(*map(torch.tensor, (dataset[train_index])))
                test_set = dataset[test_index]

                classification_stats.append(self._test_fold(training_set, test_set))

        # print(classification_stats)
        # print(np.array(classification_stats).mean(axis=0))
        classification_stats = np.array(classification_stats).mean(axis=0)

        plot_classification_accuracy([self._model_name], *classification_stats)

        # Set paths for plot files if not existing already
        plot_dir = self._dir + '/model evaluation'
        if not os.path.exists(plot_dir):
            os.makedirs(plot_dir)

        # Save plots
        for identifier in plt.get_figlabels():
            # Set path for figure directory if not existing already
            if not os.path.exists(plot_dir + '/' + identifier):
                os.makedirs(plot_dir + '/' + identifier)

            plt.figure(identifier)
            plt.savefig(plot_dir + '/' + identifier + '/' + self._model_name + '.pdf')

    def _test_fold(self, training_set, test_set):
        self.epoch_training(training_set, num_epochs=100)
        self._model.eval()

        x_test, y_test = test_set
        # print(self._model(x_test.float()))
        model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
                                     for value in torch.max(self._model(x_test.float()), 1)[1]])

        y_true = y_test.detach().numpy()
        y_pred = model_output.detach().numpy()
        accuracy = accuracy_score(y_true, y_pred)
        # print('sklearn', accuracy)
        precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
        # print(precision, recall)
        # print()

        return [precision[0], recall[0], accuracy]

    def save_model(self):
        # Saving Model
        name = self._model_name

        # Set paths for plot files if not existing already
        model_dir = self._dir + '/trained models'
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
        torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')

    def _classify(self):
        pass


# Loss Functions: BCELoss, BCEWithLogitsLoss,
# CrossEntropyLoss (not working), MSELoss (with reduction='sum')
# Optimizer: Adam, SGD
# trainer = ModelTrainer({'num_epochs': 1000})
# trainer.epoch_training()
# trainer.test_model()
# trainer.save_model()