# -*- coding: utf-8 -*- """ @author: Laura C. Kühle """ import numpy as np from numpy import ndarray def periodic_boundary(projection: ndarray, num_ghost_cells: int) -> ndarray: """Enforce boundary condition. Adjust ghost cells to ensure periodic boundary condition. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. num_ghost_cells : int Number of ghost cells to be adjusted. Returns ------- projection : ndarray Matrix of projection for each polynomial degree. """ projection[:, :num_ghost_cells] = \ projection[:, -2 * num_ghost_cells:-num_ghost_cells] projection[:, -num_ghost_cells:] = \ projection[:, num_ghost_cells:2 * num_ghost_cells] return projection def dirichlet_boundary(projection: ndarray, num_ghost_cells: int) -> ndarray: """Enforce boundary condition. Adjust ghost cells to ensure Dirichlet boundary condition. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. num_ghost_cells : int Number of ghost cells to be adjusted. Returns ------- projection : ndarray Matrix of projection for each polynomial degree. """ pass def enforce_boundary(num_ghost_cells=None): def _enforce_boundary(func): def boundary(self, *args, **kwargs): projection = func(self, *args, **kwargs) if self._mesh.boundary == 'periodic': return periodic_boundary( projection=projection, num_ghost_cells=num_ghost_cells if num_ghost_cells is not None else self._mesh.num_ghost_cells) else: raise Exception('Not implemented!') return boundary return _enforce_boundary def enforce_boxplot_boundaries(func): def boxplot_boundary(self, *args, **kwargs): bounds = np.array(func(self, *args, **kwargs)) if self._mesh.boundary == 'periodic': return tuple(periodic_boundary(projection=bounds, num_ghost_cells=1)) else: raise Exception('Not implemented!') return boxplot_boundary def enforce_fold_boundaries(func): def fold_boundaries(self, *args, **kwargs): folds = np.array(func(self, *args, **kwargs)) if self._mesh.boundary == 'periodic': return folds % self._mesh.num_cells else: raise Exception('Not implemented!') return fold_boundaries