# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead
    of every second)
TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection

"""
from abc import ABC, abstractmethod
import numpy as np
import torch

import ANN_Model
from projection_utils import Mesh


class TroubledCellDetector(ABC):
    """Abstract class for troubled-cell detection.

    Detects troubled cells, i.e., cells in the mesh containing instabilities.

    Methods
    -------
    get_name()
        Returns string of class name.
    get_cells(projection)
        Calculates troubled cells in a given projection.
    create_data_dict(projection)
        Return dictionary with data necessary to plot troubled cells.

    """
    def __init__(self, config, basis, mesh):
        """Initializes TroubledCellDetector.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.
        basis : Basis object
            Basis for calculation.
        mesh : Mesh
            Mesh for calculation.

        """
        self._mesh = mesh
        self._basis = basis

        self._reset(config)

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        pass

    def get_name(self):
        """Returns string of class name."""
        return self.__class__.__name__

    @abstractmethod
    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        """
        pass

    def create_data_dict(self, projection):
        """Return dictionary with data necessary to plot troubled cells."""
        return {'projection': projection,
                'basis': self._basis.create_data_dict(),
                'mesh': self._mesh.create_data_dict()
                }


class NoDetection(TroubledCellDetector):
    """Class without any troubled-cell detection.

    Methods
    -------
    get_cells(projection)
        Returns no troubled cells.

    """
    def get_cells(self, projection):
        """Returns no troubled cells.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        return []


class ArtificialNeuralNetwork(TroubledCellDetector):
    """Class for troubled-cell detection using ANNs.

    Attributes
    ----------
    stencil_length : int
        Size of input data array.
    model : torch.nn.Model
        ANN model instance for evaluation.

    Methods
    -------
    get_cells(projection)
        Calculates troubled cells in a given projection.


    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        self._stencil_len = config.pop('stencil_len', 3)
        self._add_reconstructions = config.pop('add_reconstructions', True)
        self._model = config.pop('model', 'ThreeLayerReLu')
        num_datapoints = self._stencil_len
        if self._add_reconstructions:
            num_datapoints += 2
        self._model_config = config.pop('model_config', {
            'input_size': num_datapoints, 'first_hidden_size': 8,
            'second_hidden_size': 4, 'output_size': 2,
            'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
        model_state = config.pop('model_state', 'Snakemake-Test/trained '
                                                'models/model__Adam.pt')

        if not hasattr(ANN_Model, self._model):
            raise ValueError('Invalid model: "%s"' % self._model)
        self._model = getattr(ANN_Model, self._model)(self._model_config)

        # Load the model state and set it to evaluation mode
        self._model.load_state_dict(torch.load(str(model_state)))
        self._model.eval()

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        # Reset ghost cells to adjust for stencil length
        num_ghost_cells = self._stencil_len//2
        projection = projection[:, 1:-1]
        projection = np.concatenate((projection[:, -num_ghost_cells:],
                                     projection,
                                     projection[:, :num_ghost_cells]), axis=1)

        # Calculate input data depending on stencil length
        input_data = torch.from_numpy(np.vstack([
            self._basis.calculate_cell_average(
                projection=projection[
                           :, cell-num_ghost_cells:cell+num_ghost_cells+1],
                stencil_length=self._stencil_len,
                add_reconstructions=self._add_reconstructions)
            for cell in range(num_ghost_cells,
                              len(projection[0])-num_ghost_cells)]))

        # Determine troubled cells
        model_output = torch.argmax(self._model(input_data.float()), dim=1)
        return [cell for cell in range(len(model_output))
                if model_output[cell] == torch.tensor([1])]


class WaveletDetector(TroubledCellDetector):
    """Abstract class for wavelet coefficient based troubled-cell detection.

    ???

    """

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Set additional necessary parameter
        self._num_coarse_grid_cells = self._mesh.num_grid_cells//2
        self._wavelet_projection_left, self._wavelet_projection_right \
            = self._basis.multiwavelet_projection

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        multiwavelet_coeffs = self._calculate_wavelet_coeffs(
            projection[:, 1: -1])
        return self._get_cells(multiwavelet_coeffs, projection)

    def _calculate_wavelet_coeffs(self, projection):
        """Calculates wavelet coefficients used for projection to coarser grid.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of wavelet coefficients.

        """
        output_matrix = []
        for i in range(self._num_coarse_grid_cells):
            new_entry = 0.5*(
                    projection[:, 2*i] @ self._wavelet_projection_left
                    + projection[:, 2*i+1] @ self._wavelet_projection_right)
            output_matrix.append(new_entry)
        return np.transpose(np.array(output_matrix))

    @abstractmethod
    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : ndarray
            Matrix of multiwavelet coefficients.
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        pass

    def _calculate_coarse_projection(self, projection):
        """Calculates coarse projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of projection on coarse grid for each polynomial degree.

        """
        basis_projection_left, basis_projection_right\
            = self._basis.basis_projection

        # Remove ghost cells
        projection = projection[:, 1:-1]

        # Calculate projection on coarse mesh
        output_matrix = []
        for i in range(self._num_coarse_grid_cells):
            new_entry = 0.5 * (
                    projection[:, 2 * i] @ basis_projection_left
                    + projection[:, 2 * i + 1] @ basis_projection_right)
            output_matrix.append(new_entry)
        coarse_projection = np.transpose(np.array(output_matrix))

        return coarse_projection

    def create_data_dict(self, projection):
        """Return dictionary with data necessary to plot troubled cells."""
        # Create general directory
        data_dict = super().create_data_dict(projection)

        # Save multiwavelet-specific data in dictionary
        multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
        coarse_projection = self._calculate_coarse_projection(projection)
        data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs
        data_dict['coarse_projection'] = coarse_projection

        return data_dict


class Boxplot(WaveletDetector):
    """Class for troubled-cell detection based on Boxplots.

    Attributes
    ----------
    fold_len : int
        Length of folds considered in one Boxplot.
    whisker_len : int
        Length of Boxplot whiskers.

    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        self._fold_len = config.pop('fold_len', 16)
        self._whisker_len = config.pop('whisker_len', 3)

    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : ndarray
            Matrix of multiwavelet coefficients.
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        indexed_coeffs = [[multiwavelet_coeffs[0, i], i]
                          for i in range(self._num_coarse_grid_cells)]

        if self._num_coarse_grid_cells < self._fold_len:
            self._fold_len = self._num_coarse_grid_cells

        num_folds = self._num_coarse_grid_cells//self._fold_len
        troubled_cells = []

        for fold in range(num_folds):
            sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:
                                                (fold+1) * self._fold_len])

            boundary_index = self._fold_len//4
            balance_factor = self._fold_len/4.0 - boundary_index

            first_quartile = (1-balance_factor) \
                * sorted_fold[boundary_index-1][0] \
                + balance_factor * sorted_fold[boundary_index][0]
            third_quartile = (1-balance_factor) \
                * sorted_fold[3*boundary_index-1][0]\
                + balance_factor * sorted_fold[3*boundary_index][0]

            lower_bound = first_quartile \
                - self._whisker_len * (third_quartile-first_quartile)
            upper_bound = third_quartile \
                + self._whisker_len * (third_quartile-first_quartile)

            # Check for lower extreme outliers and add respective cells
            for cell in sorted_fold:
                if cell[0] < lower_bound:
                    troubled_cells.append(cell[1])
                else:
                    break

            # Check for lower extreme outliers and add respective cells
            for cell in sorted_fold[::-1][:]:
                if cell[0] > upper_bound:
                    troubled_cells.append(cell[1])
                else:
                    break

        return sorted(troubled_cells)


class Theoretical(WaveletDetector):
    """Class for troubled-cell detection based on the projection averages and
    a cutoff factor.

    Attributes
    ----------
    cutoff_factor : float
        Cutoff factor above which a cell is considered troubled.

    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        self._cutoff_factor = config.pop('cutoff_factor',
                                         np.sqrt(2) * self._mesh.cell_len)
        # comment to line above: or 2 or 3

    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : ndarray
            Matrix of multiwavelet coefficients.
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        troubled_cells = []
        max_avg = np.sqrt(0.5) \
            * max(1, max(abs(projection[0][cell+1])
                         for cell in range(self._num_coarse_grid_cells)))

        for cell in range(self._num_coarse_grid_cells):
            if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                troubled_cells.append(cell)

        return troubled_cells

    def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
        """Checks whether a cell is troubled.

        Parameters
        ----------
        multiwavelet_coeffs : ndarray
            Matrix of multiwavelet coefficients.
        cell : int
            Index of cell.
        max_avg : float
            Maximum average of projection.

        Returns
        -------
        bool
            Flag whether cell is troubled.

        """
        max_value = max(abs(multiwavelet_coeffs[degree][cell])
                        for degree in range(
            self._basis.polynomial_degree+1))/max_avg
        eps = self._cutoff_factor\
            / (self._mesh.cell_len*self._num_coarse_grid_cells*2)

        return max_value > eps