# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""

from typing import Tuple
import numpy as np
from numpy import ndarray
from sympy import Symbol

from Quadrature import Quadrature
from Initial_Condition import InitialCondition


x = Symbol('x')


def calculate_approximate_solution(
        projection: ndarray, points: ndarray, polynomial_degree: int,
        basis: ndarray) -> ndarray:
    """Calculate approximate solution.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    points : ndarray
        List of evaluation points for mesh.
    polynomial_degree : int
        Polynomial degree.
    basis : ndarray
        Basis vector for calculation.

    Returns
    -------
    ndarray
        Array containing approximate evaluation of a function.

    """
    num_points = len(points)

    basis_matrix = [[basis[degree].subs(x, points[point])
                     for point in range(num_points)]
                    for degree in range(polynomial_degree+1)]

    approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
                   for degree in range(polynomial_degree+1))
               for point in range(num_points)]
              for cell in range(len(projection[0]))]

    return np.reshape(np.array(approx), (1, len(approx) * num_points))


def calculate_exact_solution(
        mesh: ndarray, cell_len: float, wave_speed: float, final_time:
        float, interval_len: float, quadrature: Quadrature, init_cond:
        InitialCondition) -> Tuple[ndarray, ndarray]:
    """Calculate exact solution.

    Parameters
    ----------
    mesh : ndarray
        List of mesh evaluation points.
    cell_len : float
        Length of a cell in mesh.
    wave_speed : float
        Speed of wave in rightward direction.
    final_time : float
        Final time for which approximation is calculated.
    interval_len : float
        Length of the interval between left and right boundary.
    quadrature : Quadrature object
        Quadrature for evaluation.
    init_cond : InitialCondition object
        Initial condition for evaluation.

    Returns
    -------
    grid : ndarray
        Array containing evaluation grid for a function.
    exact : ndarray
        Array containing exact evaluation of a function.

    """
    grid = []
    exact = []
    num_periods = np.floor(wave_speed * final_time / interval_len)

    for cell_center in mesh:
        eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points()

        eval_values = []
        for eval_point in eval_points:
            new_entry = init_cond.calculate(eval_point
                                            - wave_speed * final_time
                                            + num_periods * interval_len)
            eval_values.append(new_entry)

        grid.append(eval_points)
        exact.append(eval_values)

    exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0])))
    grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))

    return grid, exact