From 4370bdddbcb10637d87963d8cbfa91a739125274 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Fri, 28 Oct 2022 13:08:14 +0200
Subject: [PATCH] Vectorized 'get_cells()'in ANN detector.

---
 scripts/tcd/Troubled_Cell_Detector.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py
index dab248a..422db3a 100644
--- a/scripts/tcd/Troubled_Cell_Detector.py
+++ b/scripts/tcd/Troubled_Cell_Detector.py
@@ -186,9 +186,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
                               len(projection[0])-num_ghost_cells)]))
 
         # Determine troubled cells
-        model_output = torch.argmax(self._model(input_data.float()), dim=1)
-        return [cell for cell in range(len(model_output))
-                if model_output[cell] == torch.tensor([1])]
+        model_output = torch.argmax(self._model(input_data.float()),
+                                    dim=1)
+        return np.flatnonzero(model_output.numpy()).tolist()
 
 
 class WaveletDetector(TroubledCellDetector):
-- 
GitLab