diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 837a358deb91581ade2a037d25225e277d8e5e5d..d3dabe3585e99d5952869d7d13439fe41418a43b 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -6,6 +6,7 @@ TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of
 TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
 TODO: Give detailed description of wavelet detection
 TODO: Load ANN state and config in reset -> Done
+TODO: Fix bug in output calculation -> Done
 
 """
 import numpy as np
@@ -303,9 +304,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
             for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
 
         # Determine troubled cells
-        model_output = torch.round(self._model(input_data.float()))
+        model_output = torch.argmax(self._model(input_data.float()), dim=1)
         return [cell for cell in range(len(model_output))
-                if model_output[cell, 0] == torch.tensor([1])]
+                if model_output[cell] == torch.tensor([0])]
 
 
 class WaveletDetector(TroubledCellDetector):