From f42c5cdc95dd04744f5cf1a68b380d0a3424d1a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Fri, 10 Mar 2023 08:59:59 +0100
Subject: [PATCH] Enforced boundary condition on Boxplot folds with decorator.

---
 Snakefile                             |  2 +-
 scripts/tcd/Boundary_Condition.py     | 11 +++++++++++
 scripts/tcd/Troubled_Cell_Detector.py | 11 ++++++-----
 3 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/Snakefile b/Snakefile
index 0bf224c..d74be81 100644
--- a/Snakefile
+++ b/Snakefile
@@ -26,7 +26,7 @@ Urgent:
 TODO: Refactor Boxplot class -> Done
 TODO: Enforce periodic boundary condition for projection with decorator -> Done
 TODO: Enforce Boxplot bounds with decorator -> Done
-TODO: Enforce Boxplot folds with decorator
+TODO: Enforce Boxplot folds with decorator -> Done
 TODO: Enforce boundary for initial condition in exact solution only
 TODO: Adapt number of ghost cells based on ANN stencil
 TODO: Ensure exact solution is calculated in Equation class
diff --git a/scripts/tcd/Boundary_Condition.py b/scripts/tcd/Boundary_Condition.py
index 72c429a..8375491 100644
--- a/scripts/tcd/Boundary_Condition.py
+++ b/scripts/tcd/Boundary_Condition.py
@@ -77,3 +77,14 @@ def enforce_boxplot_boundaries(func):
         else:
             raise Exception('Not implemented!')
     return boxplot_boundary
+
+
+def enforce_fold_boundaries(func):
+    def fold_boundaries(self, *args, **kwargs):
+        folds = np.array(func(self, *args, **kwargs))
+        if self._mesh.boundary == 'periodic':
+            return folds % self._mesh.num_cells
+        else:
+            raise Exception('Not implemented!')
+    return fold_boundaries
+
diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py
index d78fec7..8be6c16 100644
--- a/scripts/tcd/Troubled_Cell_Detector.py
+++ b/scripts/tcd/Troubled_Cell_Detector.py
@@ -10,7 +10,8 @@ from numpy import ndarray
 import torch
 
 from . import ANN_Model
-from .Boundary_Condition import enforce_boxplot_boundaries
+from .Boundary_Condition import enforce_boxplot_boundaries, \
+    enforce_fold_boundaries
 from .Mesh import Mesh
 
 
@@ -357,6 +358,7 @@ class Boxplot(WaveletDetector):
         num_overlapping_cells = config.pop('num_overlapping_cells', 1)
         self._fold_indices = self._compute_folds(num_overlapping_cells)
 
+    @enforce_fold_boundaries
     def _compute_folds(self, num_overlapping_cells: int) -> ndarray:
         """Compute indices for all folds used in Boxplot calculation.
 
@@ -375,10 +377,9 @@ class Boxplot(WaveletDetector):
         fold_indices = np.zeros([num_folds, self._fold_len + 2 *
                                  num_overlapping_cells]).astype(np.int32)
         for fold in range(num_folds):
-            fold_indices[fold] = np.array(
-                [i % self._mesh.num_cells for i in range(
-                    fold * self._fold_len - num_overlapping_cells,
-                    (fold+1) * self._fold_len + num_overlapping_cells)])
+            fold_indices[fold] = np.array([i for i in range(
+                fold * self._fold_len - num_overlapping_cells,
+                (fold+1) * self._fold_len + num_overlapping_cells)])
         return fold_indices
 
     def get_cells(self, projection):
-- 
GitLab