# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
from abc import ABC, abstractmethod
import numpy as np
import time


class UpdateScheme(ABC):
    """Abstract class for updating projections at a time step.

    Attributes
    ----------
    stiffness_matrix : ndarray
        Matrix
    boundary_matrix : ndarray
        Matrix

    Methods
    -------
    get_name()
        Returns string of class name.
    step(projection, cfl_number)
        Performs time step.

    """
    def __init__(self, polynomial_degree, mesh, detector, limiter,
                 wave_speed):
        """Initializes UpdateScheme.

        Parameters
        ----------
        polynomial_degree : int
            Polynomial degree.
        mesh : Mesh
            Mesh for calculation.
        detector : TroubledCellDetector object
            Troubled cell detector for evaluation.
        limiter : Limiter object
            Limiter for evaluation.
        wave_speed : float
            Speed of wave in rightward direction.

        """
        # Unpack positional arguments
        self._polynomial_degree = polynomial_degree
        self._mesh = mesh
        self._detector = detector
        self._limiter = limiter
        self._wave_speed = wave_speed

        self._reset()

    def _reset(self):
        """Resets instance variables."""
        matrix_shape = (self._polynomial_degree+1, self._polynomial_degree+1)
        root_vector = np.array([np.sqrt(degree+0.5) for degree in range(
            self._polynomial_degree+1)])
        degree_matrix = np.matmul(root_vector[:, np.newaxis],
                                  root_vector[:, np.newaxis].T)

        # Set stiffness matrix
        matrix = np.ones(matrix_shape)
        if self._wave_speed > 0:
            matrix[np.fromfunction(
                lambda i, j: (j >= i) | ((i+j) % 2 == 0), matrix_shape)] = -1.0
        else:
            matrix[np.fromfunction(
                lambda i, j: (j > i) & ((i+j) % 2 == 1), matrix_shape)] = -1.0
        self._stiffness_matrix = matrix * degree_matrix

        # Set boundary matrix
        matrix = np.fromfunction(lambda i, j: (-1.0)**i, matrix_shape) \
            if self._wave_speed > 0 \
            else np.fromfunction(lambda i, j: (-1.0)**j, matrix_shape)
        self._boundary_matrix = matrix * degree_matrix

    def get_name(self):
        """Returns string of class name."""
        return self.__class__.__name__

    def step(self, projection, cfl_number):
        """Performs time step.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.
        troubled_cells : list
            List of indices for all detected troubled cells.

        """
        current_projection, troubled_cells = self._apply_stability_method(
            projection, cfl_number)

        return current_projection, troubled_cells

    @abstractmethod
    def _apply_stability_method(self, projection, cfl_number):
        """Applies stability method.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.
        troubled_cells : list
            List of indices for all detected troubled cells.

        """
        pass

    def _apply_limiter(self, current_projection):
        """Applies limiter on troubled cells.

        Parameters
        ----------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.

        Returns
        -------
        new_projection : ndarray
            Matrix of updated projection for each polynomial degree.
        troubled_cells : list
            List of indices for all detected troubled cells.

        """
        troubled_cells = self._detector.get_cells(current_projection)
        new_projection = self._limiter.apply(current_projection,
                                             troubled_cells)

        return new_projection, troubled_cells

    def _enforce_boundary_condition(self, current_projection):
        """Enforces boundary condition.

        Adjusts ghost cells to ensure periodic boundary condition.

        Parameters
        ----------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.

        Returns
        -------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.

        """
        current_projection[:, :self._mesh.num_ghost_cells] = \
            current_projection[
            :, -2*self._mesh.num_ghost_cells:-self._mesh.num_ghost_cells]
        current_projection[:, -self._mesh.num_ghost_cells:] = \
            current_projection[
            :, self._mesh.num_ghost_cells:2*self._mesh.num_ghost_cells]
        return current_projection


class SSPRK3(UpdateScheme):
    """Class for strong stability-preserving Runge Kutta of order 3.

    Notes
    -----
    Reference (?)

    """
    # Override method of superclass
    def _apply_stability_method(self, projection, cfl_number):
        """Applies stability method.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.
        troubled_cells : list
            List of indices for all detected troubled cells.

        """
        original_projection = projection

        current_projection = self._apply_first_step(original_projection,
                                                    cfl_number)
        current_projection, __ = self._apply_limiter(current_projection)
        current_projection = self._enforce_boundary_condition(
            current_projection)

        current_projection = self._apply_second_step(original_projection,
                                                     current_projection,
                                                     cfl_number)
        current_projection, __ = self._apply_limiter(current_projection)
        current_projection = self._enforce_boundary_condition(
            current_projection)

        current_projection = self._apply_third_step(original_projection,
                                                    current_projection,
                                                    cfl_number)
        current_projection, troubled_cells = self._apply_limiter(
            current_projection)
        current_projection = self._enforce_boundary_condition(
            current_projection)

        return current_projection, troubled_cells

    def _apply_first_step(self, original_projection, cfl_number):
        """Applies first step of SSPRK3.

        Parameters
        ----------
        original_projection : ndarray
            Matrix of original projection for each polynomial degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        ndarray
            Matrix of updated projection for each polynomial degree.

        """
        right_hand_side = self._update_right_hand_side(original_projection)
        return original_projection + (cfl_number*right_hand_side)

    def _apply_second_step(self, original_projection, current_projection,
                           cfl_number):
        """Applies second step of SSPRK3.

        Parameters
        ----------
        original_projection : ndarray
            Matrix of original projection for each polynomial degree.
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        ndarray
            Matrix of updated projection for each polynomial degree.

        """
        right_hand_side = self._update_right_hand_side(current_projection)
        return 1/4 * (3*original_projection
                      + (current_projection + cfl_number*right_hand_side))

    def _apply_third_step(self, original_projection, current_projection,
                          cfl_number):
        """Applies third step of SSPRK3.

        Parameter
        ---------
        original_projection : ndarray
            Matrix of original projection for each polynomial degree.
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.
        cfl_number : float
            CFL number to ensure stability.

        Returns
        -------
        ndarray
            Matrix of updated projection for each polynomial degree.

        """
        right_hand_side = self._update_right_hand_side(current_projection)
        return 1/3 * (original_projection
                      + 2*(current_projection + cfl_number*right_hand_side))

    def _update_right_hand_side(self, current_projection):
        """Updates right-hand side.

        Parameter
        ---------
        current_projection : ndarray
            Matrix of projection of current update step for each polynomial
            degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        right_hand_side = np.zeros_like(current_projection)
        if self._wave_speed > 0:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._boundary_matrix @
                     current_projection[:, self._mesh.num_ghost_cells-1:
                                        -self._mesh.num_ghost_cells-1] +
                     self._stiffness_matrix @
                     current_projection[:, self._mesh.num_ghost_cells:
                                        -self._mesh.num_ghost_cells])
        else:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._boundary_matrix @
                     current_projection[:, self._mesh.num_ghost_cells+1:] +
                     self._stiffness_matrix @
                     current_projection[:, self._mesh.num_ghost_cells:
                                        -self._mesh.num_ghost_cells])

        # Set ghost cells to respective value
        right_hand_side[:, :self._mesh.num_ghost_cells] = right_hand_side[
            :, -2*self._mesh.num_ghost_cells:-self._mesh.num_ghost_cells]
        right_hand_side[:, -self._mesh.num_ghost_cells:] = right_hand_side[
            :, self._mesh.num_ghost_cells:2*self._mesh.num_ghost_cells]

        return right_hand_side