# -*- coding: utf-8 -*-
"""Module for equation class.

@author: Laura C. Kühle
"""
from typing import Tuple
from abc import ABC, abstractmethod
import numpy as np
from numpy import ndarray
from sympy import Symbol

from .Basis_Function import Basis
from .Boundary_Condition import enforce_boundary
from .Initial_Condition import InitialCondition
from .Mesh import Mesh
from .projection_utils import do_initial_projection
from .Quadrature import Quadrature


x = Symbol('x')


class Equation(ABC):
    """Abstract class for equation.

    Methods
    -------
    initialize_approximation()
        Initialize projection and time step for approximation.
    update_time_step(current_time, time_step)
        Update time step.
    solve_exactly(mesh)
        Calculate exact solution.
    update_right_hand_side(projection)
        Update right-hand side.
    """

    def __init__(self, quadrature: Quadrature, init_cond: InitialCondition,
                 basis: Basis, mesh: Mesh, final_time: float,
                 wave_speed: float, cfl_number: float) -> None:
        """Initialize equation class.

        Parameters
        ----------
        quadrature : Quadrature object
            Quadrature for evaluation.
        init_cond : InitialCondition object
            Initial condition for evaluation.
        basis : Basis object
            Basis for calculation.
        mesh : Mesh object
            Mesh for calculation.
        final_time : float
            Final time for which approximation is calculated.
        wave_speed : float
            Speed of wave in rightward direction.
        cfl_number : float
            CFL number to ensure stability.

        Raises
        ------
        ValueError
            If number of ghost cells in mesh in not positive.

        """
        self._quadrature = quadrature
        self._init_cond = init_cond
        self._basis = basis
        self._mesh = mesh
        self._final_time = final_time
        self._wave_speed = wave_speed
        self._cfl_number = cfl_number

        if self._mesh.num_ghost_cells <= 0:
            raise ValueError('Number of ghost cells for calculations has to '
                             'be positive.')

        self._reset()

    @property
    def name(self) -> str:
        """Return string of class name."""
        return self.__class__.__name__

    @property
    def basis(self) -> Basis:
        """Return basis."""
        return self._basis

    @property
    def mesh(self) -> Mesh:
        """Return basis."""
        return self._mesh

    @property
    def quadrature(self) -> Quadrature:
        """Return basis."""
        return self._quadrature

    @property
    def final_time(self) -> float:
        """Return final time."""
        return self._final_time

    @property
    def wave_speed(self) -> float:
        """Return wave speed."""
        return self._wave_speed

    @property
    def cfl_number(self) -> float:
        """Return CFL number."""
        return self._cfl_number

    def _reset(self) -> None:
        """Reset instance variables."""
        pass

    def initialize_approximation(self) -> Tuple[float, ndarray]:
        """Initialize projection and time step for approximation."""
        return self._initialize_time_step(), self._initialize_projection()

    def _initialize_time_step(self):
        """Initialize time step."""
        return abs(self._cfl_number * self._mesh.cell_len / self._wave_speed)

    @abstractmethod
    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        pass

    def create_data_dict(self):
        """Return dictionary with data necessary to construct equation."""
        return {'basis': self._basis.create_data_dict(),
                'mesh': self._mesh.create_data_dict(),
                'final_time': self._final_time,
                'wave_speed': self._wave_speed,
                'cfl_number': self._cfl_number
                }

    @abstractmethod
    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        pass

    @abstractmethod
    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        pass

    @abstractmethod
    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        pass


class LinearAdvection(Equation):
    """Class for linear advection equation.

    Attributes
    ----------
    volume_integral_matrix : ndarray
        Volume integral matrix.
    flux_matrix : ndarray
        Flux matrix.

    Methods
    -------
    update_right_hand_side(projection)
        Update right-hand side.

    Notes
    -----
    .. math:: u_t + u_x = 0

    """

    def _reset(self) -> None:
        """Reset instance variables."""
        matrix_shape = (self._basis.polynomial_degree+1,
                        self._basis.polynomial_degree+1)
        root_vector = np.array([np.sqrt(degree+0.5) for degree in range(
            self._basis.polynomial_degree+1)])
        degree_matrix = np.matmul(root_vector[:, np.newaxis],
                                  root_vector[:, np.newaxis].T)

        # Set volume integral matrix
        matrix = np.ones(matrix_shape)
        if self._wave_speed > 0:
            matrix[np.fromfunction(
                lambda i, j: (j >= i) | ((i+j) % 2 == 0), matrix_shape)] = -1.0
        else:
            matrix[np.fromfunction(
                lambda i, j: (j > i) & ((i+j) % 2 == 1), matrix_shape)] = -1.0
        self._volume_integral_matrix = matrix * degree_matrix

        # Set flux matrix
        matrix = np.fromfunction(lambda i, j: (-1.0)**i, matrix_shape) \
            if self._wave_speed > 0 \
            else np.fromfunction(lambda i, j: (-1.0)**j, matrix_shape)
        self._flux_matrix = matrix * degree_matrix

    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        return do_initial_projection(init_cond=self._init_cond,
                                     mesh=self._mesh, basis=self._basis,
                                     quadrature=self._quadrature)

    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        cfl_number = self._cfl_number

        # Adjust for final time-step
        if current_time+time_step > self.final_time:
            time_step = self.final_time-current_time
            cfl_number = abs(self.wave_speed * time_step / self._mesh.cell_len)

        return cfl_number, time_step

    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        num_periods = np.floor(self.wave_speed * self.final_time /
                               mesh.interval_len)

        grid = np.repeat(mesh.non_ghost_cells, self._quadrature.num_nodes) + \
            mesh.cell_len / 2 * np.tile(self._quadrature.nodes, mesh.num_cells)

        # Project points into correct periodic interval
        points = np.array([point-self.wave_speed *
                           self.final_time+num_periods * mesh.interval_len
                           for point in grid])
        left_bound, right_bound = mesh.bounds
        while np.any(points < left_bound):
            points[points < left_bound] += mesh.interval_len
        while np.any(points) > right_bound:
            points[points > right_bound] -= mesh.interval_len

        exact = np.array([self._init_cond.calculate(mesh=mesh, x=point) for
                          point in points])

        grid = np.reshape(grid, (1, grid.size))
        exact = np.reshape(exact, (1, exact.size))

        return grid, exact

    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        right_hand_side = np.zeros_like(projection)
        if self._wave_speed > 0:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._flux_matrix @
                     projection[:, self._mesh.num_ghost_cells-1:
                                -self._mesh.num_ghost_cells-1] +
                     self._volume_integral_matrix @
                     projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells])
        else:
            right_hand_side[:, self._mesh.num_ghost_cells:
                            -self._mesh.num_ghost_cells] = \
                2 * (self._flux_matrix @
                     projection[:, self._mesh.num_ghost_cells+1:] +
                     self._volume_integral_matrix @
                     projection[:, self._mesh.num_ghost_cells:
                                -self._mesh.num_ghost_cells])

        return right_hand_side


class Burgers(Equation):
    """Class for inviscid Burgers' equation.

    Attributes
    ----------
    volume_integral_matrix : ndarray
        Volume integral matrix.
    flux_matrix : ndarray
        Flux matrix.

    Methods
    -------
    update_right_hand_side(projection)
        Update right-hand side.

    Notes
    -----
    .. math:: u_t + u*u_x = 0

    """

    def _reset(self) -> None:
        """Reset instance variables."""
        pass
        # matrix_shape = (self._basis.polynomial_degree+1,
        #                 self._basis.polynomial_degree+1)
        # root_vector = np.array([np.sqrt(degree+0.5) for degree in range(
        #     self._basis.polynomial_degree+1)])
        # degree_matrix = np.matmul(root_vector[:, np.newaxis],
        #                           root_vector[:, np.newaxis].T)
        #
        # # Set volume integral matrix
        # matrix = np.ones(matrix_shape)
        # if self._wave_speed > 0:
        #     matrix[np.fromfunction(
        #         lambda i, j: (j >= i) | ((i+j) % 2 == 0), matrix_shape)] = -1.0
        # else:
        #     matrix[np.fromfunction(
        #         lambda i, j: (j > i) & ((i+j) % 2 == 1), matrix_shape)] = -1.0
        # self._volume_integral_matrix = matrix * degree_matrix
        #
        # # Set flux matrix
        # matrix = np.fromfunction(lambda i, j: (-1.0)**i, matrix_shape) \
        #     if self._wave_speed > 0 \
        #     else np.fromfunction(lambda i, j: (-1.0)**j, matrix_shape)
        # self._flux_matrix = matrix * degree_matrix

    @enforce_boundary()
    def _initialize_projection(self) -> ndarray:
        """Initialize projection."""
        return do_initial_projection(init_cond=self._init_cond,
                                     mesh=self._mesh, basis=self._basis,
                                     quadrature=self._quadrature)

    def update_time_step(self, projection: ndarray, current_time: float,
                         time_step: float) -> Tuple[float, float]:
        """Update time step.

        Adapt CFL number to ensure right-hand side is multiplied with dt/dx.

        Parameters
        ----------
        projection : ndarray
            Current projection during approximation.
        current_time : float
            Current time during approximation.
        time_step : float
            Length of time step during approximation.
        Returns
        -------
        cfl_number : float
            Updated CFL number to ensure stability.
        time_step : float
            Updated time step for approximation.

        """
        cfl_number = self._cfl_number

        max_velocity = max(abs(projection[0, :] * np.sqrt(0.5)))
        time_step = cfl_number * self._mesh.cell_len / max_velocity
        cfl_number = time_step / self._mesh.cell_len

        # Adjust for final time-step
        if current_time+time_step > self._final_time:
            time_step = self._final_time-current_time
            cfl_number = time_step / self._mesh.cell_len

        return cfl_number, time_step

    def solve_exactly(self, mesh: Mesh) -> Tuple[ndarray, ndarray]:
        """Calculate exact solution.

        Parameters
        ----------
        mesh : Mesh
            Mesh for evaluation.

        Returns
        -------
        grid : ndarray
            Array containing evaluation grid for a function.
        exact : ndarray
            Array containing exact evaluation of a function.

        """
        grid = np.repeat(mesh.non_ghost_cells, self._quadrature.num_nodes) + \
            mesh.cell_len / 2 * np.tile(self._quadrature.nodes, mesh.num_cells)
        #
        # # Project points into correct periodic interval
        # points = np.array([point-self.wave_speed *
        #                    self.final_time+num_periods * mesh.interval_len
        #                    for point in grid])
        # left_bound, right_bound = mesh.bounds
        # while np.any(points < left_bound):
        #     points[points < left_bound] += mesh.interval_len
        # while np.any(points) > right_bound:
        #     points[points > right_bound] -= mesh.interval_len
        #
        # exact = np.array([self._init_cond.calculate(mesh=mesh, x=point) for
        #                   point in points])
        #
        grid = np.reshape(grid, (1, grid.size))
        # exact = np.reshape(exact, (1, exact.size))
        #
        # return grid, exact
        exact = []

        if self._init_cond.get_name() == 'Sine':
            # u(x,t) = u(x - u0*time, 0)
            exact = self.implicit_burgers_solver(grid, mesh)
        elif self._init_cond.get_name() == 'DiscontinuousConstant':
            # u(x,t) = u(x - u0*time, 0)
            exact = self.rarefaction_wave(grid)

        return grid, exact

    def rarefaction_wave(self, grid_values):
        uexact = 0 * grid_values
        N = np.size(grid_values)

        for i in range(N):
            if grid_values[0, i] <= - self._final_time:
                uexact[0, i] = -1
            elif - self._final_time < grid_values[0, i] < self._final_time:
                uexact[0, i] = grid_values[0, i] / self._final_time
            else:
                uexact[0, i] = 1

        return uexact

    def implicit_burgers_solver(self, grid_values, mesh):
        """

        Parameters
        ----------
        grid_values
        mesh

        Returns
        -------

        Notes
        -----
        Code adapted from DG code by Hesthaven.

        """
        # Temporary fix until further clarification
        height_adjustment = 0
        stretch_factor = 1

        # Shock speed = 1/2*(left+right values)
        sspeed = height_adjustment
        uexact = np.zeros(np.size(grid_values))

        # initialize values
        xx = np.linspace(0, 1, 200)
        uu = np.sin(self._init_cond._factor * np.pi * xx)

        # Scale time for domain
        te = self._final_time * (2 / mesh.interval_len)

        # Initial characteristic guess (where is the solution)
        xt0 = (2 / mesh.interval_len) * (
                    grid_values-sspeed * self._final_time).reshape(
            np.size(grid_values))

        # Check whether in interval
        for ix in range(len(xt0)):
            if xt0[ix] > 1:
                xt0[ix] -= 2 * np.floor(0.5 * (xt0[ix]+1))
            elif xt0[ix] < -1:
                xt0[ix] += 2 * np.floor(0.5 * (1-xt0[ix]))

            ay = abs(xt0[ix])
            # Finding the closest characteristic
            i0 = 0
            for j in range(200):
                xt = xx[j]+uu[j] * te
                if ay < xt:
                    break
                else:
                    i0 = j

            # This is our initial guess
            us = uu[i0]
            un = us
            for k in range(400):
                us = un
                x0 = ay-us * te
                un = us-(us-np.sin(self._init_cond._factor * np.pi * x0)) / (
                        1+self._init_cond._factor * np.pi * np.cos(
                    self._init_cond._factor * np.pi * x0) * te)

                if abs(un-us) < 1e-15:
                    break

            y = np.sign(xt0[ix]) * un
            if abs(ay-1) < 1e-15:
                y = 0
            if abs(self._final_time) < 1e-15:
                y = np.sin(self._init_cond._factor * np.pi * xt0[ix])
            uexact[ix] = y

        burgers_exact = uexact.reshape((1, np.size(grid_values)))

        return height_adjustment + stretch_factor * burgers_exact

    @enforce_boundary()
    def update_right_hand_side(self, projection: ndarray) -> ndarray:
        """Update right-hand side.

        Parameter
        ---------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        ndarray
            Matrix of right-hand side.

        """
        pass
        volume_integral_matrix = self._burgers_volume_integral(projection)
        boundary_matrix = self._burgers_boundary_matrix(projection)

        # Initialize vector and set first entry to accommodate for ghost cell
        right_hand_side = [0]

        for j in range(self._mesh.num_cells):
            right_hand_side.append(2*(volume_integral_matrix[:, j + 1] +
                                      boundary_matrix[:, j + 1]))

        # Set ghost cells to respective value
        # (Periodic, Updated to Dirichlet in enforce_boundary_condition
        # for DiscontinuousConstant problem)
        right_hand_side[0] = right_hand_side[self._mesh.num_cells]
        right_hand_side.append(right_hand_side[1])

        return np.transpose(right_hand_side)
        # right_hand_side = np.zeros_like(projection)
        # if self._wave_speed > 0:
        #     right_hand_side[:, self._mesh.num_ghost_cells:
        #                     -self._mesh.num_ghost_cells] = \
        #         2 * (self._flux_matrix @
        #              projection[:, self._mesh.num_ghost_cells-1:
        #                         -self._mesh.num_ghost_cells-1] +
        #              self._volume_integral_matrix @
        #              projection[:, self._mesh.num_ghost_cells:
        #                         -self._mesh.num_ghost_cells])
        # else:
        #     right_hand_side[:, self._mesh.num_ghost_cells:
        #                     -self._mesh.num_ghost_cells] = \
        #         2 * (self._flux_matrix @
        #              projection[:, self._mesh.num_ghost_cells+1:] +
        #              self._volume_integral_matrix @
        #              projection[:, self._mesh.num_ghost_cells:
        #                         -self._mesh.num_ghost_cells])
        #
        # return right_hand_side

    def _burgers_volume_integral(self, projection):
        basis_vector = self._basis.basis
        derivative_basis_vector = self._basis.derivative

        # Initialize matrix and set first entry to accommodate for ghost cell
        volume_integral_matrix = [0]

        for cell in range(self._mesh.num_cells):
            new_row = []
            # Approximation u at Gauss-Legendre points
            approx_eval = np.array(
                [sum(projection[degree][cell + 1] *
                     basis_vector[degree].subs(x,
                                               self._quadrature.nodes[point])
                 for degree in range(self._basis.polynomial_degree + 1))
                 for point in range(self._quadrature.num_nodes)])

            # print("u", approx_eval)
            # Quadrature Evaluation
            for degree in range(self._basis.polynomial_degree + 1):
                weighted_derivatives = [
                    derivative_basis_vector[degree].subs(
                        x, self._quadrature.nodes[point])
                    * self._quadrature.weights[point]
                    for point in range(self._quadrature.num_nodes)]
                new_entry = sum(list(approx_eval**2 * weighted_derivatives))/2

                new_row.append(np.float64(new_entry))

            new_row = np.array(new_row).T
            volume_integral_matrix.append(new_row)

        # Set ghost cells to respective value
        # (Periodic, Updated to Dirichlet in enforce_boundary_condition
        # for DiscontinuousConstant problem)
        volume_integral_matrix[0] = volume_integral_matrix[
            self._mesh.num_cells]
        volume_integral_matrix.append(volume_integral_matrix[1])

        return np.transpose(np.array(volume_integral_matrix))

    def _burgers_local_Lax_Friedrichs(self, projection):
        # Calculating the left and right boundaries for each cell,
        # u_{j-1/2}^+, u_{j+1/2}^-
        boundary_left, boundary_right = self._calculate_boundary_points(
            projection, self._basis.polynomial_degree + 1)
        # print("shape BL", boundary_left.shape)

        # Initializing Burgers local Lax-Friedrichs flux matrix
        burgers_flux = [0]
        for j in range(self._mesh.num_cells):
            # approximations j+1/2^-, j+1/2^+ for interior cells
            # and max velocity
            approx_minus = boundary_right[:, j+1]
            approx_plus = boundary_left[:, j + 2]
            max_velocity = max(abs(approx_minus), abs(approx_plus))

            # respective fluxes
            flux_minus = approx_minus**2 / 2
            flux_plus = approx_plus**2 / 2

            # local Lax-Friedrichs flux
            numerical_flux = 0.5 * (flux_minus + flux_plus - max_velocity *
                                    (approx_plus - approx_minus))

            burgers_flux.append(numerical_flux)

        # Set Ghost cells to respective value
        # (Periodic, Updated to Dirichlet in enforce_boundary_condition
        # for DiscontinuousConstant problem)
        # burgers_flux[0] = burgers_flux[self._mesh.num_cells]
        # burgers_flux.append(burgers_flux[1])

        return np.array(burgers_flux)

    def _burgers_boundary_matrix(self, projection):
        burgers_LF_flux = self._burgers_local_Lax_Friedrichs(projection)

        # Initializing boundary matrix
        boundary_matrix = [0]

        for j in range(self._mesh.num_cells):
            new_row = []
            for degree in range(self._basis.polynomial_degree + 1):
                new_row.append(np.sqrt(degree + 0.5) *
                               (-burgers_LF_flux[j + 1] + (-1)**degree *
                                burgers_LF_flux[j]))

            boundary_matrix.append(new_row)

        # Set Ghost cells to respective value
        # (Periodic, Updated to Dirichlet in enforce_boundary_condition
        # for DiscontinuousConstant problem)
        boundary_matrix[0] = boundary_matrix[self._mesh.num_cells]
        boundary_matrix.append(boundary_matrix[1])

        boundary_matrix = np.transpose(np.array(boundary_matrix))

        return boundary_matrix[0]

    @staticmethod
    def _calculate_boundary_points(projection, max_degree):
        # Approximation at j-1/2 ^+
        boundary_left = [sum(projection[degree][cell] * (-1) ** degree *
                             np.sqrt(degree + 0.5)
                             for degree in range(max_degree))
                         for cell in range(len(projection[0]))]

        # Approximation at j+1/2 ^-
        boundary_right = [sum(projection[degree][cell]* np.sqrt(degree + 0.5)
                              for degree in range(max_degree))
                          for cell in range(len(projection[0]))]

        return np.reshape(np.array(boundary_left), (1, len(boundary_left))), \
            np.reshape(np.array(boundary_right), (1, len(boundary_right)))