diff --git a/ANN_Training.py b/ANN_Training.py
index 64d7597bfc451917fe3d5a7a40c995e3e2fa2912..a2ce6923f8acd9db0d5fd98f2cde367fe595939d 100644
--- a/ANN_Training.py
+++ b/ANN_Training.py
@@ -6,9 +6,16 @@ Code-Style: E226, W503
 Docstring-Style: D200, D400
 
 TODO: Test new ANN set-up with Soraya
-TODO: Remove object set-up (for more flexibility)
+TODO: Remove object set-up (for more flexibility) -> Done (decided against it to keep easy test set-up)
 TODO: Add documentation
-TODO: Allow multiple approximations in one config
+TODO: Allow multiple approximations in one config -> Done
+TODO: Split workflow into multiple modules -> Done
+TODO: Remove unnecessary instance variables -> Done
+TODO: Add README for ANN training
+TODO: Fix random seed
+TODO: Discuss whether to outsource scripts into separate directory
+TODO: Discuss whether comparison between datasets is wanted
+TODO: Discuss whether default model state is useful
 
 """
 import numpy as np
@@ -32,45 +39,41 @@ class ModelTrainer(object):
         self._reset(config)
 
     def _reset(self, config):
-        self._dir = config.pop('dir', 'test_data')
-        self._model_name = config.pop('model_name', '0')
-        self._training_data = read_training_data(self._dir)
-
-        self._batch_size = config.pop('batch_size', min(len(self._training_data)//2, 500))
+        self._batch_size = config.pop('batch_size', 500)
         self._num_epochs = config.pop('num_epochs', 1000)
         self._threshold = config.pop('threshold', 1e-5)
-        self._model = config.pop('model', 'ThreeLayerReLu')
-        self._model_config = config.pop('model_config', {})
-        self._loss_function = config.pop('loss_function', 'BCELoss')
-        self._loss_config = config.pop('loss_config', {})
-        self._optimizer = config.pop('optimizer', 'Adam')
-        self._optimizer_config = config.pop('optimizer_config', {})
+
+        model = config.pop('model', 'ThreeLayerReLu')
+        model_config = config.pop('model_config', {})
+        loss_function = config.pop('loss_function', 'BCELoss')
+        loss_config = config.pop('loss_config', {})
+        optimizer = config.pop('optimizer', 'Adam')
+        optimizer_config = config.pop('optimizer_config', {})
 
         # Set learning rate
-        self._learning_rate = config.pop('learning_rate', 1e-2)
-        self._optimizer_config['lr'] = self._learning_rate
-
-        if not hasattr(ANN_Model, self._model):
-            raise ValueError('Invalid model: "%s"' % self._model)
-        if not hasattr(torch.nn.modules.loss, self._loss_function):
-            raise ValueError('Invalid loss function: "%s"' % self._loss_function)
-        if not hasattr(torch.optim, self._optimizer):
-            raise ValueError('Invalid optimizer: "%s"' % self._optimizer)
-
-        self._model = getattr(ANN_Model, self._model)(self._model_config)
-        self._loss_function = getattr(torch.nn.modules.loss, self._loss_function)(
-            **self._loss_config)
-        self._optimizer = getattr(torch.optim, self._optimizer)(
-            self._model.parameters(), **self._optimizer_config)
+        learning_rate = config.pop('learning_rate', 1e-2)
+        optimizer_config['lr'] = learning_rate
+
+        if not hasattr(ANN_Model, model):
+            raise ValueError('Invalid model: "%s"' % model)
+        if not hasattr(torch.nn.modules.loss, loss_function):
+            raise ValueError('Invalid loss function: "%s"' % loss_function)
+        if not hasattr(torch.optim, optimizer):
+            raise ValueError('Invalid optimizer: "%s"' % optimizer)
+
+        self._model = getattr(ANN_Model, model)(model_config)
+        self._loss_function = getattr(torch.nn.modules.loss, loss_function)(
+            **loss_config)
+        self._optimizer = getattr(torch.optim, optimizer)(
+            self._model.parameters(), **optimizer_config)
         self._validation_loss = torch.zeros(self._num_epochs//10)
 
-    def epoch_training(self, dataset=None, num_epochs=None, verbose=True):
+    def epoch_training(self, dataset, num_epochs=None, verbose=True):
         tic = time.perf_counter()
-        # Split data into training and validation set
-        if dataset is None:
-            dataset = self._training_data
         if num_epochs is None:
             num_epochs = self._num_epochs
+
+        # Split data into training and validation set
         num_samples = len(dataset)
         if verbose:
             print('Splitting data randomly into training and validation set.')
@@ -137,17 +140,15 @@ class ModelTrainer(object):
                 'F-Score_Smooth': f_score[0], 'F-Score_Troubled': f_score[1],
                 'Accuracy': accuracy, 'AUROC': auroc}
 
-    def save_model(self):
-        # Saving Model
-        name = self._model_name
-
+    def save_model(self, directory, model_name='test_model'):
         # Set paths for files if not existing already
-        model_dir = self._dir + '/trained models'
+        model_dir = directory + '/trained models'
         if not os.path.exists(model_dir):
             os.makedirs(model_dir)
 
-        torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
-        torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')
+        # Save model and loss
+        torch.save(self._model.state_dict(), model_dir + '/model__' + model_name + '.pt')
+        torch.save(self._validation_loss, model_dir + '/loss__' + model_name + '.pt')
 
 
 def read_training_data(directory, normalized=True):
diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk
index 24abd452287cfd3b90b2f60a648c30a566a4032d..32640a3e025ab35f58790c0b84f6a8c89a91d0ea 100644
--- a/workflows/ANN_training.smk
+++ b/workflows/ANN_training.smk
@@ -1,7 +1,7 @@
 import sys
 
 import ANN_Training
-from ANN_Training import evaluate_models
+from ANN_Training import *
 
 configfile: 'config.yaml'
 
@@ -53,7 +53,7 @@ rule train_model:
     run:
         with open(str(log), 'w') as logfile:
             sys.stdout = logfile
-            trainer= ANN_Training.ModelTrainer(config={'model_name': wildcards.model, 'dir': DIR,
-                                            'model_dir': DIR, **MODELS[wildcards.model]})
-            trainer.epoch_training()
-            trainer.save_model()
+            training_data = read_training_data(DIR)
+            trainer= ANN_Training.ModelTrainer(config={**MODELS[wildcards.model]})
+            trainer.epoch_training(dataset=training_data)
+            trainer.save_model(directory=DIR, model_name=wildcards.model)