diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py index dab248add77ad31d436f195ab40b7920bc91a39c..422db3a4b1ce8d753d561ecbfe28a242d274a96f 100644 --- a/scripts/tcd/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -186,9 +186,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector): len(projection[0])-num_ghost_cells)])) # Determine troubled cells - model_output = torch.argmax(self._model(input_data.float()), dim=1) - return [cell for cell in range(len(model_output)) - if model_output[cell] == torch.tensor([1])] + model_output = torch.argmax(self._model(input_data.float()), + dim=1) + return np.flatnonzero(model_output.numpy()).tolist() class WaveletDetector(TroubledCellDetector):