diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index 837a358deb91581ade2a037d25225e277d8e5e5d..d3dabe3585e99d5952869d7d13439fe41418a43b 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -6,6 +6,7 @@ TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) TODO: Give detailed description of wavelet detection TODO: Load ANN state and config in reset -> Done +TODO: Fix bug in output calculation -> Done """ import numpy as np @@ -303,9 +304,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector): for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) # Determine troubled cells - model_output = torch.round(self._model(input_data.float())) + model_output = torch.argmax(self._model(input_data.float()), dim=1) return [cell for cell in range(len(model_output)) - if model_output[cell, 0] == torch.tensor([1])] + if model_output[cell] == torch.tensor([0])] class WaveletDetector(TroubledCellDetector):