# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

Discussion:
TODO: Contemplate saving 5-CV split and evaluating models separately
TODO: Contemplate separating cell average and reconstruction calculations
    completely
TODO: Contemplate removing Methods section from class docstring
TODO: Contemplate containing the quadrature application for plots in Mesh
TODO: Contemplate containing coarse mesh generation in Mesh
TODO: Contemplate extracting boundary condition from InitialCondition
TODO: Contemplate containing boundary condition in Mesh
TODO: Ask whether all quadratures depend on freely chosen num_nodes
TODO: Contemplate saving training data for each IC separately
TODO: Contemplate removing TrainingDataGenerator class
TODO: Ask why no fine mesh is used for calculate_coarse_projection()
TODO: Discuss adding kwargs to attributes in documentation
TODO: Discuss descriptions (matrices, cfl number, right-hand side,
    limiting slope, basis, wavelet, etc.)
TODO: Discuss referencing info on SSPRK3

Urgent:
TODO: Fix indexing issue in wavelet coefficient calculation
TODO: Enforce mesh with 2^n cells
TODO: Enforce even number of ghost cells on each side on fine mesh (?)
TODO: Change heaviside random to uniform(-100, 100)
TODO: Adjust Heaviside to have non-symmetric values (left- and right_value)
TODO: Rename 'adjustment' to 'shift'
TODO: Induce shift in IC class
TODO: Give option to set discontinuity to cell boundary
TODO: Fix typo in LinearAbsolute
TODO: Introduce middle_factor for two-sided Heaviside
TODO: Move plot_approximation_results() into plotting script
TODO: Improve file naming (e.g. use '.' instead of '__')
TODO: Check whether ghost cells are handled/set correctly
TODO: Unify use of 'length' and 'len' in naming
TODO: Unify use of 'initial_condition' and 'init_cond' in naming
TODO: Unify use of 'mesh' and 'grid' in naming
TODO: Check sign change in stiffness matrix to accommodate negative wave speed
TODO: Check whether 'projection' is always a ndarray
TODO: Force input_size for each ANN model to be stencil length

Critical, but not urgent:
TODO: Introduce env files for each SM rule
TODO: Add images to report
TODO: Add verbose output
TODO: Outsource scripts into separate directory
TODO: Check whether all instance variables are sensible
TODO: Combine ANN workflows if feasible
TODO: Investigate profiling for speed up
TODO: Rework Theoretical TC for efficiency
TODO: Extract object initialization from DGScheme
TODO: Replace loops with list comprehension if feasible
TODO: Replace loops/list comprehension with vectorization if feasible

Currently not critical:
TODO: Give option to select plotting color
TODO: Add an environment file for Snakemake
TODO: Build package (module?) for DG scheme
TODO: Investigate g-mesh(?)
TODO: Create g-mesh with Mesh class
TODO: Rename files according to standard
TODO: Allow comparison between ANN training datasets
TODO: Use full path for ANN model state
TODO: Add a default model state
TODO: Look into validators for variable checks

Not feasible yet or doc-related:
TODO: Add README for ANN training
TODO: Give detailed description of wavelet detection
TODO: Use cfl_number for updating, not just time (equation-related?)
TODO: Adjust code to allow classes for all equations
    (Burger, linear advection, 1D Euler)
TODO: Add ThresholdDetector
TODO: Double-check everything! (also with pylint, pytype, pydoc,
    pycodestyle, pydocstyle)
TODO: Check whether documentation style is correct
TODO: Check whether all types in doc are correct
TODO: Add type annotations to function heads
TODO: Clean up docstrings

"""
import json
import numpy as np
from sympy import Symbol
import math
import seaborn as sns

import Troubled_Cell_Detector
import Initial_Condition
import Limiter
import Quadrature
import Update_Scheme
from Basis_Function import OrthonormalLegendre
from encoding_utils import encode_ndarray
from projection_utils import Mesh

x = Symbol('x')
sns.set()


class DGScheme:
    """Class for Discontinuous Galerkin Method.

    Approximates linear advection equation using Discontinuous Galerkin Method
    with troubled-cell-based limiting.

    Attributes
    ----------
    basis : Basis object
        Basis for calculation.
    mesh : Mesh
        Mesh for calculation.

    Methods
    -------
    approximate()
        Approximates projection.
    save_plots()
        Saves plots generated during approximation process.
    build_training_data(adjustment, stencil_length, initial_condition=None)
        Builds training data set.

    """
    def __init__(self, detector, **kwargs):
        """Initializes DGScheme.

        Parameters
        ----------
        detector : str
            Name of troubled cell detector class.

        Other Parameters
        ----------------
        wave_speed : float, optional
            Speed of wave in rightward direction. Default: 1.
        polynomial_degree : int, optional
            Polynomial degree. Default: 2.
        cfl_number : float, optional
            CFL number to ensure stability. Default: 0.2.
        num_grid_cells : int, optional
            Number of cells in the mesh. Usually exponential of 2. Default: 64.
        final_time : float, optional
            Final time for which approximation is calculated. Default: 1.
        left_bound : float, optional
            Left boundary of interval. Default: -1.
        right_bound : float, optional
            Right boundary of interval. Default: 1.
        verbose : bool, optional
            Flag whether commentary in console is wanted. Default: False.
        history_threshold : float, optional
            Threshold when history will be recorded.
            Default: math.ceil(0.2/cfl_number).
        detector_config : dict, optional
            Additional parameters for detector object. Default: {}.
        init_cond : str, optional
            Name of initial condition for evaluation. Default: 'Sine'
        init_config : dict, optional
            Additional parameters for initial condition object. Default: {}.
        limiter : str, optional
            Name of limiter for evaluation. Default: 'ModifiedMinMod'.
        limiter_config : dict, optional
            Additional parameters for limiter. object. Default: {}:
        quadrature : str, optional
            Name of quadrature for evaluation. Default: 'Gauss'.
        quadrature_config : dict, optional
            Additional parameters for quadrature object. Default: {}.
        update_scheme : str, optional
            Name of update scheme for evaluation. Default: 'SSPRK3'.

        """
        # Unpack keyword arguments
        self._wave_speed = kwargs.pop('wave_speed', 1)
        self._cfl_number = kwargs.pop('cfl_number', 0.2)
        self._final_time = kwargs.pop('final_time', 1)
        self._verbose = kwargs.pop('verbose', False)
        self._history_threshold = kwargs.pop('history_threshold',
                                             math.ceil(0.2/self._cfl_number))
        self._detector = detector
        self._detector_config = kwargs.pop('detector_config', {})
        self._init_cond = kwargs.pop('init_cond', 'Sine')
        self._init_config = kwargs.pop('init_config', {})
        self._limiter = kwargs.pop('limiter', 'ModifiedMinMod')
        self._limiter_config = kwargs.pop('limiter_config', {})
        self._quadrature = kwargs.pop('quadrature', 'Gauss')
        self._quadrature_config = kwargs.pop('quadrature_config', {})
        self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3')
        self._basis = OrthonormalLegendre(kwargs.pop('polynomial_degree', 2))

        # Initialize mesh with two ghost cells on each side
        self._mesh = Mesh(num_grid_cells=kwargs.pop('num_grid_cells', 64),
                          left_bound=kwargs.pop('left_bound', -1),
                          right_bound=kwargs.pop('right_bound', 1),
                          num_ghost_cells=2)
        # print(len(self._mesh.cells))
        # print(type(self._mesh.cells))

        # Throw an error if there are extra keyword arguments
        if len(kwargs) > 0:
            extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
            raise ValueError('Unrecognized arguments: %s' % extra)

        # Make sure all classes actually exist
        if not hasattr(Troubled_Cell_Detector, self._detector):
            raise ValueError('Invalid detector: "%s"' % self._detector)
        if not hasattr(Initial_Condition, self._init_cond):
            raise ValueError('Invalid initial condition: "%s"'
                             % self._init_cond)
        if not hasattr(Limiter, self._limiter):
            raise ValueError('Invalid limiter: "%s"' % self._limiter)
        if not hasattr(Quadrature, self._quadrature):
            raise ValueError('Invalid quadrature: "%s"' % self._quadrature)
        if not hasattr(Update_Scheme, self._update_scheme):
            raise ValueError('Invalid update scheme: "%s"'
                             % self._update_scheme)

        self._reset()

        # Replace the string names with the actual class instances
        # (and add the instance variables for the quadrature)
        self._init_cond = getattr(Initial_Condition, self._init_cond)(
            config=self._init_config)
        self._limiter = getattr(Limiter, self._limiter)(
            config=self._limiter_config)
        self._quadrature = getattr(Quadrature, self._quadrature)(
            config=self._quadrature_config)
        self._detector = getattr(Troubled_Cell_Detector, self._detector)(
            config=self._detector_config, mesh=self._mesh, basis=self._basis)
        self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
            polynomial_degree=self._basis.polynomial_degree,
            num_grid_cells=self._mesh.num_grid_cells, detector=self._detector,
            limiter=self._limiter)

    def approximate(self, data_file):
        """Approximates projection.

        Initializes projection and evolves it in time. Each time step consists
        of three parts: A projection update, a troubled-cell detection,
        and limiting based on the detected cells.

        At final time, results are saved in JSON file.

        Attributes
        ----------
        data_file: str
            Path to file in which data will be saved.

        """
        projection = do_initial_projection(
            initial_condition=self._init_cond, mesh=self._mesh,
            basis=self._basis, quadrature=self._quadrature)

        time_step = abs(self._cfl_number * self._mesh.cell_len /
                        self._wave_speed)

        current_time = 0
        iteration = 0
        troubled_cell_history = []
        time_history = []
        while current_time < self._final_time:
            # Adjust for last cell
            cfl_number = self._cfl_number
            if current_time+time_step > self._final_time:
                time_step = self._final_time-current_time
                cfl_number = self._wave_speed * time_step / self._mesh.cell_len

            # Update projection
            projection, troubled_cells = self._update_scheme.step(projection,
                                                                  cfl_number)
            iteration += 1

            if (iteration % self._history_threshold) == 0:
                troubled_cell_history.append(troubled_cells)
                time_history.append(current_time)

            current_time += time_step

        # Save detector-specific data in dictionary
        approx_stats = self._detector.create_data_dict(projection)

        # Save approximation results in dictionary
        approx_stats['wave_speed'] = self._wave_speed
        approx_stats['final_time'] = self._final_time
        approx_stats['time_history'] = time_history
        approx_stats['troubled_cell_history'] = troubled_cell_history

        # Encode all ndarrays to fit JSON format
        approx_stats = {key: encode_ndarray(approx_stats[key])
                        for key in approx_stats.keys()}

        # Save approximation results in JSON format
        with open(data_file + '.json', 'w') \
                as json_file:
            json_file.write(json.dumps(approx_stats))

    def _reset(self):
        """Resets instance variables."""
        # Set additional necessary config parameters
        self._limiter_config['cell_len'] = self._mesh.cell_len


def do_initial_projection(initial_condition, mesh, basis, quadrature,
                          adjustment=0):
    """Calculates initial projection.

    Calculates a projection at time step 0 and adds ghost cells on both
    sides of the array.

    Parameters
    ----------
    initial_condition : InitialCondition object
        Initial condition used for calculation.
    mesh : Mesh
        Mesh for calculation.
    basis: Basis object
        Basis used for calculation.
    quadrature: Quadrature object
        Quadrature used for evaluation.
    adjustment: float, optional
        Extent of adjustment of each evaluation point in x-direction.
        Default: 0.

    Returns
    -------
    ndarray
        Matrix containing projection of size (N+2, p+1) with N being the
        number of grid cells and p being the polynomial degree of the basis.

    """
    # Initialize matrix and set first entry to accommodate for ghost cell
    output_matrix = [0]

    for eval_point in mesh.non_ghost_cells:
        new_row = []
        for degree in range(basis.polynomial_degree + 1):
            new_row.append(np.float64(sum(initial_condition.calculate(
                x=eval_point + mesh.cell_len/2
                * quadrature.nodes[point] - adjustment)
                * basis.basis[degree].subs(
                    x, quadrature.nodes[point])
                * quadrature.weights[point]
                for point in range(quadrature.num_nodes))))

        new_row = np.array(new_row)
        output_matrix.append(basis.inverse_mass_matrix @ new_row)

    # Set ghost cells to respective value
    output_matrix[0] = output_matrix[mesh.num_grid_cells]
    output_matrix.append(output_matrix[1])

    # print(np.array(output_matrix).shape)
    return np.transpose(np.array(output_matrix))