diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py
index dab248add77ad31d436f195ab40b7920bc91a39c..422db3a4b1ce8d753d561ecbfe28a242d274a96f 100644
--- a/scripts/tcd/Troubled_Cell_Detector.py
+++ b/scripts/tcd/Troubled_Cell_Detector.py
@@ -186,9 +186,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
                               len(projection[0])-num_ghost_cells)]))
 
         # Determine troubled cells
-        model_output = torch.argmax(self._model(input_data.float()), dim=1)
-        return [cell for cell in range(len(model_output))
-                if model_output[cell] == torch.tensor([1])]
+        model_output = torch.argmax(self._model(input_data.float()),
+                                    dim=1)
+        return np.flatnonzero(model_output.numpy()).tolist()
 
 
 class WaveletDetector(TroubledCellDetector):