From 4370bdddbcb10637d87963d8cbfa91a739125274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Fri, 28 Oct 2022 13:08:14 +0200 Subject: [PATCH] Vectorized 'get_cells()'in ANN detector. --- scripts/tcd/Troubled_Cell_Detector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py index dab248a..422db3a 100644 --- a/scripts/tcd/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -186,9 +186,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector): len(projection[0])-num_ghost_cells)])) # Determine troubled cells - model_output = torch.argmax(self._model(input_data.float()), dim=1) - return [cell for cell in range(len(model_output)) - if model_output[cell] == torch.tensor([1])] + model_output = torch.argmax(self._model(input_data.float()), + dim=1) + return np.flatnonzero(model_output.numpy()).tolist() class WaveletDetector(TroubledCellDetector): -- GitLab