From 7b32def237a3910bf787cf2bec03e4e5ec99a08c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Mon, 14 Nov 2022 00:14:39 +0100
Subject: [PATCH] Vectorized 'get_cells()' for ANN detector.

---
 Snakefile                             | 13 +++++++------
 scripts/tcd/Basis_Function.py         | 13 ++++++-------
 scripts/tcd/Troubled_Cell_Detector.py | 17 +++++++++--------
 3 files changed, 22 insertions(+), 21 deletions(-)

diff --git a/Snakefile b/Snakefile
index e4fbf82..171ac8a 100644
--- a/Snakefile
+++ b/Snakefile
@@ -21,22 +21,20 @@ TODO: Discuss descriptions (matrices, cfl number, right-hand side,
 TODO: Discuss referencing info on SSPRK3
 TODO: Discuss name for quadrature mesh (now: grid)
 TODO: Contemplate using lambdify for basis
+TODO: Contemplate allowing vector input for ICs
 TODO: Discuss how wavelet details should be plotted
 
 Urgent:
 TODO: Vectorize 'plot_details()' -> Done
 TODO: Remove unnecessary dimension during cell average calculation -> Done
-TODO: Replace loops/list comprehension with vectorization if feasible
-TODO: Replace loops with list comprehension if feasible
-TODO: Rework ICs to allow vector input
+TODO: Vectorize 'get_cells()' for ANN detector -> Done
+TODO: Replace loops/list comprehension with vectorization if feasible -> Done
+TODO: Replace loops with list comprehension if feasible -> Done
 TODO: Check whether 'projection' is always a ndarray
 TODO: Check whether ghost cells are handled/set correctly
 TODO: Enforce even number of ghost cells on each side on fine mesh (?)
 TODO: Check whether all instance variables are sensible
-TODO: Investigate g-mesh(?)
-TODO: Create g-mesh with Mesh class
 TODO: Combine ANN workflows if feasible
-TODO: Investigate profiling for speed up
 TODO: Make sure that the cell indices are the same over all TCDs
 TODO: Make sure TCs are reported as ndarray
 
@@ -49,6 +47,8 @@ TODO: Allow comparison between ANN training datasets
 TODO: Use full path for ANN model state
 TODO: Add a default model state
 TODO: Look into validators for variable checks
+TODO: Investigate g-mesh(?)
+TODO: Create g-mesh with Mesh class
 
 Not feasible yet or doc-related:
 TODO: Move plot_approximation_results() into plotting script
@@ -74,6 +74,7 @@ TODO: Check whether documentation style is correct
 TODO: Check whether all types in doc are correct
 TODO: Add type annotations to function heads
 TODO: Clean up docstrings
+TODO: Investigate profiling for speed up
 
 """
 
diff --git a/scripts/tcd/Basis_Function.py b/scripts/tcd/Basis_Function.py
index 2b3c809..4144fb9 100644
--- a/scripts/tcd/Basis_Function.py
+++ b/scripts/tcd/Basis_Function.py
@@ -507,13 +507,12 @@ class OrthonormalLegendre(Legendre):
             left_reconstructions, right_reconstructions = \
                 self._calculate_reconstructions(
                     projection[:, middle_idx:middle_idx+1])
-            return np.array(list(map(
-                np.float64, zip(cell_averages[:middle_idx],
-                                left_reconstructions,
-                                cell_averages[middle_idx:middle_idx+1],
-                                right_reconstructions,
-                                cell_averages[middle_idx+1:]))))
-        return np.array(list(map(np.float64, cell_averages)))
+            return np.hstack([cell_averages[:middle_idx].T,
+                              left_reconstructions,
+                              cell_averages[middle_idx:middle_idx+1].T,
+                              right_reconstructions,
+                              cell_averages[middle_idx+1:].T])
+        return cell_averages
 
     def _calculate_reconstructions(self, projection: ndarray) \
             -> Tuple[list, list]:
diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py
index d33a51a..175fd9c 100644
--- a/scripts/tcd/Troubled_Cell_Detector.py
+++ b/scripts/tcd/Troubled_Cell_Detector.py
@@ -135,6 +135,11 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
 
         self._stencil_len = config.pop('stencil_len', 3)
         self._add_reconstructions = config.pop('add_reconstructions', True)
+
+        # Set mask for stencil sliding window
+        self._window_mask = np.arange(self._mesh.num_cells)[None, :] + \
+                            np.arange(self._stencil_len)[:, None]
+
         self._model = config.pop('model', 'ThreeLayerReLu')
         num_datapoints = self._stencil_len
         if self._add_reconstructions:
@@ -176,14 +181,10 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
                                      projection[:, :num_ghost_cells]), axis=1)
 
         # Calculate input data depending on stencil length
-        input_data = torch.from_numpy(np.vstack([
-            self._basis.calculate_cell_average(
-                projection=projection[
-                           :, cell-num_ghost_cells:cell+num_ghost_cells+1],
-                stencil_len=self._stencil_len,
-                add_reconstructions=self._add_reconstructions)
-            for cell in range(num_ghost_cells,
-                              len(projection[0])-num_ghost_cells)]))
+        projection_window = projection[:, self._window_mask]
+        input_data = torch.from_numpy(self._basis.calculate_cell_average(
+            projection=projection_window, stencil_len=self._stencil_len,
+            add_reconstructions=self._add_reconstructions))
 
         # Determine troubled cells
         model_output = torch.argmax(self._model(input_data.float()),
-- 
GitLab