From fd22e0cada21ff251aaa7a98e874727c5a29cf09 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 19 Jan 2022 21:05:01 +0100
Subject: [PATCH] Fixed bug in ouput evaluation for ANN model.

---
 Troubled_Cell_Detector.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 837a358..d3dabe3 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -6,6 +6,7 @@ TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of
 TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
 TODO: Give detailed description of wavelet detection
 TODO: Load ANN state and config in reset -> Done
+TODO: Fix bug in output calculation -> Done
 
 """
 import numpy as np
@@ -303,9 +304,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
             for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
 
         # Determine troubled cells
-        model_output = torch.round(self._model(input_data.float()))
+        model_output = torch.argmax(self._model(input_data.float()), dim=1)
         return [cell for cell in range(len(model_output))
-                if model_output[cell, 0] == torch.tensor([1])]
+                if model_output[cell] == torch.tensor([0])]
 
 
 class WaveletDetector(TroubledCellDetector):
-- 
GitLab