Skip to content
Snippets Groups Projects
Commit 4370bddd authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Vectorized 'get_cells()'in ANN detector.

parent 76e3c1a5
No related branches found
No related tags found
No related merge requests found
......@@ -186,9 +186,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
len(projection[0])-num_ghost_cells)]))
# Determine troubled cells
model_output = torch.argmax(self._model(input_data.float()), dim=1)
return [cell for cell in range(len(model_output))
if model_output[cell] == torch.tensor([1])]
model_output = torch.argmax(self._model(input_data.float()),
dim=1)
return np.flatnonzero(model_output.numpy()).tolist()
class WaveletDetector(TroubledCellDetector):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment