# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second)
TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection

"""
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from sympy import Symbol

import ANN_Model
from Plotting import plot_solution_and_approx, plot_semilog_error, plot_error, plot_shock_tube, \
    plot_details, calculate_approximate_solution, calculate_exact_solution

matplotlib.use('Agg')
x = Symbol('x')
z = Symbol('z')


class TroubledCellDetector(object):
    """Class for troubled-cell detection.

    Detects troubled cells, i.e., cells in the mesh containing instabilities.

    Attributes
    ----------
    interval_len : float
        Length of the interval between left and right boundary.
    cell_len : float
        Length of a cell in mesh.

    Methods
    -------
    get_name()
        Returns string of class name.
    get_cells(projection)
        Calculates troubled cells in a given projection.
    calculate_cell_average_and_reconstructions(projection, stencil_length)
        Calculates cell averages and reconstructions for a given projection.
    plot_results(projection, troubled_cell_history, time_history)
        Plots results and troubled cells of a projection given its evaluation history.

    """
    def __init__(self, config, mesh, wave_speed, polynomial_degree, num_grid_cells, final_time,
                 left_bound, right_bound, basis, init_cond, quadrature):
        """Initializes TroubledCellDetector.

        Parameters
        ----------
        mesh : array
            List of mesh valuation points.
        wave_speed : float
            Speed of wave in rightward direction.
        polynomial_degree : int
            Polynomial degree.
        num_grid_cells : int
            Number of cells in the mesh. Usually exponential of 2.
        final_time : float
            Final time for which approximation is calculated.
        left_bound : float
            Left boundary of interval.
        right_bound : float
            Right boundary of interval.
        basis : Basis object
            Basis for calculation.
        init_cond : InitialCondition object
            Initial condition for evaluation.
        quadrature : Quadrature object
            Quadrature for evaluation.

        """
        self._mesh = mesh
        self._wave_speed = wave_speed
        self._polynomial_degree = polynomial_degree
        self._num_grid_cells = num_grid_cells
        self._final_time = final_time
        self._left_bound = left_bound
        self._right_bound = right_bound
        self._interval_len = right_bound - left_bound
        self._cell_len = self._interval_len / num_grid_cells
        self._basis = basis
        self._init_cond = init_cond
        self._quadrature = quadrature

        # Set parameters from config if existing
        self._colors = config.pop('colors', {})

        self._check_colors()
        self._reset(config)

    def _check_colors(self):
        """Checks plot colors.

        Checks whether colors for plots were given and sets them if required.

        """
        self._colors['exact'] = self._colors.get('exact', 'k-')
        self._colors['approx'] = self._colors.get('approx', 'y')

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        sns.set()

    def get_name(self):
        """Returns string of class name."""
        return self.__class__.__name__

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        """
        pass

    def calculate_cell_average_and_reconstructions(self, projection, stencil_length):
        """Calculates cell averages and reconstructions for a given projection.

        Calculate the cell averages of all cells in a projection. Reconstructions are only
        calculated for the middle cell and added left and right to it, respectively.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.
        stencil_length : int
            Size of data array.

        Returns
        -------
        np.array
            Matrix containing cell averages and reconstructions for initial projection.

        """
        cell_averages = calculate_approximate_solution(
            projection, [0], 0, self._basis.get_basis_vector())
        left_reconstructions = calculate_approximate_solution(
            projection, [-1], self._polynomial_degree, self._basis.get_basis_vector())
        right_reconstructions = calculate_approximate_solution(
            projection, [1], self._polynomial_degree, self._basis.get_basis_vector())
        middle_idx = stencil_length//2
        return np.array(list(map(np.float64, zip(cell_averages[:, :middle_idx],
                        left_reconstructions[:, middle_idx], cell_averages[:, middle_idx],
                        right_reconstructions[:, middle_idx], cell_averages[:, middle_idx+1:]))))

    def plot_results(self, projection, troubled_cell_history, time_history):
        """Plots results and troubled cells of a projection given its evaluation history.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.
        troubled_cell_history : list
            List of detected troubled cells for each time step.
        time_history:
            List of value of each time step.

        """
        plot_shock_tube(self._num_grid_cells, troubled_cell_history, time_history)
        max_error = self._plot_mesh(projection)

        print('p =', self._polynomial_degree)
        print('N =', self._num_grid_cells)
        print('maximum error =', max_error)

    def _plot_mesh(self, projection):
        """Plots exact and approximate solution as well as errors.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        max_error
            Maximum error between exact and approximate solution.

        """
        grid, exact = calculate_exact_solution(
            self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
            self._interval_len, self._quadrature, self._init_cond)
        approx = calculate_approximate_solution(
            projection[:, 1:-1], self._quadrature.get_eval_points(), self._polynomial_degree,
            self._basis.get_basis_vector())

        pointwise_error = np.abs(exact-approx)
        max_error = np.max(pointwise_error)

        plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
        plt.legend(['Exact', 'Approx'])
        plot_semilog_error(grid, pointwise_error)
        plot_error(grid, exact, approx)

        return max_error


class NoDetection(TroubledCellDetector):
    """Class without any troubled-cell detection.

    Methods
    -------
    get_cells(projection)
        Returns no troubled cells.

    """
    def get_cells(self, projection):
        """Returns no troubled cells.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        """
        return []


class ArtificialNeuralNetwork(TroubledCellDetector):
    """Class for troubled-cell detection using ANNs.

    Attributes
    ----------
    stencil_length : int
        Size of input data array.
    model : ANNModel object
        ANN model instance for evaluation.

    Methods
    -------
    get_cells(projection)
        Calculates troubled cells in a given projection.


    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        self._stencil_len = config.pop('stencil_len', 3)
        self._model = config.pop('model', 'ThreeLayerReLu')
        self._model_config = config.pop('model_config', {
            'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4,
            'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
        model_state = config.pop('model_state', 'Snakemake-Test/trained models/model__Adam.pt')

        if not hasattr(ANN_Model, self._model):
            raise ValueError('Invalid model: "%s"' % self._model)
        self._model = getattr(ANN_Model, self._model)(self._model_config)

        # Load the model state and set it to evaluation mode
        self._model.load_state_dict(torch.load(str(model_state)))
        self._model.eval()

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        # Reset ghost cells to adjust for stencil length
        num_ghost_cells = self._stencil_len//2
        projection = projection[:, 1:-1]
        projection = np.concatenate((projection[:, -num_ghost_cells:], projection,
                                     projection[:, :num_ghost_cells]),
                                    axis=1)

        # Calculate input data depending on stencil length
        input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
            projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
            for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))

        # Determine troubled cells
        model_output = torch.argmax(self._model(input_data.float()), dim=1)
        return [cell for cell in range(len(model_output))
                if model_output[cell] == torch.tensor([1])]


class WaveletDetector(TroubledCellDetector):
    """Class for troubled-cell detection based on wavelet coefficients.

    ???

    """
    def _check_colors(self):
        """Checks plot colors.

        Checks whether colors for plots were given and sets them if required.

        """
        self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
        self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
        self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
        self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')

    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Set additional necessary parameter
        self._num_coarse_grid_cells = self._num_grid_cells//2
        self._wavelet_projection_left, self._wavelet_projection_right \
            = self._basis.get_multiwavelet_projections()

    def get_cells(self, projection):
        """Calculates troubled cells in a given projection.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection[:, 1: -1])
        return self._get_cells(multiwavelet_coeffs, projection)

    def _calculate_wavelet_coeffs(self, projection):
        """Calculates wavelet coefficients used for projection to coarser grid.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        np.array
            Matrix of wavelet coefficients.

        """
        output_matrix = []
        for i in range(self._num_coarse_grid_cells):
            new_entry = 0.5*(projection[:, 2*i] @ self._wavelet_projection_left
                             + projection[:, 2*i+1] @ self._wavelet_projection_right)
            output_matrix.append(new_entry)
        return np.transpose(np.array(output_matrix))

    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : np.array
            Matrix of multiwavelet coefficients.
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        return []

    def plot_results(self, projection, troubled_cell_history, time_history):
        """Plots results and troubled cells of a projection given its evaluation history.

        Plots results on coarse and fine grid, errors, troubled cells, and coefficient details.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.
        troubled_cell_history : list
            List of detected troubled cells for each time step.
        time_history:
            List of value of each time step.

        """
        multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
        coarse_projection = self._calculate_coarse_projection(projection)
        plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection,
                     self._basis.get_basis_vector(), self._basis.get_wavelet_vector(),
                     multiwavelet_coeffs, self._num_coarse_grid_cells,
                     self._polynomial_degree)
        super().plot_results(projection, troubled_cell_history, time_history)

    def _calculate_coarse_projection(self, projection):
        """Calculates coarse projection.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        np.array
            Matrix of projection on coarse grid for each polynomial degree.

        """
        basis_projection_left, basis_projection_right = self._basis.get_basis_projections()

        # Remove ghost cells
        projection = projection[:, 1:-1]

        # Calculate projection on coarse mesh
        output_matrix = []
        for i in range(self._num_coarse_grid_cells):
            new_entry = 0.5 * (projection[:, 2 * i] @ basis_projection_left
                               + projection[:, 2 * i + 1] @ basis_projection_right)
            output_matrix.append(new_entry)
        coarse_projection = np.transpose(np.array(output_matrix))

        return coarse_projection

    def _plot_mesh(self, projection):
        """Plots exact and approximate solution as well as errors.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        max_error
            Maximum error between exact and approximate solution.

        """

        grid, exact = calculate_exact_solution(
            self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
            self._interval_len, self._quadrature, self._init_cond)
        approx = calculate_approximate_solution(
            projection[:, 1:-1], self._quadrature.get_eval_points(), self._polynomial_degree,
            self._basis.get_basis_vector())

        pointwise_error = np.abs(exact-approx)
        max_error = np.max(pointwise_error)

        self._plot_coarse_mesh(projection)
        plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'],
                                 self._colors['fine_approx'])
        plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', 'Approx (Fine)'])
        plot_semilog_error(grid, pointwise_error)
        plot_error(grid, exact, approx)

        return max_error

    def _plot_coarse_mesh(self, projection):
        """Plots exact and approximate solution as well as errors for a coarse projection.

        Parameters
        ----------
        projection : np.array
            Matrix of projection for each polynomial degree.

        """
        coarse_cell_len = 2*self._cell_len
        coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len),
                                self._right_bound + (1.5*coarse_cell_len),
                                coarse_cell_len)

        coarse_projection = self._calculate_coarse_projection(projection)

        # Plot exact and approximate solutions for coarse mesh
        grid, exact = calculate_exact_solution(
            coarse_mesh[1:-1], coarse_cell_len, self._wave_speed, self._final_time,
            self._interval_len, self._quadrature, self._init_cond)
        approx = calculate_approximate_solution(
            coarse_projection, self._quadrature.get_eval_points(), self._polynomial_degree,
            self._basis.get_basis_vector())
        plot_solution_and_approx(
            grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])


class Boxplot(WaveletDetector):
    """Class for troubled-cell detection based on Boxplots.

    Attributes
    ----------
    fold_len : int
        Length of folds considered in one Boxplot.
    whisker_len : int
        Length of Boxplot whiskers.

    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        self._fold_len = config.pop('fold_len', 16)
        self._whisker_len = config.pop('whisker_len', 3)

    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : np.array
            Matrix of multiwavelet coefficients.
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        indexed_coeffs = [[multiwavelet_coeffs[0, i], i]for i in range(self._num_coarse_grid_cells)]

        if self._num_coarse_grid_cells < self._fold_len:
            self._fold_len = self._num_coarse_grid_cells

        num_folds = self._num_coarse_grid_cells//self._fold_len
        troubled_cells = []

        for fold in range(num_folds):
            sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:(fold+1) * self._fold_len])

            boundary_index = self._fold_len//4
            balance_factor = self._fold_len/4.0 - boundary_index

            first_quartile = (1-balance_factor) * sorted_fold[boundary_index-1][0] \
                + balance_factor * sorted_fold[boundary_index][0]
            third_quartile = (1-balance_factor) * sorted_fold[3*boundary_index-1][0]\
                + balance_factor * sorted_fold[3*boundary_index][0]

            lower_bound = first_quartile - self._whisker_len * (third_quartile-first_quartile)
            upper_bound = third_quartile + self._whisker_len * (third_quartile-first_quartile)

            # Check for lower extreme outliers and add respective cells
            for cell in sorted_fold:
                if cell[0] < lower_bound:
                    troubled_cells.append(cell[1])
                else:
                    break

            # Check for lower extreme outliers and add respective cells
            for cell in sorted_fold[::-1][:]:
                if cell[0] > upper_bound:
                    troubled_cells.append(cell[1])
                else:
                    break

        return sorted(troubled_cells)


class Theoretical(WaveletDetector):
    """"Class for troubled-cell detection based on the projection averages and a cutoff factor.

    Attributes
    ----------
    cutoff_factor : float
        Cutoff factor above which a cell is considered troubled.

    """
    def _reset(self, config):
        """Resets instance variables.

        Parameters
        ----------
        config : dict
            Additional parameters for detector.

        """
        super()._reset(config)

        # Unpack necessary configurations
        self._cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._cell_len)
        # comment to line above: or 2 or 3

    def _get_cells(self, multiwavelet_coeffs, projection):
        """Calculates troubled cells using multiwavelet coefficients.

        Parameters
        ----------
        multiwavelet_coeffs : np.array
            Matrix of multiwavelet coefficients.
        projection : np.array
            Matrix of projection for each polynomial degree.

        Returns
        -------
        list
            List of indices for all detected troubled cells.

        """
        troubled_cells = []
        max_avg = np.sqrt(0.5) * max(1, max(abs(projection[0][cell+1])
                                            for cell in range(self._num_coarse_grid_cells)))

        for cell in range(self._num_coarse_grid_cells):
            if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                troubled_cells.append(cell)

        return troubled_cells

    def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
        """Checks whether a cell is troubled.

        Parameters
        ----------
        multiwavelet_coeffs : np.array
            Matrix of multiwavelet coefficients.
        cell : int
            Index of cell.
        max_avg
            Maximum average of projection.

        Returns
        -------
        boolean
            Flag whether cell is troubled.

        """
        max_value = max(abs(multiwavelet_coeffs[degree][cell])
                        for degree in range(self._polynomial_degree+1))/max_avg
        eps = self._cutoff_factor / (self._cell_len*self._num_coarse_grid_cells*2)

        return max_value > eps