From f42c5cdc95dd04744f5cf1a68b380d0a3424d1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Fri, 10 Mar 2023 08:59:59 +0100 Subject: [PATCH] Enforced boundary condition on Boxplot folds with decorator. --- Snakefile | 2 +- scripts/tcd/Boundary_Condition.py | 11 +++++++++++ scripts/tcd/Troubled_Cell_Detector.py | 11 ++++++----- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/Snakefile b/Snakefile index 0bf224c..d74be81 100644 --- a/Snakefile +++ b/Snakefile @@ -26,7 +26,7 @@ Urgent: TODO: Refactor Boxplot class -> Done TODO: Enforce periodic boundary condition for projection with decorator -> Done TODO: Enforce Boxplot bounds with decorator -> Done -TODO: Enforce Boxplot folds with decorator +TODO: Enforce Boxplot folds with decorator -> Done TODO: Enforce boundary for initial condition in exact solution only TODO: Adapt number of ghost cells based on ANN stencil TODO: Ensure exact solution is calculated in Equation class diff --git a/scripts/tcd/Boundary_Condition.py b/scripts/tcd/Boundary_Condition.py index 72c429a..8375491 100644 --- a/scripts/tcd/Boundary_Condition.py +++ b/scripts/tcd/Boundary_Condition.py @@ -77,3 +77,14 @@ def enforce_boxplot_boundaries(func): else: raise Exception('Not implemented!') return boxplot_boundary + + +def enforce_fold_boundaries(func): + def fold_boundaries(self, *args, **kwargs): + folds = np.array(func(self, *args, **kwargs)) + if self._mesh.boundary == 'periodic': + return folds % self._mesh.num_cells + else: + raise Exception('Not implemented!') + return fold_boundaries + diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py index d78fec7..8be6c16 100644 --- a/scripts/tcd/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -10,7 +10,8 @@ from numpy import ndarray import torch from . import ANN_Model -from .Boundary_Condition import enforce_boxplot_boundaries +from .Boundary_Condition import enforce_boxplot_boundaries, \ + enforce_fold_boundaries from .Mesh import Mesh @@ -357,6 +358,7 @@ class Boxplot(WaveletDetector): num_overlapping_cells = config.pop('num_overlapping_cells', 1) self._fold_indices = self._compute_folds(num_overlapping_cells) + @enforce_fold_boundaries def _compute_folds(self, num_overlapping_cells: int) -> ndarray: """Compute indices for all folds used in Boxplot calculation. @@ -375,10 +377,9 @@ class Boxplot(WaveletDetector): fold_indices = np.zeros([num_folds, self._fold_len + 2 * num_overlapping_cells]).astype(np.int32) for fold in range(num_folds): - fold_indices[fold] = np.array( - [i % self._mesh.num_cells for i in range( - fold * self._fold_len - num_overlapping_cells, - (fold+1) * self._fold_len + num_overlapping_cells)]) + fold_indices[fold] = np.array([i for i in range( + fold * self._fold_len - num_overlapping_cells, + (fold+1) * self._fold_len + num_overlapping_cells)]) return fold_indices def get_cells(self, projection): -- GitLab