# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""

from typing import Tuple
import numpy as np
from numpy import ndarray
from sympy import Symbol, lambdify

from .Mesh import Mesh
from .Quadrature import Quadrature
from .Initial_Condition import InitialCondition


x = Symbol('x')


def calculate_approximate_solution(
        projection: ndarray, points: ndarray, polynomial_degree: int,
        basis: ndarray) -> ndarray:
    """Calculate approximate solution.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    points : ndarray
        List of evaluation points for mesh.
    polynomial_degree : int
        Polynomial degree.
    basis : ndarray
        Basis vector for calculation.

    Returns
    -------
    ndarray
        Array containing approximate evaluation of a function.

    """
    basis_matrix = [[basis[degree].subs(x, points[point])
                     for point in range(len(points))]
                    for degree in range(polynomial_degree+1)]

    # basis = np.array([np.vectorize(lambdify(x, function, 'numpy'))
    #                   for function in basis])
    # basis_matrix = np.array([basis[degree](np.array(points)) for degree in
    #                         range(polynomial_degree+1)])

    approx = projection.T@basis_matrix
    return np.reshape(approx, (1, approx.size))


def calculate_exact_solution(
        mesh: Mesh, wave_speed: float, final_time:
        float, quadrature: Quadrature,
        init_cond: InitialCondition) -> Tuple[ndarray, ndarray]:
    """Calculate exact solution.

    Parameters
    ----------
    mesh : Mesh
        Mesh for evaluation.
    wave_speed : float
        Speed of wave in rightward direction.
    final_time : float
        Final time for which approximation is calculated.
    quadrature : Quadrature object
        Quadrature for evaluation.
    init_cond : InitialCondition object
        Initial condition for evaluation.

    Returns
    -------
    grid : ndarray
        Array containing evaluation grid for a function.
    exact : ndarray
        Array containing exact evaluation of a function.

    """
    num_periods = np.floor(wave_speed * final_time / mesh.interval_len)

    grid = np.repeat(mesh.non_ghost_cells, quadrature.num_nodes) + \
        mesh.cell_len/2 * np.tile(quadrature.nodes, mesh.num_cells)
    exact = np.array([init_cond.calculate(
        mesh=mesh, x=point-wave_speed*final_time+num_periods*mesh.interval_len)
        for point in grid])

    grid = np.reshape(grid, (1, grid.size))
    exact = np.reshape(exact, (1, exact.size))

    return grid, exact