# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

Urgent:
TODO: Extract do_initial_projection() from DGScheme -> Done
TODO: Move inverse mass matrix to basis
TODO: Extract calculate_cell_average() from TCD
TODO: Extract calculate_[...]_solution() from Plotting
TODO: Extract plotting from TCD completely
    (maybe give indicator which plots are required instead?)
TODO: Contain all plotting in Plotting
TODO: Remove use of DGScheme from ANN_Data_Generator
TODO: Adapt TCD from Soraya
    (Dropbox->...->TEST_troubled-cell-detector->Troubled_Cell_Detector)
TODO: Add verbose output
TODO: Improve file naming (e.g. use '.' instead of '__')
TODO: Combine ANN workflows
TODO: Add an environment file for Snakemake

Critical, but not urgent:
TODO: Force input_size for each ANN model to be stencil length
TODO: Use full path for ANN model state
TODO: Enforce abstract classes/methods (abc.ABC, abc.abstractmethod)
TODO: Extract object initialization from DGScheme
TODO: Use cfl_number for updating, not just time

Currently not critical:
TODO: Unify use of 'length' and 'len' in naming
TODO: Replace loops with list comprehension if feasible
TODO: Check whether 'projection' is always a np.array()
TODO: Check whether all instance variables are sensible
TODO: Rename files according to standard
TODO: Outsource scripts into separate directory
TODO: Allow comparison between ANN training datasets
TODO: Add a default model state
TODO: Look into validators for variable checks

Not feasible yet or doc-related:
TODO: Adjust code to allow classes for all equations
    (Burger, linear advection, 1D Euler)
TODO: Double-check everything! (also with pylint, pytype, pydoc,
    pycodestyle, pydocstyle)
TODO: Check whether documentation style is correct
TODO: Check whether all types in doc are correct
TODO: Discuss adding kwargs to attributes in documentation
TODO: Add type annotations to function heads

"""
import os
import json
import numpy as np
from sympy import Symbol
import math
import matplotlib
from matplotlib import pyplot as plt

import Troubled_Cell_Detector
import Initial_Condition
import Limiter
import Quadrature
import Update_Scheme
from Basis_Function import OrthonormalLegendre

matplotlib.use('Agg')
x = Symbol('x')


def encode_ndarray(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj


def decode_ndarray(obj):
    if isinstance(obj, list):
        return np.asarray(obj)
    return obj


class DGScheme:
    """Class for Discontinuous Galerkin Method.

    Approximates linear advection equation using Discontinuous Galerkin Method
    with troubled-cell-based limiting.

    Attributes
    ----------
    interval_len : float
        Length of the interval between left and right boundary.
    cell_len : float
        Length of a cell in mesh.
    basis : Basis object
        Basis for calculation.
    mesh : ndarray
        List of mesh valuation points.
    inv_mass : ndarray
        Inverse mass matrix.

    Methods
    -------
    approximate()
        Approximates projection.
    save_plots()
        Saves plots generated during approximation process.
    build_training_data(adjustment, stencil_length, initial_condition=None)
        Builds training data set.

    """
    def __init__(self, detector, **kwargs):
        """Initializes DGScheme.

        Parameters
        ----------
        detector : str
            Name of troubled cell detector class.

        Other Parameters
        ----------------
        wave_speed : float, optional
            Speed of wave in rightward direction. Default: 1.
        polynomial_degree : int, optional
            Polynomial degree. Default: 2.
        cfl_number : float, optional
            CFL number to ensure stability. Default: 0.2.
        num_grid_cells : int, optional
            Number of cells in the mesh. Usually exponential of 2. Default: 64.
        final_time : float, optional
            Final time for which approximation is calculated. Default: 1.
        left_bound : float, optional
            Left boundary of interval. Default: -1.
        right_bound : float, optional
            Right boundary of interval. Default: 1.
        verbose : bool, optional
            Flag whether commentary in console is wanted. Default: False.
        history_threshold : float, optional
            Threshold when history will be recorded.
            Default: math.ceil(0.2/cfl_number).
        detector_config : dict, optional
            Additional parameters for detector object. Default: {}.
        init_cond : str, optional
            Name of initial condition for evaluation. Default: 'Sine'
        init_config : dict, optional
            Additional parameters for initial condition object. Default: {}.
        limiter : str, optional
            Name of limiter for evaluation. Default: 'ModifiedMinMod'.
        limiter_config : dict, optional
            Additional parameters for limiter. object. Default: {}:
        quadrature : str, optional
            Name of quadrature for evaluation. Default: 'Gauss'.
        quadrature_config : dict, optional
            Additional parameters for quadrature object. Default: {}.
        update_scheme : str, optional
            Name of update scheme for evaluation. Default: 'SSPRK3'.

        """
        # Unpack keyword arguments
        self._wave_speed = kwargs.pop('wave_speed', 1)
        self._polynomial_degree = kwargs.pop('polynomial_degree', 2)
        self._cfl_number = kwargs.pop('cfl_number', 0.2)
        self._num_grid_cells = kwargs.pop('num_grid_cells', 64)
        self._final_time = kwargs.pop('final_time', 1)
        self._left_bound = kwargs.pop('left_bound', -1)
        self._right_bound = kwargs.pop('right_bound', 1)
        self._verbose = kwargs.pop('verbose', False)
        self._history_threshold = kwargs.pop('history_threshold',
                                             math.ceil(0.2/self._cfl_number))
        self._detector = detector
        self._detector_config = kwargs.pop('detector_config', {})
        self._init_cond = kwargs.pop('init_cond', 'Sine')
        self._init_config = kwargs.pop('init_config', {})
        self._limiter = kwargs.pop('limiter', 'ModifiedMinMod')
        self._limiter_config = kwargs.pop('limiter_config', {})
        self._quadrature = kwargs.pop('quadrature', 'Gauss')
        self._quadrature_config = kwargs.pop('quadrature_config', {})
        self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3')

        # Throw an error if there are extra keyword arguments
        if len(kwargs) > 0:
            extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
            raise ValueError('Unrecognized arguments: %s' % extra)

        # Make sure all classes actually exist
        if not hasattr(Troubled_Cell_Detector, self._detector):
            raise ValueError('Invalid detector: "%s"' % self._detector)
        if not hasattr(Initial_Condition, self._init_cond):
            raise ValueError('Invalid initial condition: "%s"'
                             % self._init_cond)
        if not hasattr(Limiter, self._limiter):
            raise ValueError('Invalid limiter: "%s"' % self._limiter)
        if not hasattr(Quadrature, self._quadrature):
            raise ValueError('Invalid quadrature: "%s"' % self._quadrature)
        if not hasattr(Update_Scheme, self._update_scheme):
            raise ValueError('Invalid update scheme: "%s"'
                             % self._update_scheme)

        self._reset()

        # Replace the string names with the actual class instances
        # (and add the instance variables for the quadrature)
        self._init_cond = getattr(Initial_Condition, self._init_cond)(
            left_bound=self._left_bound, right_bound=self._right_bound,
            config=self._init_config)
        self._limiter = getattr(Limiter, self._limiter)(
            config=self._limiter_config)
        self._quadrature = getattr(Quadrature, self._quadrature)(
            config=self._quadrature_config)
        self._detector = getattr(Troubled_Cell_Detector, self._detector)(
            config=self._detector_config, mesh=self._mesh,
            wave_speed=self._wave_speed, num_grid_cells=self._num_grid_cells,
            polynomial_degree=self._polynomial_degree,
            final_time=self._final_time, left_bound=self._left_bound,
            right_bound=self._right_bound, basis=self._basis,
            init_cond=self._init_cond, quadrature=self._quadrature)
        self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
            polynomial_degree=self._polynomial_degree,
            num_grid_cells=self._num_grid_cells, detector=self._detector,
            limiter=self._limiter)

    def approximate(self, data_file):
        """Approximates projection.

        Initializes projection and evolves it in time. Each time step consists
        of three parts: A projection update, a troubled-cell detection,
        and limiting based on the detected cells.

        At final time, results are saved in JSON file.

        Attributes
        ----------
        data_file: str
            Path to file in which data will be saved.

        """
        projection = do_initial_projection(
            initial_condition=self._init_cond, basis=self._basis,
            quadrature=self._quadrature, num_grid_cells=self._num_grid_cells,
            left_bound=self._left_bound, right_bound=self._right_bound,
            polynomial_degree=self._polynomial_degree)

        time_step = abs(self._cfl_number * self._cell_len / self._wave_speed)

        current_time = 0
        iteration = 0
        troubled_cell_history = []
        time_history = []
        while current_time < self._final_time:
            # Adjust for last cell
            cfl_number = self._cfl_number
            if current_time+time_step > self._final_time:
                time_step = self._final_time-current_time
                cfl_number = self._wave_speed * time_step / self._cell_len

            # Update projection
            projection, troubled_cells = self._update_scheme.step(projection,
                                                                  cfl_number)
            iteration += 1

            if (iteration % self._history_threshold) == 0:
                troubled_cell_history.append(troubled_cells)
                time_history.append(current_time)

            current_time += time_step

        # Save approximation results in dictionary
        approx_stats = {'projection': projection, 'time_history': time_history,
                        'troubled_cell_history': troubled_cell_history}

        # Encode all ndarrays to fit JSON format
        approx_stats = {key: encode_ndarray(approx_stats[key])
                        for key in approx_stats.keys()}

        # Save approximation results in JSON format
        with open(data_file + '.json', 'w') \
                as json_file:
            json_file.write(json.dumps(approx_stats))

    def _reset(self):
        """Resets instance variables."""
        # Set additional necessary instance variables
        self._interval_len = self._right_bound-self._left_bound
        self._cell_len = self._interval_len / self._num_grid_cells
        self._basis = OrthonormalLegendre(self._polynomial_degree)

        # Set additional necessary config parameters
        self._limiter_config['cell_len'] = self._cell_len

        # Set mesh with one ghost point on each side
        self._mesh = np.arange(self._left_bound - (3/2*self._cell_len),
                               self._right_bound + (5/2*self._cell_len),
                               self._cell_len)  # +3/2

    def build_training_data(self, adjustment, stencil_length,
                            add_reconstructions, initial_condition=None):
        """Builds training data set.

        Initializes projection and calculates cell averages and
        reconstructions for it.

        Parameters
        ----------
        adjustment : float
            Extent of adjustment of each evaluation point in x-direction.
        stencil_length : int
            Size of training data array.
        add_reconstructions: bool
            Flag whether reconstructions of the middle cell are included.
        initial_condition : InitialCondition object, optional
            Initial condition used for calculation.
            Default: None (i.e. instance variable).

        Returns
        -------
        ndarray
            Matrix containing cell averages and reconstructions for initial
            projection.

        """
        if initial_condition is None:
            initial_condition = self._init_cond
        projection = do_initial_projection(
            initial_condition=initial_condition, basis=self._basis,
            quadrature=self._quadrature, num_grid_cells=self._num_grid_cells,
            left_bound=self._left_bound, right_bound=self._right_bound,
            polynomial_degree=self._polynomial_degree, adjustment=adjustment)

        return self._detector.calculate_cell_average(projection[:, 1:-1],
                                                     stencil_length,
                                                     add_reconstructions)


def do_initial_projection(initial_condition, basis, quadrature,
                          num_grid_cells, left_bound, right_bound,
                          polynomial_degree, adjustment=0):
    """Calculates initial projection.

    Calculates a projection at time step 0 and adds ghost cells on both
    sides of the array.

    Parameters
    ----------
    initial_condition : InitialCondition object
        Initial condition used for calculation.
    basis: Vector object
        Basis used for calculation.
    quadrature: Quadrature object
        Quadrature fused for evaluation.
    num_grid_cells : int
        Number of cells in the mesh. Usually exponential of 2.
    left_bound : float
        Left boundary of interval.
    right_bound : float
        Right boundary of interval.
    polynomial_degree : int
        Polynomial degree.
    adjustment: float, optional
        Extent of adjustment of each evaluation point in x-direction.
        Default: 0.

    Returns
    -------
    ndarray
        Matrix containing projection of size (N+2, p+1) with N being the
        number of grid cells and p being the polynomial degree.

    """
    # Set inverse mass matrix
    mass_matrix = []
    for i in range(polynomial_degree+1):
        new_row = []
        for j in range(polynomial_degree+1):
            new_entry = 0.0
            if i == j:
                new_entry = 1.0
            new_row.append(new_entry)
        mass_matrix.append(new_row)
    inv_mass = np.array(mass_matrix)

    # Initialize matrix and set first entry to accommodate for ghost cell
    output_matrix = [0]
    basis_vector = basis.get_basis_vector()

    cell_len = (right_bound-left_bound)/num_grid_cells
    for cell in range(num_grid_cells):
        new_row = []
        eval_point = left_bound + (cell+0.5)*cell_len

        for degree in range(polynomial_degree + 1):
            new_entry = sum(
                initial_condition.calculate(
                    eval_point + cell_len/2
                    * quadrature.get_eval_points()[point]
                    - adjustment)
                * basis_vector[degree].subs(
                    x, quadrature.get_eval_points()[point])
                * quadrature.get_weights()[point]
                for point in range(quadrature.get_num_points()))
            new_row.append(np.float64(new_entry))

        new_row = np.array(new_row)
        output_matrix.append(inv_mass @ new_row)

    # Set ghost cells to respective value
    output_matrix[0] = output_matrix[num_grid_cells]
    output_matrix.append(output_matrix[1])

    # print(np.array(output_matrix).shape)
    return np.transpose(np.array(output_matrix))


def plot_approximation_results(detector, data_file, directory, plot_name):
    """Plots given approximation results.

    Generates plots based on given data, sets plot directory if not
    already existing, and saves plots.

    Parameters
    ----------
    data_file: str
        Path to data file for plotting.
    directory: str
        Path to directory in which plots will be saved.
    plot_name : str
        Name of plot.

    """
    # Read approximation results
    with open(data_file + '.json') as json_file:
        approx_stats = json.load(json_file)

    # Decode all ndarrays by converting lists
    approx_stats = {key: decode_ndarray(approx_stats[key])
                    for key in approx_stats.keys()}

    # Plot exact/approximate results, errors, shock tubes,
    # and any detector-dependant plots
    detector.plot_results(**approx_stats)

    # Set paths for plot files if not existing already
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save plots
    for identifier in plt.get_figlabels():
        # Set path for figure directory if not existing already
        if not os.path.exists(directory + '/' + identifier):
            os.makedirs(directory + '/' + identifier)

        plt.figure(identifier)
        plt.savefig(directory + '/' + identifier + '/' +
                    plot_name + '.pdf')