From 9b0bacf05121ccf823059ce01610d30da4f2ee73 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 5 Oct 2022 23:55:55 +0200
Subject: [PATCH] Outsourced run command for SM rule 'plot_test_results' into
 script.

---
 scripts/plot_test_results.py | 27 +++++++++++++++++++++++++++
 workflows/ANN_training.smk   | 20 ++++++++------------
 2 files changed, 35 insertions(+), 12 deletions(-)
 create mode 100644 scripts/plot_test_results.py

diff --git a/scripts/plot_test_results.py b/scripts/plot_test_results.py
new file mode 100644
index 0000000..53eb6d0
--- /dev/null
+++ b/scripts/plot_test_results.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+"""Script to plot results of ANN model testing.
+
+@author: Laura C. Kühle
+"""
+import sys
+
+from tcd.Plotting import plot_evaluation_results
+
+
+def main() -> None:
+    """Plot results of ANN model tests."""
+    with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
+        sys.stdout = logfile
+        sys.stderr = logfile
+
+        # Plot evaluation results
+        plot_evaluation_results(evaluation_file=snakemake.input['json_file'],
+                                directory=snakemake.params['directory'],
+                                colors=snakemake.params['colors'])
+
+
+if __name__ == '__main__':
+    if "snakemake" in locals():
+        main()
+    else:
+        print('Not Defined.')
diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk
index 7cdaa9d..f75998e 100644
--- a/workflows/ANN_training.smk
+++ b/workflows/ANN_training.smk
@@ -1,12 +1,10 @@
-import sys
-
-from scripts.tcd.Plotting import plot_evaluation_results
-
 configfile: 'config.yaml'
 
+
 DIR = config['data_dir']
 MODELS = config['models']
 
+
 rule all:
     input:
         expand(DIR+'/trained models/{model}.model.pt', model=MODELS),
@@ -14,22 +12,20 @@ rule all:
         +'.barplot.pdf'
     default_target: True
 
+
 rule plot_test_results:
     input:
-        json_file=DIR+'/model evaluation/'+ '_'.join(MODELS.keys()) + '.json'
+        json_file = DIR+'/model evaluation/'+'_'.join(MODELS.keys()) + '.json'
     output:
         DIR+'/model evaluation/'+'_'.join(MODELS.keys())
         +'.barplot.pdf'
     params:
-        colors = config['classification_colors']
+        colors = config['classification_colors'],
+        directory = DIR
     log:
         DIR+'/log/plot_test_results.log'
-    run:
-        with open(str(log), 'w') as logfile:
-            sys.stdout = logfile
-            sys.stderr = logfile
-            plot_evaluation_results(evaluation_file=input.json_file,
-                directory=DIR, colors=params.colors)
+    script:
+        '../scripts/plot_test_results.py'
 
 
 rule test_model:
-- 
GitLab