From bc1150e6fea204f3aeca9a8f0c9a00f3395916e8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 19 Jan 2022 20:46:28 +0100
Subject: [PATCH] Moved loading of ANN state to reset().

---
 Troubled_Cell_Detector.py | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 3028247..837a358 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -5,7 +5,7 @@
 TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second)
 TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
 TODO: Give detailed description of wavelet detection
-TODO: Load ANN state and config in reset
+TODO: Load ANN state and config in reset -> Done
 
 """
 import numpy as np
@@ -266,13 +266,16 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
         self._model_config = config.pop('model_config', {
             'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4,
             'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
-        self._model_state = config.pop(
-            'model_state', 'Snakemake-Test/trained models/model__Adam.pt')
+        model_state = config.pop('model_state', 'Snakemake-Test/trained models/model__Adam.pt')
 
         if not hasattr(ANN_Model, self._model):
             raise ValueError('Invalid model: "%s"' % self._model)
         self._model = getattr(ANN_Model, self._model)(self._model_config)
 
+        # Load the model state and set it to evaluation mode
+        self._model.load_state_dict(torch.load(str(model_state)))
+        self._model.eval()
+
     def get_cells(self, projection):
         """Calculates troubled cells in a given projection.
 
@@ -299,11 +302,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
             projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
             for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
 
-        # Evaluate troubled cell probabilities
-        self._model.load_state_dict(torch.load(self._model_state))
-        self._model.eval()
-
-        # Return troubled cells
+        # Determine troubled cells
         model_output = torch.round(self._model(input_data.float()))
         return [cell for cell in range(len(model_output))
                 if model_output[cell, 0] == torch.tensor([1])]
-- 
GitLab