From 7b32def237a3910bf787cf2bec03e4e5ec99a08c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Mon, 14 Nov 2022 00:14:39 +0100 Subject: [PATCH] Vectorized 'get_cells()' for ANN detector. --- Snakefile | 13 +++++++------ scripts/tcd/Basis_Function.py | 13 ++++++------- scripts/tcd/Troubled_Cell_Detector.py | 17 +++++++++-------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/Snakefile b/Snakefile index e4fbf82..171ac8a 100644 --- a/Snakefile +++ b/Snakefile @@ -21,22 +21,20 @@ TODO: Discuss descriptions (matrices, cfl number, right-hand side, TODO: Discuss referencing info on SSPRK3 TODO: Discuss name for quadrature mesh (now: grid) TODO: Contemplate using lambdify for basis +TODO: Contemplate allowing vector input for ICs TODO: Discuss how wavelet details should be plotted Urgent: TODO: Vectorize 'plot_details()' -> Done TODO: Remove unnecessary dimension during cell average calculation -> Done -TODO: Replace loops/list comprehension with vectorization if feasible -TODO: Replace loops with list comprehension if feasible -TODO: Rework ICs to allow vector input +TODO: Vectorize 'get_cells()' for ANN detector -> Done +TODO: Replace loops/list comprehension with vectorization if feasible -> Done +TODO: Replace loops with list comprehension if feasible -> Done TODO: Check whether 'projection' is always a ndarray TODO: Check whether ghost cells are handled/set correctly TODO: Enforce even number of ghost cells on each side on fine mesh (?) TODO: Check whether all instance variables are sensible -TODO: Investigate g-mesh(?) -TODO: Create g-mesh with Mesh class TODO: Combine ANN workflows if feasible -TODO: Investigate profiling for speed up TODO: Make sure that the cell indices are the same over all TCDs TODO: Make sure TCs are reported as ndarray @@ -49,6 +47,8 @@ TODO: Allow comparison between ANN training datasets TODO: Use full path for ANN model state TODO: Add a default model state TODO: Look into validators for variable checks +TODO: Investigate g-mesh(?) +TODO: Create g-mesh with Mesh class Not feasible yet or doc-related: TODO: Move plot_approximation_results() into plotting script @@ -74,6 +74,7 @@ TODO: Check whether documentation style is correct TODO: Check whether all types in doc are correct TODO: Add type annotations to function heads TODO: Clean up docstrings +TODO: Investigate profiling for speed up """ diff --git a/scripts/tcd/Basis_Function.py b/scripts/tcd/Basis_Function.py index 2b3c809..4144fb9 100644 --- a/scripts/tcd/Basis_Function.py +++ b/scripts/tcd/Basis_Function.py @@ -507,13 +507,12 @@ class OrthonormalLegendre(Legendre): left_reconstructions, right_reconstructions = \ self._calculate_reconstructions( projection[:, middle_idx:middle_idx+1]) - return np.array(list(map( - np.float64, zip(cell_averages[:middle_idx], - left_reconstructions, - cell_averages[middle_idx:middle_idx+1], - right_reconstructions, - cell_averages[middle_idx+1:])))) - return np.array(list(map(np.float64, cell_averages))) + return np.hstack([cell_averages[:middle_idx].T, + left_reconstructions, + cell_averages[middle_idx:middle_idx+1].T, + right_reconstructions, + cell_averages[middle_idx+1:].T]) + return cell_averages def _calculate_reconstructions(self, projection: ndarray) \ -> Tuple[list, list]: diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py index d33a51a..175fd9c 100644 --- a/scripts/tcd/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -135,6 +135,11 @@ class ArtificialNeuralNetwork(TroubledCellDetector): self._stencil_len = config.pop('stencil_len', 3) self._add_reconstructions = config.pop('add_reconstructions', True) + + # Set mask for stencil sliding window + self._window_mask = np.arange(self._mesh.num_cells)[None, :] + \ + np.arange(self._stencil_len)[:, None] + self._model = config.pop('model', 'ThreeLayerReLu') num_datapoints = self._stencil_len if self._add_reconstructions: @@ -176,14 +181,10 @@ class ArtificialNeuralNetwork(TroubledCellDetector): projection[:, :num_ghost_cells]), axis=1) # Calculate input data depending on stencil length - input_data = torch.from_numpy(np.vstack([ - self._basis.calculate_cell_average( - projection=projection[ - :, cell-num_ghost_cells:cell+num_ghost_cells+1], - stencil_len=self._stencil_len, - add_reconstructions=self._add_reconstructions) - for cell in range(num_ghost_cells, - len(projection[0])-num_ghost_cells)])) + projection_window = projection[:, self._window_mask] + input_data = torch.from_numpy(self._basis.calculate_cell_average( + projection=projection_window, stencil_len=self._stencil_len, + add_reconstructions=self._add_reconstructions)) # Determine troubled cells model_output = torch.argmax(self._model(input_data.float()), -- GitLab