diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index 30282470b8d54c688d8cf4b5295f323239026a1c..837a358deb91581ade2a037d25225e277d8e5e5d 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -5,7 +5,7 @@ TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second) TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) TODO: Give detailed description of wavelet detection -TODO: Load ANN state and config in reset +TODO: Load ANN state and config in reset -> Done """ import numpy as np @@ -266,13 +266,16 @@ class ArtificialNeuralNetwork(TroubledCellDetector): self._model_config = config.pop('model_config', { 'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4, 'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}}) - self._model_state = config.pop( - 'model_state', 'Snakemake-Test/trained models/model__Adam.pt') + model_state = config.pop('model_state', 'Snakemake-Test/trained models/model__Adam.pt') if not hasattr(ANN_Model, self._model): raise ValueError('Invalid model: "%s"' % self._model) self._model = getattr(ANN_Model, self._model)(self._model_config) + # Load the model state and set it to evaluation mode + self._model.load_state_dict(torch.load(str(model_state))) + self._model.eval() + def get_cells(self, projection): """Calculates troubled cells in a given projection. @@ -299,11 +302,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector): projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len) for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) - # Evaluate troubled cell probabilities - self._model.load_state_dict(torch.load(self._model_state)) - self._model.eval() - - # Return troubled cells + # Determine troubled cells model_output = torch.round(self._model(input_data.float())) return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]