# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

"""
from typing import Tuple
from abc import ABC, abstractmethod
import numpy as np
from numpy import ndarray
import torch

from . import ANN_Model
from .Boundary_Condition import enforce_boxplot_boundaries, \
    enforce_fold_boundaries
from .Mesh import Mesh


class TroubledCellDetector(ABC):
    """Abstract class for troubled-cell detection.

    Detects troubled cells, i.e., cells in the mesh containing instabilities.

    Methods
    -------
    get_name()
        Returns string of class name.
    get_cells(projection)
        Calculates troubled cells in a given projection.
    create_data_dict(projection)
        Return dictionary with data necessary to plot troubled cells.

    """
    def __init__(self, config, basis, mesh):
        """Initializes TroubledCellDetector.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.
        basis : Basis object
            Basis for calculation.
        mesh : Mesh
            Mesh for calculation.

        """
        self._mesh = mesh
        self._basis = basis

        self._reset(config)

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        pass

    def get_name(self):
        """Returns string of class name."""
        return self.__class__.__name__

    @abstractmethod
    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        """
        pass

    def create_data_dict(self, projection):
        """Return dictionary with data necessary to plot troubled cells."""
        return {'projection': projection,
                'basis': self._basis.create_data_dict(),
                'mesh': self._mesh.create_data_dict()
                }


class NoDetection(TroubledCellDetector):
    """Class without any troubled-cell detection.

    Methods
    -------
    get_cells(projection)
        Returns no troubled cells.

    """
    def get_cells(self, projection):
        """Returns no troubled cells.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        return []


class ArtificialNeuralNetwork(TroubledCellDetector):
    """Class for troubled-cell detection using ANNs.

    Attributes
    ----------
    stencil_length : int
        Size of input data array.
    model : torch.nn.Model
        ANN model instance for evaluation.

    Methods
    -------
    get_cells(projection)
        Calculates troubled cells in a given projection.


    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        self._stencil_len = config.pop('stencil_len', 3)
        self._add_reconstructions = config.pop('add_reconstructions', True)

        # Set mask for stencil sliding window
        self._window_mask = np.arange(self._mesh.num_cells)[None, :] + \
                            np.arange(self._stencil_len)[:, None]

        self._model = config.pop('model', 'ThreeLayerReLu')
        num_datapoints = self._stencil_len
        if self._add_reconstructions:
            num_datapoints += 2
        self._model_config = config.pop('model_config', {
            'input_size': num_datapoints, 'first_hidden_size': 8,
            'second_hidden_size': 4, 'output_size': 2,
            'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
        model_state = config.pop('model_state', 'Snakemake-Test/trained '
                                                'models/model__Adam.pt')

        if not hasattr(ANN_Model, self._model):
            raise ValueError('Invalid model: "%s"' % self._model)
        self._model = getattr(ANN_Model, self._model)(self._model_config)

        # Load the model state and set it to evaluation mode
        self._model.load_state_dict(torch.load(str(model_state)))
        self._model.eval()

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        # Reset ghost cells to adjust for stencil length
        num_ghost_cells = self._stencil_len//2
        projection = projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells]
        projection = np.concatenate((projection[:, -num_ghost_cells:],
                                     projection,
                                     projection[:, :num_ghost_cells]), axis=1)

        # Calculate input data depending on stencil length
        projection_window = projection[:, self._window_mask]
        input_data = torch.from_numpy(self._basis.calculate_cell_average(
            projection=projection_window, stencil_len=self._stencil_len,
            add_reconstructions=self._add_reconstructions))

        # Determine troubled cells
        model_output = torch.argmax(self._model(input_data.float()),
                                    dim=1)
        return np.flatnonzero(model_output.numpy()).tolist()


class WaveletDetector(TroubledCellDetector):
    """Abstract class for wavelet coefficient based troubled-cell detection.

    ???

    """

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        Raises
        ------
        ValueError
            If multiwavelet degree is not in ['first', 'last', 'max'].

        """
        super()._reset(config)

        self._wavelet_degree = config.pop('multiwavelet_degree', 'first')
        if self._wavelet_degree not in ['first', 'last', 'max']:
            raise ValueError('Invalid entry for multiwavelet degree. It must '
                             'be either "first", "last", or "max".')

        # Set wavelet projections
        self._wavelet_projection_left, self._wavelet_projection_right \
            = self._basis.multiwavelet_projection

    def _calculate_wavelet_coeffs(self, projection):
        """Calculates wavelet coefficients used for projection to coarser mesh.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of multiwavelet coefficients.

        """
        return 0.5 * (self._wavelet_projection_left.T @
                      projection[:, self._mesh.num_ghost_cells:
                                 -self._mesh.num_ghost_cells] +
                      self._wavelet_projection_right.T @
                      projection[:, self._mesh.num_ghost_cells+1:
                                 projection.shape[-1]
                                 - self._mesh.num_ghost_cells+1])

    def _select_degree(self, wavelet_matrix):
        """Select degree of wavelet coefficients for troubled cell detection.

        Select either the first, last, or highest megnitude degree for each
        cell from the multiwavelet coefficients.

        Parameters
        ----------
        wavelet_matrix : ndarray
            Matrix of multiwavelet coefficients.

        Returns
        -------
        ndarray
            Matrix of multiwavelet coefficients of selected degree.

        """
        if self._wavelet_degree == 'first':
            return wavelet_matrix[0]
        elif self._wavelet_degree == 'last':
            return wavelet_matrix[-1]
        else:
            max_values = np.max(wavelet_matrix, axis=0)
            min_values = np.min(wavelet_matrix, axis=0)
            return np.where(-min_values > max_values, min_values, max_values)

    def _calculate_coarse_projection(self, projection):
        """Calculates coarse projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of projection on coarse mesh for each polynomial degree.

        """
        basis_projection_left, basis_projection_right\
            = self._basis.basis_projection

        # Remove ghost cells
        projection = projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells]

        # Calculate projection on coarse mesh
        return 0.5 * (basis_projection_left.T @ projection[:, ::2] +
                      basis_projection_right.T @ projection[:, 1::2])

    def create_data_dict(self, projection):
        """Return dictionary with data necessary to plot troubled cells."""
        # Create general directory
        data_dict = super().create_data_dict(projection)

        # Save multiwavelet-specific data in dictionary
        multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
        coarse_projection = self._calculate_coarse_projection(projection)
        data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs
        data_dict['coarse_projection'] = coarse_projection

        return data_dict


class Boxplot(WaveletDetector):
    """Class for troubled-cell detection based on Boxplots.

    Attributes
    ----------
    fold_len : int
        Length of folds considered in one Boxplot.
    whisker_len : int
        Length of Boxplot whiskers.
    adjust_outer_fences : bool
        Flag whether outer fences should be adjusted using global mean.
    extreme_outlier_only : bool
        Flag whether outliers also have to be detected in neighbouring folds.
    quantile_method : str
        Method used to calculate quantiles.
    fold_indices : ndarray
        Array with indices for elements of each fold (including
        overlaps).

    """
    def _reset(self, config):
        """Reset instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        self._fold_len = min(self._mesh.num_cells,
                             config.pop('fold_len', 16))
        self._whisker_len = config.pop('whisker_len', 3)
        self._adjust_outer_fences = config.pop('adjust_outer_fences', True)
        self._extreme_outlier_only = config.pop('extreme_outlier_only', True)

        self._quantile_method = config.pop('quantile_method', 'weibull')
        num_overlapping_cells = config.pop('num_overlapping_cells', 1)
        self._fold_indices = self._compute_folds(num_overlapping_cells)

    @enforce_fold_boundaries
    def _compute_folds(self, num_overlapping_cells: int) -> ndarray:
        """Compute indices for all folds used in Boxplot calculation.

        Parameters
        ----------
        num_overlapping_cells : int
            Number of cells overlapping between adjacent folds.

        Returns
        -------
        fold_indices : ndarray
            Array of projection indices in each fold.

        """
        num_folds = self._mesh.num_cells//self._fold_len
        fold_indices = np.zeros([num_folds, self._fold_len + 2 *
                                 num_overlapping_cells]).astype(np.int32)
        for fold in range(num_folds):
            fold_indices[fold] = np.array([i for i in range(
                fold * self._fold_len - num_overlapping_cells,
                (fold+1) * self._fold_len + num_overlapping_cells)])
        return fold_indices

    def get_cells(self, projection):
        """Calculate troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        # Determine bounds
        coeffs = self._select_degree(self._calculate_wavelet_coeffs(
            projection))
        lower_bounds, upper_bounds = self._compute_bounds(coeffs)

        # Adjust outer fences if flag is set
        if self._adjust_outer_fences:
            global_mean = np.mean(abs(coeffs))
            lower_bounds[lower_bounds > -global_mean] = -global_mean
            upper_bounds[upper_bounds < global_mean] = global_mean

        # Select outliers as troubled cells
        lower_outlier = coeffs < np.repeat(lower_bounds[1:-1], self._fold_len)
        upper_outlier = coeffs > np.repeat(upper_bounds[1:-1], self._fold_len)

        # Adjust for extreme outliers if flag is set
        if self._extreme_outlier_only:
            lower_outlier = np.logical_and(lower_outlier, np.logical_and(
                coeffs < np.repeat(lower_bounds[:-2], self._fold_len),
                coeffs < np.repeat(lower_bounds[2:], self._fold_len)))
            upper_outlier = np.logical_and(upper_outlier, np.logical_and(
                coeffs > np.repeat(upper_bounds[:-2], self._fold_len),
                coeffs > np.repeat(upper_bounds[2:], self._fold_len)))

        troubled_cells = np.flatnonzero(np.logical_or(lower_outlier,
                                                      upper_outlier)).tolist()

        return troubled_cells

    @enforce_boxplot_boundaries
    def _compute_bounds(self, coeffs: ndarray) -> Tuple[ndarray, ndarray]:
        """Compute lower and upper bound for Boxplot outliers.

        Parameters
        ----------
        coeffs : ndarray
            Matrix of multiwavelet coefficients of projection.

        Returns
        -------
        lower_bounds : ndarray
            Array of lower bounds for outlier.
        upper_bounds : ndarray
            Array of upper bounds for outlier.

        """
        # Determine quartiles of folds
        folds = coeffs[self._fold_indices]
        first_quartiles = np.quantile(folds, 0.25, axis=1,
                                      method=self._quantile_method)
        third_quartiles = np.quantile(folds, 0.75, axis=1,
                                      method=self._quantile_method)

        # Determine bounds based on quartiles of a boxplot
        lower_bounds = np.zeros(len(first_quartiles) + 2)
        upper_bounds = np.zeros(len(first_quartiles) + 2)

        lower_bounds[1:-1] = first_quartiles - self._whisker_len * (
                third_quartiles-first_quartiles)
        upper_bounds[1:-1] = third_quartiles + self._whisker_len * (
                third_quartiles-first_quartiles)

        return lower_bounds, upper_bounds


class Theoretical(WaveletDetector):
    """Class for troubled-cell detection based on theoretical thresholding.

    Attributes
    ----------
    threshold : float
        Threshold above which a cell is considered troubled.

    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        cutoff_factor = config.pop('cutoff_factor',
                                   np.sqrt(2) * self._mesh.cell_len)
        self._threshold = cutoff_factor / (
                self._mesh.cell_len*self._mesh.num_cells)

    def get_cells(self, projection):
        """Calculate troubled cells in a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        coeffs = self._select_degree(self._calculate_wavelet_coeffs(
            projection))

        max_avg = np.sqrt(0.5) * max(1, np.max(np.abs(projection[0])))
        troubled_cells = np.flatnonzero(
            np.abs(coeffs)/max_avg > self._threshold).tolist()

        return troubled_cells