# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
import json
import numpy as np
from sympy import Symbol
import math

from .Basis_Function import Basis
from .encoding_utils import encode_ndarray
from .Initial_Condition import InitialCondition
from .Mesh import Mesh
from .Quadrature import Quadrature
from .Troubled_Cell_Detector import TroubledCellDetector
from .Update_Scheme import UpdateScheme

x = Symbol('x')


class DGScheme:
    """Class for Discontinuous Galerkin Method.

    Approximates linear advection equation using Discontinuous Galerkin Method
    with troubled-cell-based limiting.

    Methods
    -------
    approximate()
        Approximates projection.
    save_plots()
        Saves plots generated during approximation process.
    build_training_data(x_shift, stencil_length, init_cond=None)
        Builds training data set.

    """
    def __init__(self, detector: TroubledCellDetector,
                 quadrature: Quadrature, init_cond: InitialCondition,
                 update_scheme: UpdateScheme, mesh: Mesh, basis: Basis,
                 wave_speed, **kwargs):
        """Initializes DGScheme.

        Parameters
        ----------
        detector : TroubledCellDetector object
            Troubled cell detector.
        quadrature : Quadrature object
            Quadrature for evaluation.
        init_cond : InitialCondition object
            Initial condition for evaluation.
        update_scheme : UpdateScheme object
            Update scheme for evaluation.
        mesh : Mesh object
            Mesh for calculation.
        basis : Basis object
            Basis for calculation.
        wave_speed : float, optional
            Speed of wave in rightward direction.

        Other Parameters
        ----------------
        cfl_number : float, optional
            CFL number to ensure stability. Default: 0.2.
        final_time : float, optional
            Final time for which approximation is calculated. Default: 1.
        verbose : bool, optional
            Flag whether commentary in console is wanted. Default: False.
        history_threshold : float, optional
            Threshold when history will be recorded.
            Default: math.ceil(0.2/cfl_number).

        """
        self._detector = detector
        self._quadrature = quadrature
        self._init_cond = init_cond
        self._update_scheme = update_scheme
        self._mesh = mesh
        self._basis = basis
        self._wave_speed = wave_speed

        # Unpack keyword arguments
        self._cfl_number = kwargs.pop('cfl_number', 0.2)
        self._final_time = kwargs.pop('final_time', 1)
        self._verbose = kwargs.pop('verbose', False)
        self._history_threshold = kwargs.pop('history_threshold',
                                             math.ceil(0.2/self._cfl_number))

        # Throw an error if there are extra keyword arguments
        if len(kwargs) > 0:
            extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
            raise ValueError('Unrecognized arguments: %s' % extra)

    def approximate(self, data_file):
        """Approximates projection.

        Initializes projection and evolves it in time. Each time step consists
        of three parts: A projection update, a troubled-cell detection,
        and limiting based on the detected cells.

        At final time, results are saved in JSON file.

        Attributes
        ----------
        data_file: str
            Path to file in which data will be saved.

        """
        projection = do_initial_projection(
            init_cond=self._init_cond, mesh=self._mesh,
            basis=self._basis, quadrature=self._quadrature)

        time_step = abs(self._cfl_number * self._mesh.cell_len /
                        self._wave_speed)

        current_time = 0
        iteration = 0
        troubled_cell_history = []
        time_history = []
        while current_time < self._final_time:
            # Adjust for last cell
            cfl_number = self._cfl_number
            if current_time+time_step > self._final_time:
                time_step = self._final_time-current_time
                cfl_number = self._wave_speed * time_step / self._mesh.cell_len

            # Update projection
            projection, troubled_cells = self._update_scheme.step(projection,
                                                                  cfl_number)
            iteration += 1

            if (iteration % self._history_threshold) == 0:
                troubled_cell_history.append(troubled_cells)
                time_history.append(current_time)

            current_time += time_step

        # Save detector-specific data in dictionary
        approx_stats = self._detector.create_data_dict(projection)

        # Save approximation results in dictionary
        approx_stats['wave_speed'] = self._wave_speed
        approx_stats['final_time'] = self._final_time
        approx_stats['time_history'] = time_history
        approx_stats['troubled_cell_history'] = troubled_cell_history

        # Encode all ndarrays to fit JSON format
        approx_stats = {key: encode_ndarray(approx_stats[key])
                        for key in approx_stats.keys()}

        # Save approximation results in JSON format
        with open(data_file + '.json', 'w') \
                as json_file:
            json_file.write(json.dumps(approx_stats))


def do_initial_projection(init_cond, mesh, basis, quadrature,
                          x_shift=0):
    """Calculates initial projection.

    Calculates a projection at time step 0 and adds ghost cells on both
    sides of the array.

    Parameters
    ----------
    init_cond : InitialCondition object
        Initial condition used for calculation.
    mesh : Mesh object
        Mesh for calculation.
    basis: Basis object
        Basis used for calculation.
    quadrature: Quadrature object
        Quadrature used for evaluation.
    x_shift: float, optional
        Shift in x-direction for each evaluation point. Default: 0.

    Returns
    -------
    ndarray
        Matrix containing projection of size (N+2, p+1) with N being the
        number of mesh cells and p being the polynomial degree of the basis.

    """
    # Initialize output matrix
    output_matrix = np.zeros((basis.polynomial_degree+1, mesh.num_cells+2))

    # Calculate basis based on quadrature
    basis_matrix = np.array([np.vectorize(basis.basis[degree].subs)(
        x, quadrature.nodes) for degree in range(basis.polynomial_degree+1)])

    # Calculate points based on initial condition
    points = quadrature.nodes[:, np.newaxis] @ \
        np.repeat(mesh.cell_len/2, mesh.num_cells)[:, np.newaxis].T + \
        np.tile(mesh.non_ghost_cells - x_shift, (quadrature.num_nodes, 1))
    init_matrix = np.vectorize(init_cond.calculate, otypes=[np.float])(
        x=points, mesh=mesh)

    # Set output matrix for regular cells
    output_matrix[:, 1:-1] = (basis_matrix * quadrature.weights) @ init_matrix

    # Set ghost cells
    output_matrix[:, 0] = output_matrix[:, -2]
    output_matrix[:, -1] = output_matrix[:, 1]

    return output_matrix