Skip to content
Snippets Groups Projects
Commit f42c5cdc authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Enforced boundary condition on Boxplot folds with decorator.

parent 5b26530e
Branches
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ Urgent: ...@@ -26,7 +26,7 @@ Urgent:
TODO: Refactor Boxplot class -> Done TODO: Refactor Boxplot class -> Done
TODO: Enforce periodic boundary condition for projection with decorator -> Done TODO: Enforce periodic boundary condition for projection with decorator -> Done
TODO: Enforce Boxplot bounds with decorator -> Done TODO: Enforce Boxplot bounds with decorator -> Done
TODO: Enforce Boxplot folds with decorator TODO: Enforce Boxplot folds with decorator -> Done
TODO: Enforce boundary for initial condition in exact solution only TODO: Enforce boundary for initial condition in exact solution only
TODO: Adapt number of ghost cells based on ANN stencil TODO: Adapt number of ghost cells based on ANN stencil
TODO: Ensure exact solution is calculated in Equation class TODO: Ensure exact solution is calculated in Equation class
......
...@@ -77,3 +77,14 @@ def enforce_boxplot_boundaries(func): ...@@ -77,3 +77,14 @@ def enforce_boxplot_boundaries(func):
else: else:
raise Exception('Not implemented!') raise Exception('Not implemented!')
return boxplot_boundary return boxplot_boundary
def enforce_fold_boundaries(func):
def fold_boundaries(self, *args, **kwargs):
folds = np.array(func(self, *args, **kwargs))
if self._mesh.boundary == 'periodic':
return folds % self._mesh.num_cells
else:
raise Exception('Not implemented!')
return fold_boundaries
...@@ -10,7 +10,8 @@ from numpy import ndarray ...@@ -10,7 +10,8 @@ from numpy import ndarray
import torch import torch
from . import ANN_Model from . import ANN_Model
from .Boundary_Condition import enforce_boxplot_boundaries from .Boundary_Condition import enforce_boxplot_boundaries, \
enforce_fold_boundaries
from .Mesh import Mesh from .Mesh import Mesh
...@@ -357,6 +358,7 @@ class Boxplot(WaveletDetector): ...@@ -357,6 +358,7 @@ class Boxplot(WaveletDetector):
num_overlapping_cells = config.pop('num_overlapping_cells', 1) num_overlapping_cells = config.pop('num_overlapping_cells', 1)
self._fold_indices = self._compute_folds(num_overlapping_cells) self._fold_indices = self._compute_folds(num_overlapping_cells)
@enforce_fold_boundaries
def _compute_folds(self, num_overlapping_cells: int) -> ndarray: def _compute_folds(self, num_overlapping_cells: int) -> ndarray:
"""Compute indices for all folds used in Boxplot calculation. """Compute indices for all folds used in Boxplot calculation.
...@@ -375,8 +377,7 @@ class Boxplot(WaveletDetector): ...@@ -375,8 +377,7 @@ class Boxplot(WaveletDetector):
fold_indices = np.zeros([num_folds, self._fold_len + 2 * fold_indices = np.zeros([num_folds, self._fold_len + 2 *
num_overlapping_cells]).astype(np.int32) num_overlapping_cells]).astype(np.int32)
for fold in range(num_folds): for fold in range(num_folds):
fold_indices[fold] = np.array( fold_indices[fold] = np.array([i for i in range(
[i % self._mesh.num_cells for i in range(
fold * self._fold_len - num_overlapping_cells, fold * self._fold_len - num_overlapping_cells,
(fold+1) * self._fold_len + num_overlapping_cells)]) (fold+1) * self._fold_len + num_overlapping_cells)])
return fold_indices return fold_indices
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment