diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 30282470b8d54c688d8cf4b5295f323239026a1c..837a358deb91581ade2a037d25225e277d8e5e5d 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -5,7 +5,7 @@
 TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second)
 TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
 TODO: Give detailed description of wavelet detection
-TODO: Load ANN state and config in reset
+TODO: Load ANN state and config in reset -> Done
 
 """
 import numpy as np
@@ -266,13 +266,16 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
         self._model_config = config.pop('model_config', {
             'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4,
             'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
-        self._model_state = config.pop(
-            'model_state', 'Snakemake-Test/trained models/model__Adam.pt')
+        model_state = config.pop('model_state', 'Snakemake-Test/trained models/model__Adam.pt')
 
         if not hasattr(ANN_Model, self._model):
             raise ValueError('Invalid model: "%s"' % self._model)
         self._model = getattr(ANN_Model, self._model)(self._model_config)
 
+        # Load the model state and set it to evaluation mode
+        self._model.load_state_dict(torch.load(str(model_state)))
+        self._model.eval()
+
     def get_cells(self, projection):
         """Calculates troubled cells in a given projection.
 
@@ -299,11 +302,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
             projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
             for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
 
-        # Evaluate troubled cell probabilities
-        self._model.load_state_dict(torch.load(self._model_state))
-        self._model.eval()
-
-        # Return troubled cells
+        # Determine troubled cells
         model_output = torch.round(self._model(input_data.float()))
         return [cell for cell in range(len(model_output))
                 if model_output[cell, 0] == torch.tensor([1])]