# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
import numpy as np
from numpy import ndarray


def periodic_boundary(projection: ndarray, num_ghost_cells: int) -> ndarray:
    """Enforce boundary condition.

    Adjust ghost cells to ensure periodic boundary condition.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    num_ghost_cells : int
        Number of ghost cells to be adjusted.

    Returns
    -------
    projection : ndarray
        Matrix of projection for each polynomial degree.

    """
    projection[:, :num_ghost_cells] = \
        projection[:, -2 * num_ghost_cells:-num_ghost_cells]
    projection[:, -num_ghost_cells:] = \
        projection[:, num_ghost_cells:2 * num_ghost_cells]
    return projection


def dirichlet_boundary(projection: ndarray, num_ghost_cells: int) -> ndarray:
    """Enforce boundary condition.

    Adjust ghost cells to ensure Dirichlet boundary condition.

    Parameters
    ----------
    projection : ndarray
        Matrix of projection for each polynomial degree.
    num_ghost_cells : int
        Number of ghost cells to be adjusted.

    Returns
    -------
    projection : ndarray
        Matrix of projection for each polynomial degree.

    """
    pass


def enforce_boundary(num_ghost_cells=None):
    def _enforce_boundary(func):
        def boundary(self, *args, **kwargs):
            projection = func(self, *args, **kwargs)
            if self._mesh.boundary == 'periodic':
                return periodic_boundary(
                    projection=projection, num_ghost_cells=num_ghost_cells
                    if num_ghost_cells is not None
                    else self._mesh.num_ghost_cells)
            else:
                raise Exception('Not implemented!')
        return boundary
    return _enforce_boundary


def enforce_boxplot_boundaries(func):
    def boxplot_boundary(self, *args, **kwargs):
        bounds = np.array(func(self, *args, **kwargs))
        if self._mesh.boundary == 'periodic':
            return tuple(periodic_boundary(projection=bounds,
                                           num_ghost_cells=1))
        else:
            raise Exception('Not implemented!')
    return boxplot_boundary


def enforce_fold_boundaries(func):
    def fold_boundaries(self, *args, **kwargs):
        folds = np.array(func(self, *args, **kwargs))
        if self._mesh.boundary == 'periodic':
            return folds % self._mesh.num_cells
        else:
            raise Exception('Not implemented!')
    return fold_boundaries