Skip to content
Snippets Groups Projects
Commit f42c5cdc authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Enforced boundary condition on Boxplot folds with decorator.

parent 5b26530e
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ Urgent:
TODO: Refactor Boxplot class -> Done
TODO: Enforce periodic boundary condition for projection with decorator -> Done
TODO: Enforce Boxplot bounds with decorator -> Done
TODO: Enforce Boxplot folds with decorator
TODO: Enforce Boxplot folds with decorator -> Done
TODO: Enforce boundary for initial condition in exact solution only
TODO: Adapt number of ghost cells based on ANN stencil
TODO: Ensure exact solution is calculated in Equation class
......
......@@ -77,3 +77,14 @@ def enforce_boxplot_boundaries(func):
else:
raise Exception('Not implemented!')
return boxplot_boundary
def enforce_fold_boundaries(func):
def fold_boundaries(self, *args, **kwargs):
folds = np.array(func(self, *args, **kwargs))
if self._mesh.boundary == 'periodic':
return folds % self._mesh.num_cells
else:
raise Exception('Not implemented!')
return fold_boundaries
......@@ -10,7 +10,8 @@ from numpy import ndarray
import torch
from . import ANN_Model
from .Boundary_Condition import enforce_boxplot_boundaries
from .Boundary_Condition import enforce_boxplot_boundaries, \
enforce_fold_boundaries
from .Mesh import Mesh
......@@ -357,6 +358,7 @@ class Boxplot(WaveletDetector):
num_overlapping_cells = config.pop('num_overlapping_cells', 1)
self._fold_indices = self._compute_folds(num_overlapping_cells)
@enforce_fold_boundaries
def _compute_folds(self, num_overlapping_cells: int) -> ndarray:
"""Compute indices for all folds used in Boxplot calculation.
......@@ -375,8 +377,7 @@ class Boxplot(WaveletDetector):
fold_indices = np.zeros([num_folds, self._fold_len + 2 *
num_overlapping_cells]).astype(np.int32)
for fold in range(num_folds):
fold_indices[fold] = np.array(
[i % self._mesh.num_cells for i in range(
fold_indices[fold] = np.array([i for i in range(
fold * self._fold_len - num_overlapping_cells,
(fold+1) * self._fold_len + num_overlapping_cells)])
return fold_indices
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment