From fd22e0cada21ff251aaa7a98e874727c5a29cf09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Wed, 19 Jan 2022 21:05:01 +0100 Subject: [PATCH] Fixed bug in ouput evaluation for ANN model. --- Troubled_Cell_Detector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index 837a358..d3dabe3 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -6,6 +6,7 @@ TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) TODO: Give detailed description of wavelet detection TODO: Load ANN state and config in reset -> Done +TODO: Fix bug in output calculation -> Done """ import numpy as np @@ -303,9 +304,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector): for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) # Determine troubled cells - model_output = torch.round(self._model(input_data.float())) + model_output = torch.argmax(self._model(input_data.float()), dim=1) return [cell for cell in range(len(model_output)) - if model_output[cell, 0] == torch.tensor([1])] + if model_output[cell] == torch.tensor([0])] class WaveletDetector(TroubledCellDetector): -- GitLab