# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

TODO: Contemplate whether calculating projections during initialization can
    save time

"""
import numpy as np
from sympy import Symbol, integrate

from projection_utils import calculate_approximate_solution

x = Symbol('x')
z = Symbol('z')


class Basis:
    """Class for basis vector.

    Attributes
    ----------
    basis : ndarray
        Array of basis.
    wavelet : ndarray
        Array of wavelet.
    inv_mass: ndarray
        Inverse mass matrix.

    Methods
    -------
    get_basis_vector()
        Returns basis vector.
    get_wavelet_vector
        Returns wavelet vector.
    get_inverse_mass_matrix()
        Returns inverse mass matrix.
    get_basis_projections()
        Returns basis projections.
    get_wavelet_projections()
        Returns wavelet projections.
    calculate_cell_average(projection, stencil_length, add_reconstructions)
        Calculate cell averages for a given projection.

    """
    def __init__(self, polynomial_degree):
        """Initializes Vector.

        Parameters
        ----------
        polynomial_degree : int
            Polynomial degree.

        """
        self._polynomial_degree = polynomial_degree
        self._basis = self._build_basis_vector(x)
        self._wavelet = self._build_wavelet_vector(z)
        self._inv_mass = self._build_inverse_mass_matrix()

    def get_basis_vector(self):
        """Returns basis vector."""
        return self._basis

    def _build_basis_vector(self, eval_point):
        """Constructs basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        return []

    def get_wavelet_vector(self):
        """Returns wavelet vector."""
        return self._wavelet

    def _build_wavelet_vector(self, eval_point):
        """Constructs wavelet vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing wavelet evaluated at evaluation point.

        """
        return []

    def get_inverse_mass_matrix(self):
        """Returns inverse mass matrix."""
        return self._inv_mass

    def _build_inverse_mass_matrix(self):
        """Constructs inverse mass matrix.

        Returns
        -------
        ndarray
            Inverse mass matrix.

        """
        pass

    def get_basis_projections(self):
        """Returns basis projection."""
        pass

    def get_multiwavelet_projections(self):
        """Returns wavelet projection."""
        pass

    def calculate_cell_average(self, projection, stencil_length,
                               add_reconstructions=True):
        """Calculate cell averages for a given projection.

        Calculate the cell averages of all cells in a projection.
        If desired, reconstructions are calculated for the middle cell
        and added left and right to it, respectively.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        stencil_length : int
            Size of data array.
        add_reconstructions: bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        ndarray
            Matrix containing cell averages (and reconstructions) for given
            projection.

        """
        cell_averages = calculate_approximate_solution(
            projection, np.array([0]), 0, self._basis)

        if add_reconstructions:
            middle_idx = stencil_length // 2
            left_reconstructions, right_reconstructions = \
                self._calculate_reconstructions(
                    projection[:, middle_idx:middle_idx+1])
            return np.array(list(map(
                np.float64, zip(cell_averages[:, :middle_idx],
                                left_reconstructions,
                                cell_averages[:, middle_idx],
                                right_reconstructions,
                                cell_averages[:, middle_idx+1:]))))
        return np.array(list(map(np.float64, cell_averages)))

    def _calculate_reconstructions(self, projection):
        left_reconstructions = calculate_approximate_solution(
            projection, np.array([-1]), self._polynomial_degree, self._basis)
        right_reconstructions = calculate_approximate_solution(
            projection, np.array([1]), self._polynomial_degree, self._basis)
        return left_reconstructions, right_reconstructions


class Legendre(Basis):
    """Class for Legendre basis."""
    def _build_basis_vector(self, eval_point):
        """Constructs basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        return self._calculate_legendre_vector(eval_point)

    def _calculate_legendre_vector(self, eval_point):
        """Constructs Legendre vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing Legendre polynomial evaluated at evaluation
            point.

        """
        vector = []
        for degree in range(self._polynomial_degree+1):
            if degree == 0:
                vector.append(1.0 + 0*eval_point)
            else:
                if degree == 1:
                    vector.append(eval_point)
                else:
                    poly = (2.0*degree - 1)/degree * eval_point * vector[-1] \
                           - (degree-1)/degree * vector[-2]
                    vector.append(poly)
        return vector


class OrthonormalLegendre(Legendre):
    """Class for orthonormal Legendre basis.

    Methods
    -------
    get_basis_projection()
        Returns basis projection.
    get_wavelet_projection()
        Returns wavelet projection.

    """
    def _build_basis_vector(self, eval_point):
        """Constructs basis vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing basis evaluated at evaluation point.

        """
        leg_vector = self._calculate_legendre_vector(eval_point)
        return [leg_vector[degree] * np.sqrt(degree+0.5)
                for degree in range(self._polynomial_degree+1)]

    def _build_wavelet_vector(self, eval_point):
        """Constructs wavelet vector.

        Parameters
        ----------
        eval_point : float
            Evaluation point.

        Returns
        -------
        ndarray
            Vector containing wavelet evaluated at evaluation point.

        Notes
        -----
        Hardcoded version only for now.

        """
        degree = self._polynomial_degree

        if degree == 0:
            return [np.sqrt(0.5) + eval_point*0]
        if degree == 1:
            return [np.sqrt(1.5) * (-1 + 2*eval_point),
                    np.sqrt(0.5) * (-2 + 3*eval_point)]
        if degree == 2:
            return [1/3 * np.sqrt(0.5) *
                    (1 - 24*eval_point + 30*(eval_point**2)),
                    1/2 * np.sqrt(1.5) *
                    (3 - 16*eval_point + 15*(eval_point**2)),
                    1/3 * np.sqrt(2.5) *
                    (4 - 15*eval_point + 12*(eval_point**2))]
        if degree == 3:
            return [np.sqrt(15/34) *
                    (1 + 4*eval_point - 30*(eval_point**2)
                     + 28*(eval_point**3)),
                    np.sqrt(1/42) *
                    (-4 + 105*eval_point - 300*(eval_point**2)
                     + 210*(eval_point**3)),
                    1/2 * np.sqrt(35/34) *
                    (-5 + 48*eval_point - 105*(eval_point**2)
                     + 64*(eval_point**3)),
                    1/2 * np.sqrt(5/42) *
                    (-16 + 105*eval_point - 192*(eval_point**2)
                     + 105*(eval_point**3))]
        if degree == 4:
            return [np.sqrt(1/186) *
                    (1 + 30*eval_point + 210*(eval_point**2)
                     - 840*(eval_point**3) + 630*(eval_point**4)),
                    0.5 * np.sqrt(1/38) *
                    (-5 - 144*eval_point + 1155*(eval_point**2)
                     - 2240*(eval_point**3) + 1260*(eval_point**4)),
                    np.sqrt(35/14694) *
                    (22 - 735*eval_point + 3504*(eval_point**2)
                     - 5460*(eval_point**3) + 2700*(eval_point**4)),
                    1/8 * np.sqrt(21/38) *
                    (35 - 512*eval_point + 1890*(eval_point**2)
                     - 2560*(eval_point**3) + 1155*(eval_point**4)),
                    0.5 * np.sqrt(7/158) *
                    (32 - 315*eval_point + 960*(eval_point**2)
                     - 1155*(eval_point**3) + 480*(eval_point**4))]

        raise ValueError('Invalid value: Alpert\'s wavelet is only available \
                         up to degree 4 for this application')

    def _build_inverse_mass_matrix(self):
        mass_matrix = []
        for i in range(self._polynomial_degree+1):
            new_row = []
            for j in range(self._polynomial_degree+1):
                new_entry = 0.0
                if i == j:
                    new_entry = 1.0
                new_row.append(new_entry)
            mass_matrix.append(new_row)
        return np.array(mass_matrix)

    def get_basis_projections(self):
        """Returns basis projection.

        Returns
        -------
        ndarray
            Array containing the basis projection based on the integrals of
            the product of two basis vectors for each degree combination.

        """
        basis_projection_left = self._build_basis_matrix(z, 0.5 * (z - 1))
        basis_projection_right = self._build_basis_matrix(z, 0.5 * (z + 1))
        return basis_projection_left, basis_projection_right

    def _build_basis_matrix(self, first_param, second_param):
        """Constructs a basis matrix.

        Parameters
        ----------
        first_param : float
            First parameter.
        second_param : float
            Second parameter.

        Returns
        -------
        ndarray
            Matrix containing the integral of basis products.

        """
        matrix = []
        for i in range(self._polynomial_degree + 1):
            row = []
            for j in range(self._polynomial_degree + 1):
                entry = integrate(self._basis[i].subs(x, first_param)
                                  * self._basis[j].subs(x, second_param),
                                  (z, -1, 1))
                row.append(np.float64(entry))
            matrix.append(row)
        return matrix

    def get_multiwavelet_projections(self):
        """Returns wavelet projection.

        Returns
        -------
        ndarray
            Array containing the multiwavelet projection based on the integrals
            of the product of a basis vector and a wavelet vector for each
            degree combination.

        """
        wavelet_projection_left = self._build_multiwavelet_matrix(
            z, -0.5*(z-1), True)
        wavelet_projection_right = self._build_multiwavelet_matrix(
            z, 0.5*(z+1), False)
        return wavelet_projection_left, wavelet_projection_right

    def _build_multiwavelet_matrix(self, first_param, second_param,
                                   is_left_matrix):
        """Constructs a multiwavelet matrix.

        Parameters
        ----------
        first_param : float
            First parameter.
        second_param : float
            Second parameter.
        is_left_matrix : bool
            Flag whether the left matrix is calculated.

        Returns
        -------
        ndarray
            Matrix containing the integral of products of a basis and a wavelet
            vector.

        """
        matrix = []
        for i in range(self._polynomial_degree+1):
            row = []
            for j in range(self._polynomial_degree+1):
                entry = integrate(self._basis[i].subs(x, first_param)
                                  * self._wavelet[j].subs(z, second_param),
                                  (z, -1, 1))
                if is_left_matrix:
                    entry = entry * (-1)**(j + self._polynomial_degree + 1)
                row.append(np.float64(entry))
            matrix.append(row)
        return matrix

    def calculate_cell_average(self, projection, stencil_length,
                               add_reconstructions=True):
        """Calculate cell averages for a given projection.

        Calculate the cell averages of all cells in a projection.
        If desired, reconstructions are calculated for the middle cell
        and added left and right to it, respectively.

        Notes
        -----
            To increase speed. this function uses a simplified calculation
            specific to the orthonormal Legendre polynomial basis.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.
        stencil_length : int
            Size of data array.
        add_reconstructions: bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        ndarray
            Matrix containing cell averages (and reconstructions) for given
            projection.

        """

        cell_averages = np.array([projection[0] / np.sqrt(2)])

        if add_reconstructions:
            middle_idx = stencil_length // 2
            left_reconstructions, right_reconstructions = \
                self._calculate_reconstructions(
                    projection[:, middle_idx:middle_idx+1])
            return np.array(list(map(
                np.float64, zip(cell_averages[:, :middle_idx],
                                left_reconstructions,
                                cell_averages[:, middle_idx],
                                right_reconstructions,
                                cell_averages[:, middle_idx+1:]))))
        return np.array(list(map(np.float64, cell_averages)))

    def _calculate_reconstructions(self, projection):
        """Calculate left and right reconstructions for a given projection.

        Notes
        -----
            To increase speed. this function uses a simplified calculation
            specific to the orthonormal Legendre polynomial basis.

        Parameters
        ----------
        projection : ndarray
            Matrix of projection for each polynomial degree.

        Returns
        -------
        left_reconstruction: list
            List containing left reconstructions for given projection.
        right_reconstruction: list
            List containing right reconstructions for given projection.

        """

        left_reconstructions = [
            sum(projection[degree][cell] * (-1)**degree
                * np.sqrt(degree + 0.5)
                for degree in range(self._polynomial_degree+1))
            for cell in range(len(projection[0]))]
        right_reconstructions = [
            sum(projection[degree][cell] * np.sqrt(degree + 0.5)
                for degree in range(self._polynomial_degree+1))
            for cell in range(len(projection[0]))]
        return left_reconstructions, right_reconstructions