From 5effb4e866a302023799c9944230bcfcf6278daa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 5 Oct 2022 23:32:25 +0200
Subject: [PATCH] Outsourced run command for SM rule 'train_model' into script.

---
 scripts/train_model.py     | 34 ++++++++++++++++++++++++++++++++++
 workflows/ANN_training.smk | 15 +++++----------
 2 files changed, 39 insertions(+), 10 deletions(-)
 create mode 100644 scripts/train_model.py

diff --git a/scripts/train_model.py b/scripts/train_model.py
new file mode 100644
index 0000000..492d0f7
--- /dev/null
+++ b/scripts/train_model.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+"""Script to train ANN models.
+
+@author: Laura C. Kühle
+"""
+import sys
+
+from tcd.ANN_Training import ModelTrainer, read_training_data
+
+
+def main() -> None:
+    """Train ANN models."""
+    with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
+        sys.stdout = logfile
+
+        # Read training data
+        training_data = read_training_data(snakemake.params['directory'])
+
+        # Train model
+        trainer = ModelTrainer(config={
+            'random_seed': snakemake.params['random_seed'],
+            **snakemake.params['models'][snakemake.wildcards['model']]})
+        trainer.epoch_training(dataset=training_data)
+
+        # Save trained model
+        trainer.save_model(directory=snakemake.params['directory'],
+                           model_name=snakemake.wildcards['model'])
+
+
+if __name__ == '__main__':
+    if "snakemake" in locals():
+        main()
+    else:
+        print('Not Defined.')
diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk
index 47c3dc2..fd03f39 100644
--- a/workflows/ANN_training.smk
+++ b/workflows/ANN_training.smk
@@ -71,15 +71,10 @@ rule train_model:
         protected(DIR+'/trained models/{model}.model.pt'),
         protected(DIR+'/trained models/{model}.loss.pt')
     params:
-        random_seed = config['random_seed']
+        random_seed = config['random_seed'],
+        directory = DIR,
+        models = config['models']
     log:
         DIR+'/log/train_model/{model}.log'
-    run:
-        with open(str(log), 'w') as logfile:
-            sys.stdout = logfile
-            training_data = read_training_data(DIR)
-            trainer= ANN_Training.ModelTrainer(
-                config={'random_seed': params.random_seed,
-                        **MODELS[wildcards.model]})
-            trainer.epoch_training(dataset=training_data)
-            trainer.save_model(directory=DIR, model_name=wildcards.model)
+    script:
+        '../scripts/train_model.py'
-- 
GitLab