# -*- coding: utf-8 -*-
"""Module for equation class.

@author: Laura C. Kühle
"""
from typing import Tuple
from abc import ABC, abstractmethod
import numpy as np
from numpy import ndarray

from .Basis_Function import Basis
from .Boundary_Condition import enforce_boundary
from .Initial_Condition import InitialCondition
from .Mesh import Mesh
from .projection_utils import do_initial_projection
from .Quadrature import Quadrature


class Equation(ABC):
    """Abstract class for equation.

    Methods
    -------
    initialize_approximation()
        Initialize projection and time step for approximation.
    update_time_step(current_time, time_step)
        Update time step.
    solve_exactly(mesh)
        Calculate exact solution.
    update_right_hand_side(projection)
        Update right-hand side.
    """

    def __init__(self, quadrature: Quadrature, init_cond: InitialCondition,
                 basis: Basis, mesh: Mesh, final_time: float,
                 wave_speed: float, cfl_number: float) -> None:
        """Initialize equation class.

        Parameters
        ----------
        quadrature : Quadrature object
            Quadrature for evaluation.
        init_cond : InitialCondition object
            Initial condition for evaluation.
        basis : Basis object
            Basis for calculation.
        mesh : Mesh object
            Mesh for calculation.
        final_time : float
            Final time for which approximation is calculated.
        wave_speed : float
            Speed of wave in rightward direction.
        cfl_number : float
            CFL number to ensure stability.

        Raises
        ------
        ValueError
            If number of ghost cells in mesh in not positive.

        """
        self._quadrature = quadrature
        self._init_cond = init_cond
        self._basis = basis
        self._mesh = mesh
        self._final_time = final_time
        self._wave_speed = wave_speed
        self._cfl_number = cfl_number

        if self._mesh.num_ghost_cells <= 0:
            raise ValueError('Number of ghost cells for calculations has to '
                             'be positive.')

        self._reset()

    @property
    def basis(self) -> Basis:
        """Return basis."""
        return self._basis

    @property
    def mesh(self) -> Mesh:
        """Return basis."""
        return self._mesh

    @property
    def quadrature(self) -> Quadrature:
        """Return basis."""
        return self._quadrature

    @property
    def final_time(self) -> float:
        """Return final time."""
        return self._final_time

    @property
    def wave_speed(self) -> float:
        """Return wave speed."""
        return self._wave_speed

    @property
    def cfl_number(self) -> float:
        """Return CFL number."""
        return self._cfl_number

    def _reset(self) -> None:
        """Reset instance variables."""
        pass

    def initialize_approximation(self) -> Tuple[float, ndarray]:
        """Initialize projection and time step for approximation."""
        return self._initialize_time_step(), self._initialize_projection()

    def _initialize_time_step(self):
        """Initialize time step."""
        return abs(self._cfl_number * self._mesh.cell_len / self._wave_speed)

    @abstractmethod
    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        pass

    def create_data_dict(self):
        """Return dictionary with data necessary to construct equation."""
        return {'basis': self._basis.create_data_dict(),
                'mesh': self._mesh.create_data_dict(),
                'final_time': self._final_time,
                'wave_speed': self._wave_speed,
                'cfl_number': self._cfl_number
                }

    @abstractmethod
    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        pass

    @abstractmethod
    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        pass

    @abstractmethod
    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        pass


class LinearAdvection(Equation):
    """Class for linear advection equation.

    Attributes
    ----------
    volume_integral_matrix : ndarray
        Volume integral matrix.
    flux_matrix : ndarray
        Flux matrix.

    Methods
    -------
    update_right_hand_side(projection)
        Update right-hand side.

    Notes
    -----
    .. math:: u_t + u_x = 0

    """

    def _reset(self) -> None:
        """Reset instance variables."""
        matrix_shape = (self._basis.polynomial_degree+1,
                        self._basis.polynomial_degree+1)
        root_vector = np.array([np.sqrt(degree+0.5) for degree in range(
            self._basis.polynomial_degree+1)])
        degree_matrix = np.matmul(root_vector[:, np.newaxis],
                                  root_vector[:, np.newaxis].T)

        # Set volume integral matrix
        matrix = np.ones(matrix_shape)
        if self._wave_speed > 0:
            matrix[np.fromfunction(
                lambda i, j: (j >= i) | ((i+j) % 2 == 0), matrix_shape)] = -1.0
        else:
            matrix[np.fromfunction(
                lambda i, j: (j > i) & ((i+j) % 2 == 1), matrix_shape)] = -1.0
        self._volume_integral_matrix = matrix * degree_matrix

        # Set flux matrix
        matrix = np.fromfunction(lambda i, j: (-1.0)**i, matrix_shape) \
            if self._wave_speed > 0 \
            else np.fromfunction(lambda i, j: (-1.0)**j, matrix_shape)
        self._flux_matrix = matrix * degree_matrix

    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        return do_initial_projection(init_cond=self._init_cond,
                                     mesh=self._mesh, basis=self._basis,
                                     quadrature=self._quadrature)

    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        cfl_number = self._cfl_number

        # Adjust for final time-step
        if current_time+time_step > self.final_time:
            time_step = self.final_time-current_time
            cfl_number = self.wave_speed * time_step / self._mesh.cell_len

        return cfl_number, time_step

    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        num_periods = np.floor(self.wave_speed * self.final_time /
                               mesh.interval_len)

        grid = np.repeat(mesh.non_ghost_cells, self._quadrature.num_nodes) + \
            mesh.cell_len / 2 * np.tile(self._quadrature.nodes, mesh.num_cells)

        # Project points into correct periodic interval
        points = np.array([point-self.wave_speed *
                           self.final_time+num_periods * mesh.interval_len
                           for point in grid])
        left_bound, right_bound = mesh.bounds
        while np.any(points < left_bound):
            points[points < left_bound] += mesh.interval_len
        while np.any(points) > right_bound:
            points[points > right_bound] -= mesh.interval_len

        exact = np.array([self._init_cond.calculate(mesh=mesh, x=point) for
                          point in points])

        grid = np.reshape(grid, (1, grid.size))
        exact = np.reshape(exact, (1, exact.size))

        return grid, exact

    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        right_hand_side = np.zeros_like(projection)
        if self._wave_speed > 0:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._flux_matrix @
                     projection[:, self._mesh.num_ghost_cells-1:
                                -self._mesh.num_ghost_cells-1] +
                     self._volume_integral_matrix @
                     projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells])
        else:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._flux_matrix @
                     projection[:, self._mesh.num_ghost_cells+1:] +
                     self._volume_integral_matrix @
                     projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells])

        return right_hand_side


class Burgers(Equation):
    """Class for Burgers' equation.

    Attributes
    ----------
    volume_integral_matrix : ndarray
        Volume integral matrix.
    flux_matrix : ndarray
        Flux matrix.

    Methods
    -------
    update_right_hand_side(projection)
        Update right-hand side.

    Notes
    -----
    .. math:: u_t + u*u_x = 0

    """

    def _reset(self) -> None:
        """Reset instance variables."""
        pass
        # matrix_shape = (self._basis.polynomial_degree+1,
        #                 self._basis.polynomial_degree+1)
        # root_vector = np.array([np.sqrt(degree+0.5) for degree in range(
        #     self._basis.polynomial_degree+1)])
        # degree_matrix = np.matmul(root_vector[:, np.newaxis],
        #                           root_vector[:, np.newaxis].T)
        #
        # # Set volume integral matrix
        # matrix = np.ones(matrix_shape)
        # if self._wave_speed > 0:
        #     matrix[np.fromfunction(
        #         lambda i, j: (j >= i) | ((i+j) % 2 == 0), matrix_shape)] = -1.0
        # else:
        #     matrix[np.fromfunction(
        #         lambda i, j: (j > i) & ((i+j) % 2 == 1), matrix_shape)] = -1.0
        # self._volume_integral_matrix = matrix * degree_matrix
        #
        # # Set flux matrix
        # matrix = np.fromfunction(lambda i, j: (-1.0)**i, matrix_shape) \
        #     if self._wave_speed > 0 \
        #     else np.fromfunction(lambda i, j: (-1.0)**j, matrix_shape)
        # self._flux_matrix = matrix * degree_matrix

    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        return do_initial_projection(init_cond=self._init_cond,
                                     mesh=self._mesh, basis=self._basis,
                                     quadrature=self._quadrature)

    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Adapt CFL number to ensure right-hand side is multiplied with dt/dx.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        cfl_number = self._cfl_number

        max_velocity = max(abs(projection[0, :] * np.sqrt(0.5)))
        time_step = cfl_number * self._mesh.cell_len / max_velocity
        cfl_number = time_step / self._mesh.cell_len

        # Adjust for final time-step
        if current_time+time_step > self._final_time:
            time_step = self._final_time-current_time
            cfl_number = time_step / self._mesh.cell_len

        return cfl_number, time_step

    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        pass
        # num_periods = np.floor(self.wave_speed * self.final_time /
        #                        mesh.interval_len)
        #
        # grid = np.repeat(mesh.non_ghost_cells, self._quadrature.num_nodes) + \
        #     mesh.cell_len / 2 * np.tile(self._quadrature.nodes, mesh.num_cells)
        #
        # # Project points into correct periodic interval
        # points = np.array([point-self.wave_speed *
        #                    self.final_time+num_periods * mesh.interval_len
        #                    for point in grid])
        # left_bound, right_bound = mesh.bounds
        # while np.any(points < left_bound):
        #     points[points < left_bound] += mesh.interval_len
        # while np.any(points) > right_bound:
        #     points[points > right_bound] -= mesh.interval_len
        #
        # exact = np.array([self._init_cond.calculate(mesh=mesh, x=point) for
        #                   point in points])
        #
        # grid = np.reshape(grid, (1, grid.size))
        # exact = np.reshape(exact, (1, exact.size))
        #
        # return grid, exact

    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        pass
        # right_hand_side = np.zeros_like(projection)
        # if self._wave_speed > 0:
        #     right_hand_side[:, self._mesh.num_ghost_cells:
        #                     -self._mesh.num_ghost_cells] = \
        #         2 * (self._flux_matrix @
        #              projection[:, self._mesh.num_ghost_cells-1:
        #                         -self._mesh.num_ghost_cells-1] +
        #              self._volume_integral_matrix @
        #              projection[:, self._mesh.num_ghost_cells:
        #                         -self._mesh.num_ghost_cells])
        # else:
        #     right_hand_side[:, self._mesh.num_ghost_cells:
        #                     -self._mesh.num_ghost_cells] = \
        #         2 * (self._flux_matrix @
        #              projection[:, self._mesh.num_ghost_cells+1:] +
        #              self._volume_integral_matrix @
        #              projection[:, self._mesh.num_ghost_cells:
        #                         -self._mesh.num_ghost_cells])
        #
        # return right_hand_side