# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

TODO: Give option to select plotting color
TODO: Add documentation to plot_boxplot()
TODO: Adjust documentation for plot_classification_accuracy()

"""
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sympy import Symbol


x = Symbol('x')
z = Symbol('z')
sns.set()


def plot_solution_and_approx(grid, exact, approx, color_exact, color_approx):
    """"Plots approximate and exact solution against each other.

    Parameters
    ----------
    grid : np.array
        List of mesh evaluation points.
    exact : np.array
        Array containing exact evaluation of a function.
    approx : np.array
        Array containing approximate evaluation of a function.
    color_exact : str
        String describing color to plot exact solution.
    color_approx : str
        String describing color to plot approximate solution.

    """
    print(color_exact, color_approx)
    plt.figure('exact_and_approx')
    plt.plot(grid[0], exact[0], color_exact)
    plt.plot(grid[0], approx[0], color_approx)
    plt.xlabel('x')
    plt.ylabel('u(x,t)')
    plt.title('Solution and Approximation')


def plot_semilog_error(grid, pointwise_error):
    """"Plots semi-logarithmic error between approximate and exact solution.

    Parameters
    ----------
    grid : np.array
        List of mesh evaluation points.
    pointwise_error : np.array
        Array containing pointwise difference between exact and approximate solution.

    """
    plt.figure('semilog_error')
    plt.semilogy(grid[0], pointwise_error[0])
    plt.xlabel('x')
    plt.ylabel('|u(x,t)-uh(x,t)|')
    plt.title('Semilog Error plotted at Evaluation points')


def plot_error(grid, exact, approx):
    """"Plots error between approximate and exact solution.

    Parameters
    ----------
    grid : np.array
        List of mesh evaluation points.
    exact : np.array
        Array containing exact evaluation of a function.
    approx : np.array
        Array containing approximate evaluation of a function.

    """
    plt.figure('error')
    plt.plot(grid[0], exact[0]-approx[0])
    plt.xlabel('X')
    plt.ylabel('u(x,t)-uh(x,t)')
    plt.title('Errors')


def plot_shock_tube(num_grid_cells, troubled_cell_history, time_history):
    """"Plots shock tube.

    Plots detected troubled cells over time to depict the evolution of shocks as shock tubes.

    Parameters
    ----------
    num_grid_cells : int
        Number of cells in the mesh. Usually exponential of 2.
    troubled_cell_history : list
        List of detected troubled cells for each time step.
    time_history:
        List of value of each time step.

    """
    plt.figure('shock_tube')
    for pos in range(len(time_history)):
        current_cells = troubled_cell_history[pos]
        for cell in current_cells:
            plt.plot(cell, time_history[pos], 'k.')
    plt.xlim((0, num_grid_cells // 2))
    plt.xlabel('Cell')
    plt.ylabel('Time')
    plt.title('Shock Tubes')


def plot_details(fine_projection, fine_mesh, coarse_projection, basis, wavelet, multiwavelet_coeffs,
                 num_coarse_grid_cells, polynomial_degree):
    """"Plots details of projection to coarser mesh..

    Parameters
    ----------
    fine_projection, coarse_projection : np.array
        Matrix of projection for each polynomial degree.
    fine_mesh : np.array
        List of evaluation points for fine mesh.
    basis : np.array
        Basis vector for calculation.
    wavelet : np.array
        Wavelet vector for calculation.
    multiwavelet_coeffs : np.array
        Matrix of multiwavelet coefficients.
    num_coarse_grid_cells : int
        Number of cells in the coarse mesh (half the cells of the fine mesh).
        Usually exponential of 2.
    polynomial_degree : int
        Polynomial degree.

    """
    averaged_projection = [[coarse_projection[degree][cell] * basis[degree].subs(x, value)
                            for cell in range(num_coarse_grid_cells)
                            for value in [-0.5, 0.5]]
                           for degree in range(polynomial_degree + 1)]

    wavelet_projection = [[multiwavelet_coeffs[degree][cell] * wavelet[degree].subs(z, 0.5) * value
                           for cell in range(num_coarse_grid_cells)
                           for value in [(-1) ** (polynomial_degree + degree + 1), 1]]
                          for degree in range(polynomial_degree + 1)]

    projected_coarse = np.sum(averaged_projection, axis=0)
    projected_fine = np.sum([fine_projection[degree] * basis[degree].subs(x, 0)
                             for degree in range(polynomial_degree + 1)], axis=0)
    projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)

    plt.figure('coeff_details')
    plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
    plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
    plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
    plt.xlabel('X')
    plt.ylabel('Detail Coefficients')
    plt.title('Wavelet Coefficients')


def calculate_approximate_solution(projection, points, polynomial_degree, basis):
    """"Calculates approximate solution.

    Parameters
    ----------
    projection : np.array
        Matrix of projection for each polynomial degree.
    points : np.array
        List of evaluation points for mesh.
    polynomial_degree : int
        Polynomial degree.
    basis : np.array
        Basis vector for calculation.

    Returns
    -------
    np.array
        Array containing approximate evaluation of a function.

    """
    num_points = len(points)

    basis_matrix = [[basis[degree].subs(x, points[point]) for point in range(num_points)]
                    for degree in range(polynomial_degree+1)]

    approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
                   for degree in range(polynomial_degree+1))
               for point in range(num_points)]
              for cell in range(len(projection[0]))]

    return np.reshape(np.array(approx), (1, len(approx) * num_points))


def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_len, quadrature,
                             init_cond):
    """Calculates exact solution.

    Parameters
    ----------
    mesh : array
        List of mesh valuation points.
    cell_len : float
        Length of a cell in mesh.
    wave_speed : float
        Speed of wave in rightward direction.
    final_time : float
        Final time for which approximation is calculated.
    interval_len : float
        Length of the interval between left and right boundary.
    quadrature : Quadrature object
        Quadrature for evaluation.
    init_cond : InitialCondition object
        Initial condition for evaluation.

    Returns
    -------
    np.array
        Array containing exact evaluation of a function.

    """
    grid = []
    exact = []
    num_periods = np.floor(wave_speed * final_time / interval_len)

    for cell in range(len(mesh)):
        eval_points = mesh[cell]+cell_len / 2 * quadrature.get_eval_points()

        eval_values = []
        for point in range(len(eval_points)):
            new_entry = init_cond.calculate(eval_points[point] - wave_speed * final_time
                                            + num_periods * interval_len)
            eval_values.append(new_entry)

        grid.append(eval_points)
        exact.append(eval_values)

    exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0])))
    grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))

    return grid, exact


def plot_classification_accuracy(evaluation_dict, colors):
    """Plots classification accuracy.

    Plots the accuracy, precision, and recall in a bar plot.

    Parameters
    ----------
    precision : float
        Precision of classification.
    recall : float
        Recall of classification.
    accuracy : float
        Accuracy of classification.
    xlabels : list
        List of strings for x-axis labels.

    """
    model_names = evaluation_dict[list(colors.keys())[0]].keys()
    font_size = 16 - (len(max(model_names, key=len))//3)
    pos = np.arange(len(model_names))
    width = 1/(3*len(model_names))
    fig = plt.figure('classification_accuracy')
    ax = fig.add_axes([0.15, 0.3, 0.75, 0.6])
    step_len = 1
    adjustment = -(len(model_names)//2)*step_len
    for measure in evaluation_dict:
        model_eval = [evaluation_dict[measure][model] for model in evaluation_dict[measure]]
        ax.bar(pos + adjustment*width, model_eval, width, label=measure, color=colors[measure])
        adjustment += step_len
    ax.set_xticks(pos)
    ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size)
    ax.set_ylabel('Classification (%)')
    ax.set_ylim(bottom=-0.02)
    ax.set_ylim(top=1.02)
    ax.set_title('Classification Evaluation (Barplot)')
    ax.legend(loc='upper right')
    # fig.tight_layout()


def plot_boxplot(evaluation_dict, colors):
    model_names = evaluation_dict[list(colors.keys())[0]].keys()
    font_size = 16 - (len(max(model_names, key=len))//3)
    fig = plt.figure('boxplot_accuracy')
    ax = fig.add_axes([0.15, 0.3, 0.75, 0.6])
    step_len = 1.5
    boxplots = []
    adjustment = -(len(model_names)//2)*step_len
    pos = np.arange(len(model_names))
    width = 1/(5*len(model_names))
    for measure in evaluation_dict:
        model_eval = [evaluation_dict[measure][model] for model in evaluation_dict[measure]]
        boxplot = ax.boxplot(model_eval, positions=pos + adjustment*width, widths=width,
                             meanline=True, showmeans=True, patch_artist=True)
        for patch in boxplot['boxes']:
            patch.set(facecolor=colors[measure])
        boxplots.append(boxplot)
        adjustment += step_len

    ax.set_xticks(pos)
    ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size)
    ax.set_ylim(bottom=-0.02)
    ax.set_ylim(top=1.02)
    ax.set_ylabel('Classification (%)')
    ax.set_title('Classification Evaluation (Boxplot)')
    ax.legend([bp["boxes"][0] for bp in boxplots], evaluation_dict.keys(), loc='upper right')