# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
import numpy as np
from numpy import ndarray


def periodic_boundary(projection: ndarray, config: dict) -> ndarray:
    """Enforce boundary condition.

    Adjust ghost cells to ensure periodic boundary condition.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    config : dict
        Configuration dictionary.

    Returns
    -------
    projection : ndarray
        Matrix of projection for each polynomial degree.

    """
    num_ghost_cells = config['num_ghost_cells']

    projection[:, :num_ghost_cells] = \
        projection[:, -2 * num_ghost_cells:-num_ghost_cells]
    projection[:, -num_ghost_cells:] = \
        projection[:, num_ghost_cells:2 * num_ghost_cells]
    return projection


def dirichlet_boundary(projection: ndarray, config: dict) -> ndarray:
    """Enforce boundary condition.

    Adjust ghost cells to ensure Dirichlet boundary condition.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    config : dict
        Configuration dictionary.

    Returns
    -------
    projection : ndarray
        Matrix of projection for each polynomial degree.

    """
    polynomial_degree = config['polynomial_degree']
    final_time = config['final_time']
    num_ghost_cells = config['num_ghost_cells']
    left_factor = config['left_factor']
    right_factor = config['right_factor']

    projection_values_left = [0 for _ in range(polynomial_degree+1)]
    projection_values_right = [0 for _ in range(polynomial_degree+1)]
    if final_time > 1:
        projection_values_left[0] = -np.sqrt(2) / final_time
        projection_values_right[0] = np.sqrt(2) / final_time
    else:
        projection_values_left[0] = np.sqrt(2) * left_factor
        projection_values_right[0] = np.sqrt(2) * right_factor

    projection[:, :num_ghost_cells] = np.repeat(
        projection_values_left, num_ghost_cells).reshape(-1, num_ghost_cells)
    projection[:, -num_ghost_cells:] = np.repeat(
        projection_values_right, num_ghost_cells).reshape(-1, num_ghost_cells)
    # projection[:, 0] = projection_values_left
    # projection[:, 1] = projection_values_left
    # projection[:, -2] = projection_values_right
    # projection[:, -1] = projection_values_right
    return projection


def enforce_boundary(num_ghost_cells=None):
    def _enforce_boundary(func):
        def boundary(self, *args, **kwargs):
            projection = func(self, *args, **kwargs)
            if self._mesh.boundary_config['type'] == 'periodic':
                config = self._mesh.boundary_config.copy()
                if num_ghost_cells is not None:
                    config['num_ghost_cells'] = num_ghost_cells
                return periodic_boundary(
                    projection=projection, config=config)
            elif self._mesh.boundary_config['type'] == 'dirichlet':
                return dirichlet_boundary(projection=projection,
                                          config=self._mesh.boundary_config)
            else:
                raise Exception('Not implemented!')
        return boundary
    return _enforce_boundary


def enforce_boxplot_boundaries(func):
    def boxplot_boundary(self, *args, **kwargs):
        bounds = np.array(func(self, *args, **kwargs))
        if self._mesh.boundary_config['type'] == 'periodic':
            return tuple(periodic_boundary(projection=bounds,
                                           config=self._mesh.boundary_config))
        else:
            raise Exception('Not implemented!')
    return boxplot_boundary


def enforce_fold_boundaries(func):
    def fold_boundaries(self, *args, **kwargs):
        folds = np.array(func(self, *args, **kwargs))
        if self._mesh.boundary_config['type'] == 'periodic':
            return folds % self._mesh.num_cells
        else:
            raise Exception('Not implemented!')
    return fold_boundaries