Skip to content
Snippets Groups Projects
Commit bc1150e6 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Moved loading of ANN state to reset().

parent e6a0ef81
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@
TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second)
TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection
TODO: Load ANN state and config in reset
TODO: Load ANN state and config in reset -> Done
"""
import numpy as np
......@@ -266,13 +266,16 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
self._model_config = config.pop('model_config', {
'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4,
'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
self._model_state = config.pop(
'model_state', 'Snakemake-Test/trained models/model__Adam.pt')
model_state = config.pop('model_state', 'Snakemake-Test/trained models/model__Adam.pt')
if not hasattr(ANN_Model, self._model):
raise ValueError('Invalid model: "%s"' % self._model)
self._model = getattr(ANN_Model, self._model)(self._model_config)
# Load the model state and set it to evaluation mode
self._model.load_state_dict(torch.load(str(model_state)))
self._model.eval()
def get_cells(self, projection):
"""Calculates troubled cells in a given projection.
......@@ -299,11 +302,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
# Evaluate troubled cell probabilities
self._model.load_state_dict(torch.load(self._model_state))
self._model.eval()
# Return troubled cells
# Determine troubled cells
model_output = torch.round(self._model(input_data.float()))
return [cell for cell in range(len(model_output))
if model_output[cell, 0] == torch.tensor([1])]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment