# -*- coding: utf-8 -*-
"""Module for polynomial basis.

@author: Laura C. Kühle

"""
from abc import ABC, abstractmethod
from functools import cache
from typing import Tuple
import numpy as np
from numpy import ndarray
from sympy import Symbol, integrate

from projection_utils import calculate_approximate_solution

x = Symbol('x')
z = Symbol('z')


class Basis(ABC):
    """Abstract class for polynomial basis.

    Attributes
    ----------
    polynomial degree : int
         Polynomial degree.
    basis : ndarray
        Array of basis.
    wavelet : ndarray
        Array of wavelet.
    inv_mass : ndarray
        Inverse mass matrix.
    basis_projection : Tuple[ndarray, ndarray]
        Two arrays containing integrals of all basis vector
        combinations evaluated on the left and right cell boundary,
        respectively.
    multiwavelet_projection : Tuple[ndarray, ndarray]
        Two arrays containing integrals of all basis vector/
        wavelet vector combinations evaluated on the left and right cell
        boundary, respectively.


    Methods
    -------
    calculate_cell_average(projection, stencil_length, add_reconstructions)
        Calculate cell averages for a given projection.
    create_data_dict()
        Return dictionary with data necessary to construct basis.

    """

    def __init__(self, polynomial_degree: int) -> None:
        """Initialize Basis.

        Parameters
        ----------
        polynomial_degree : int
            Polynomial degree.

        """
        self._polynomial_degree = polynomial_degree

    @property
    def polynomial_degree(self) -> int:
        """Return polynomial degree."""
        return self._polynomial_degree

    @property
    @cache
    def basis(self) -> ndarray:
        """Return basis vector."""
        return self._build_basis_vector(x)

    @abstractmethod
    def _build_basis_vector(self, eval_point: float) -> ndarray:
        """Construct basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        pass

    @property
    @cache
    def wavelet(self) -> ndarray:
        """Return wavelet vector."""
        return self._build_wavelet_vector(z)

    @abstractmethod
    def _build_wavelet_vector(self, eval_point: float) -> ndarray:
        """Construct wavelet vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing wavelet evaluated at evaluation point.

        """
        pass

    @property
    @cache
    def inverse_mass_matrix(self) -> ndarray:
        """Return inverse mass matrix."""
        return self._build_inverse_mass_matrix()

    @abstractmethod
    def _build_inverse_mass_matrix(self) -> ndarray:
        """Construct inverse mass matrix.

        Returns
        -------
        ndarray
            Inverse mass matrix.

        """
        pass

    @property
    @abstractmethod
    @cache
    def basis_projection(self) -> Tuple[ndarray, ndarray]:
        """Return basis projection."""
        pass

    @property
    @abstractmethod
    @cache
    def multiwavelet_projection(self) -> Tuple[ndarray, ndarray]:
        """Return wavelet projection."""
        pass

    def calculate_cell_average(self, projection: ndarray, stencil_length: int,
                               add_reconstructions: bool = True) -> ndarray:
        """Calculate cell averages for a given projection.

        Calculate the cell averages of all cells in a projection.
        If desired, reconstructions are calculated for the middle cell
        and added left and right to it, respectively.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        stencil_length : int
            Size of data array.
        add_reconstructions : bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        ndarray
            Matrix containing cell averages (and reconstructions) for given
            projection.

        """
        cell_averages = calculate_approximate_solution(
            projection, np.array([0]), 0, self.basis)

        if add_reconstructions:
            middle_idx = stencil_length // 2
            left_reconstructions, right_reconstructions = \
                self._calculate_reconstructions(
                    projection[:, middle_idx:middle_idx+1])
            return np.array(list(map(
                np.float64, zip(cell_averages[:, :middle_idx],
                                left_reconstructions,
                                cell_averages[:, middle_idx],
                                right_reconstructions,
                                cell_averages[:, middle_idx+1:]))))
        return np.array(list(map(np.float64, cell_averages)))

    def _calculate_reconstructions(self, projection: ndarray) \
            -> Tuple[ndarray, ndarray]:
        """Calculate left and right reconstructions for a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        left_reconstruction : ndarray
            List containing left reconstructions for given projection.
        right_reconstruction : ndarray
            List containing right reconstructions for given projection.

        """
        left_reconstructions = calculate_approximate_solution(
            projection, np.array([-1]), self._polynomial_degree,
            self.basis)
        right_reconstructions = calculate_approximate_solution(
            projection, np.array([1]), self._polynomial_degree,
            self.basis)
        return left_reconstructions, right_reconstructions

    def create_data_dict(self):
        """Return dictionary with data necessary to construct basis."""
        return {'polynomial_degree': self.polynomial_degree}


class Legendre(Basis):
    """Class for Legendre basis."""

    def _build_basis_vector(self, eval_point: float) -> ndarray:
        """Construct basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        return self._calculate_legendre_vector(eval_point)

    def _calculate_legendre_vector(self, eval_point: float) -> ndarray:
        """Construct Legendre vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing Legendre polynomial evaluated at evaluation
            point.

        """
        vector = []
        for degree in range(self._polynomial_degree+1):
            if degree == 0:
                vector.append(1.0 + 0*eval_point)
            else:
                if degree == 1:
                    vector.append(eval_point)
                else:
                    poly = (2.0*degree - 1)/degree * eval_point * vector[-1] \
                           - (degree-1)/degree * vector[-2]
                    vector.append(poly)
        return np.array(vector)


class OrthonormalLegendre(Legendre):
    """Class for orthonormal Legendre basis."""

    def _build_basis_vector(self, eval_point: float) -> ndarray:
        """Construct basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        leg_vector = self._calculate_legendre_vector(eval_point)
        return np.array([leg_vector[degree] * np.sqrt(degree+0.5)
                         for degree in range(self._polynomial_degree+1)])

    def _build_wavelet_vector(self, eval_point: float) -> ndarray:
        """Construct wavelet vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing wavelet evaluated at evaluation point.

        Notes
        -----
        Hardcoded version only for now.

        """
        degree = self._polynomial_degree

        if degree == 0:
            return np.array([np.sqrt(0.5) + eval_point*0])
        if degree == 1:
            return np.array([np.sqrt(1.5) * (-1 + 2*eval_point),
                             np.sqrt(0.5) * (-2 + 3*eval_point)])
        if degree == 2:
            return np.array([1/3 * np.sqrt(0.5) *
                             (1 - 24*eval_point + 30*(eval_point**2)),
                             1/2 * np.sqrt(1.5) *
                             (3 - 16*eval_point + 15*(eval_point**2)),
                             1/3 * np.sqrt(2.5) *
                             (4 - 15*eval_point + 12*(eval_point**2))])
        if degree == 3:
            return np.array([np.sqrt(15/34) *
                             (1 + 4*eval_point - 30*(eval_point**2)
                              + 28*(eval_point**3)),
                             np.sqrt(1/42) *
                             (-4 + 105*eval_point - 300*(eval_point**2)
                              + 210*(eval_point**3)),
                             1/2 * np.sqrt(35/34) *
                             (-5 + 48*eval_point - 105*(eval_point**2)
                             + 64*(eval_point**3)),
                             1/2 * np.sqrt(5/42) *
                             (-16 + 105*eval_point - 192*(eval_point**2)
                             + 105*(eval_point**3))])
        if degree == 4:
            return np.array([np.sqrt(1/186) *
                             (1 + 30*eval_point + 210*(eval_point**2)
                              - 840*(eval_point**3) + 630*(eval_point**4)),
                             0.5 * np.sqrt(1/38) *
                             (-5 - 144*eval_point + 1155*(eval_point**2)
                             - 2240*(eval_point**3) + 1260*(eval_point**4)),
                             np.sqrt(35/14694) *
                             (22 - 735*eval_point + 3504*(eval_point**2)
                             - 5460*(eval_point**3) + 2700*(eval_point**4)),
                             1/8 * np.sqrt(21/38) *
                             (35 - 512*eval_point + 1890*(eval_point**2)
                             - 2560*(eval_point**3) + 1155*(eval_point**4)),
                             0.5 * np.sqrt(7/158) *
                             (32 - 315*eval_point + 960*(eval_point**2)
                             - 1155*(eval_point**3) + 480*(eval_point**4))])

        raise ValueError('Invalid value: Alpert\'s wavelet is only available \
                         up to degree 4 for this application')

    def _build_inverse_mass_matrix(self) -> ndarray:
        """Construct inverse mass matrix.

        Returns
        -------
        ndarray
            Inverse mass matrix.

        """
        mass_matrix = []
        for i in range(self._polynomial_degree+1):
            new_row = []
            for j in range(self._polynomial_degree+1):
                new_entry = 0.0
                if i == j:
                    new_entry = 1.0
                new_row.append(new_entry)
            mass_matrix.append(new_row)
        return np.array(mass_matrix)

    @property
    @cache
    def basis_projection(self) -> Tuple[ndarray, ndarray]:
        """Return basis projection.

        Construct matrices containing the integrals of the
        product of two basis vectors for every degree combination evaluated
        at the left and right cell boundary.

        Returns
        -------
        left_basis_projection : ndarray
            Array containing the left basis projection.
        right_basis_projection : ndarray
            Array containing the right basis projection.

        """
        left_basis_projection = self._build_basis_matrix(z, 0.5 * (z - 1))
        right_basis_projection = self._build_basis_matrix(z, 0.5 * (z + 1))
        return left_basis_projection, right_basis_projection

    def _build_basis_matrix(self, first_param: float, second_param: float) \
            -> ndarray:
        """Construct a basis matrix.

        Parameters
        ----------
        first_param : float
            First parameter.
        second_param : float
            Second parameter.

        Returns
        -------
        ndarray
            Matrix containing the integral of basis products.

        """
        matrix = []
        for i in range(self._polynomial_degree + 1):
            row = []
            for j in range(self._polynomial_degree + 1):
                entry = integrate(self.basis[i].subs(x, first_param)
                                  * self.basis[j].subs(x, second_param),
                                  (z, -1, 1))
                row.append(np.float64(entry))
            matrix.append(row)
        return np.array(matrix)

    @property
    @cache
    def multiwavelet_projection(self) -> Tuple[ndarray, ndarray]:
        """Return wavelet projection.

        Construct matrices containing the integrals of the
        product of a basis vector and a wavelet vector  for every degree
        combination evaluated at the left and right cell boundary.

        Returns
        -------
        left_wavelet_projection : ndarray
            Array containing the left multiwavelet projection.
        right_wavelet_projection : ndarray
            Array containing the right multiwavelet projection.

        """
        left_wavelet_projection = self._build_multiwavelet_matrix(
            z, -0.5*(z-1), True)
        right_wavelet_projection = self._build_multiwavelet_matrix(
            z, 0.5*(z+1), False)
        return left_wavelet_projection, right_wavelet_projection

    def _build_multiwavelet_matrix(self, first_param: float,
                                   second_param: float, is_left_matrix: bool) \
            -> ndarray:
        """Construct a multiwavelet matrix.

        Parameters
        ----------
        first_param : float
            First parameter.
        second_param : float
            Second parameter.
        is_left_matrix : bool
            Flag whether the left matrix is calculated.

        Returns
        -------
        ndarray
            Matrix containing the integral of products of a basis and a wavelet
            vector.

        """
        matrix = []
        for i in range(self._polynomial_degree+1):
            row = []
            for j in range(self._polynomial_degree+1):
                entry = integrate(self.basis[i].subs(x, first_param)
                                  * self.wavelet[j].subs(z, second_param),
                                  (z, -1, 1))
                if is_left_matrix:
                    entry = entry * (-1)**(j + self._polynomial_degree + 1)
                row.append(np.float64(entry))
            matrix.append(row)
        return np.array(matrix)

    def calculate_cell_average(self, projection: ndarray, stencil_length: int,
                               add_reconstructions: bool = True) -> ndarray:
        """Calculate cell averages for a given projection.

        Calculate the cell averages of all cells in a projection.
        If desired, reconstructions are calculated for the middle cell
        and added left and right to it, respectively.

        Notes
        -----
            To increase speed, this function uses a simplified calculation
            specific to the orthonormal Legendre polynomial basis.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        stencil_length : int
            Size of data array.
        add_reconstructions : bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        ndarray
            Matrix containing cell averages (and reconstructions) for given
            projection.

        """
        cell_averages = np.array([projection[0] / np.sqrt(2)])

        if add_reconstructions:
            middle_idx = stencil_length // 2
            left_reconstructions, right_reconstructions = \
                self._calculate_reconstructions(
                    projection[:, middle_idx:middle_idx+1])
            return np.array(list(map(
                np.float64, zip(cell_averages[:, :middle_idx],
                                left_reconstructions,
                                cell_averages[:, middle_idx],
                                right_reconstructions,
                                cell_averages[:, middle_idx+1:]))))
        return np.array(list(map(np.float64, cell_averages)))

    def _calculate_reconstructions(self, projection: ndarray) \
            -> Tuple[list, list]:
        """Calculate left and right reconstructions for a given projection.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        left_reconstruction : list
            List containing left reconstructions for given projection.
        right_reconstruction : list
            List containing right reconstructions for given projection.

        Notes
        -----
            To increase speed. this function uses a simplified calculation
            specific to the orthonormal Legendre polynomial basis.

        """
        left_reconstructions = [
            sum(projection[degree][cell] * (-1)**degree
                * np.sqrt(degree + 0.5)
                for degree in range(self._polynomial_degree+1))
            for cell in range(len(projection[0]))]
        right_reconstructions = [
            sum(projection[degree][cell] * np.sqrt(degree + 0.5)
                for degree in range(self._polynomial_degree+1))
            for cell in range(len(projection[0]))]
        return left_reconstructions, right_reconstructions