From d05acbc2e2632135983898c841d3b87e0fe63eea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Thu, 9 Mar 2023 20:16:46 +0100
Subject: [PATCH] Refactored Boxplot class.

---
 Snakefile                             |  7 +-
 scripts/tcd/Troubled_Cell_Detector.py | 95 +++++++++++++++++++--------
 2 files changed, 72 insertions(+), 30 deletions(-)

diff --git a/Snakefile b/Snakefile
index fd059c4..615f4a7 100644
--- a/Snakefile
+++ b/Snakefile
@@ -23,7 +23,12 @@ TODO: Contemplate allowing vector input for ICs
 TODO: Discuss how wavelet details should be plotted
 
 Urgent:
-TODO: Move boundary condition to Mesh class
+TODO: Refactor Boxplot class -> Done
+TODO: Enforce boundary condition with decorator
+TODO: Enforce Boxplot bounds with decorator
+TODO: Enforce Boxplot folds with decorator
+TODO: Enforce boundary for initial condition in exact solution only
+TODO: Adapt number of ghost cells based on ANN stencil
 TODO: Ensure exact solution is calculated in Equation class
 TODO: Extract objects from UpdateScheme
 TODO: Enforce num_ghost_cells to be positive integer for DG (not training)
diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py
index 805fdd8..69c2162 100644
--- a/scripts/tcd/Troubled_Cell_Detector.py
+++ b/scripts/tcd/Troubled_Cell_Detector.py
@@ -3,8 +3,10 @@
 @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
 
 """
+from typing import Tuple
 from abc import ABC, abstractmethod
 import numpy as np
+from numpy import ndarray
 import torch
 
 from . import ANN_Model
@@ -344,25 +346,39 @@ class Boxplot(WaveletDetector):
         super()._reset(config)
 
         # Unpack necessary configurations
-        self._fold_len = config.pop('fold_len', 16)
+        self._fold_len = min(self._mesh.num_cells,
+                             config.pop('fold_len', 16))
         self._whisker_len = config.pop('whisker_len', 3)
         self._adjust_outer_fences = config.pop('adjust_outer_fences', True)
         self._extreme_outlier_only = config.pop('extreme_outlier_only', True)
 
-        if self._mesh.num_cells < self._fold_len:
-            self._fold_len = self._mesh.num_cells
-
         self._quantile_method = config.pop('quantile_method', 'weibull')
         num_overlapping_cells = config.pop('num_overlapping_cells', 1)
+        self._fold_indices = self._compute_folds(num_overlapping_cells)
+
+    def _compute_folds(self, num_overlapping_cells: int) -> ndarray:
+        """Compute indices for all folds used in Boxplot calculation.
+
+        Parameters
+        ----------
+        num_overlapping_cells : int
+            Number of cells overlapping between adjacent folds.
+
+        Returns
+        -------
+        fold_indices : ndarray
+            Array of projection indices in each fold.
+
+        """
         num_folds = self._mesh.num_cells//self._fold_len
-        self._fold_indices = np.zeros([num_folds,
-                                       self._fold_len + 2 *
-                                       num_overlapping_cells]).astype(np.int32)
+        fold_indices = np.zeros([num_folds, self._fold_len + 2 *
+                                 num_overlapping_cells]).astype(np.int32)
         for fold in range(num_folds):
-            self._fold_indices[fold] = np.array(
+            fold_indices[fold] = np.array(
                 [i % self._mesh.num_cells for i in range(
                     fold * self._fold_len - num_overlapping_cells,
                     (fold+1) * self._fold_len + num_overlapping_cells)])
+        return fold_indices
 
     def get_cells(self, projection):
         """Calculate troubled cells in a given projection.
@@ -378,29 +394,10 @@ class Boxplot(WaveletDetector):
             List of indices for all detected troubled cells.
 
         """
-        # Determine quartiles of folds
+        # Determine bounds
         coeffs = self._select_degree(self._calculate_wavelet_coeffs(
             projection))
-        folds = coeffs[self._fold_indices]
-        first_quartiles = np.quantile(folds, 0.25, axis=1,
-                                      method=self._quantile_method)
-        third_quartiles = np.quantile(folds, 0.75, axis=1,
-                                      method=self._quantile_method)
-
-        # Determine bounds based on quartiles of a boxplot
-        lower_bounds = np.zeros(len(first_quartiles) + 2)
-        upper_bounds = np.zeros(len(first_quartiles) + 2)
-
-        lower_bounds[1:-1] = first_quartiles - self._whisker_len * (
-                third_quartiles-first_quartiles)
-        upper_bounds[1:-1] = third_quartiles + self._whisker_len * (
-                third_quartiles-first_quartiles)
-
-        # Adjust bounds to capture periodic boundary
-        lower_bounds[0] = lower_bounds[-2]
-        lower_bounds[-1] = lower_bounds[1]
-        upper_bounds[0] = upper_bounds[-2]
-        upper_bounds[-1] = upper_bounds[1]
+        lower_bounds, upper_bounds = self._compute_bounds(coeffs)
 
         # Adjust outer fences if flag is set
         if self._adjust_outer_fences:
@@ -426,6 +423,46 @@ class Boxplot(WaveletDetector):
 
         return troubled_cells
 
+    def _compute_bounds(self, coeffs: ndarray) -> Tuple[ndarray, ndarray]:
+        """Compute lower and upper bound for Boxplot outliers.
+
+        Parameters
+        ----------
+        coeffs : ndarray
+            Matrix of multiwavelet coefficients of projection.
+
+        Returns
+        -------
+        lower_bounds : ndarray
+            Array of lower bounds for outlier.
+        upper_bounds : ndarray
+            Array of upper bounds for outlier.
+
+        """
+        # Determine quartiles of folds
+        folds = coeffs[self._fold_indices]
+        first_quartiles = np.quantile(folds, 0.25, axis=1,
+                                      method=self._quantile_method)
+        third_quartiles = np.quantile(folds, 0.75, axis=1,
+                                      method=self._quantile_method)
+
+        # Determine bounds based on quartiles of a boxplot
+        lower_bounds = np.zeros(len(first_quartiles) + 2)
+        upper_bounds = np.zeros(len(first_quartiles) + 2)
+
+        lower_bounds[1:-1] = first_quartiles - self._whisker_len * (
+                third_quartiles-first_quartiles)
+        upper_bounds[1:-1] = third_quartiles + self._whisker_len * (
+                third_quartiles-first_quartiles)
+
+        # Adjust bounds to capture periodic boundary
+        lower_bounds[0] = lower_bounds[-2]
+        lower_bounds[-1] = lower_bounds[1]
+        upper_bounds[0] = upper_bounds[-2]
+        upper_bounds[-1] = upper_bounds[1]
+
+        return lower_bounds, upper_bounds
+
 
 class Theoretical(WaveletDetector):
     """Class for troubled-cell detection based on theoretical thresholding.
-- 
GitLab