From a557d401d330ef4e279bbc028babe2737c1f9723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Wed, 30 Jun 2021 17:13:37 +0200 Subject: [PATCH] Added explanatory comments to 'get_cells()' for the class 'ArtificialNeuralNetwork'. --- Troubled_Cell_Detector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index d922731..ebc71d2 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -3,6 +3,7 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Fix cell averages and reconstructions to create data with an x-point stencil +TODO: Add comments to get_cells() for ArtificialNeuralNetwork -> Done """ import os @@ -210,19 +211,22 @@ class ArtificialNeuralNetwork(TroubledCellDetector): self._model = getattr(ANN_Model, self._model)(self._model_config) def get_cells(self, projection): + # Reset ghost cells to adjust for stencil length num_ghost_cells = self._stencil_len//2 - projection = projection[:, 1:-1] - # projection = projection[:, :5] projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]), axis=1) + + # Calculate input data depending on stencil length input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions( projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1]) for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) + # Evaluate troubled cell probabilities self._model.load_state_dict(torch.load(self._model_state)) self._model.eval() + # Return troubled cells model_output = torch.round(self._model(input_data.float())) return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])] -- GitLab