From bfcddd48ee4e10f2c50e17b4cc7f2a56e256fcf6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 5 Oct 2022 23:11:13 +0200
Subject: [PATCH] Outsourced run command for SM rule 'generate_data' into
 script.

---
 scripts/generate_data.py | 45 ++++++++++++++++++++++++++++++++++++++++
 workflows/ANN_data.smk   | 32 ++++++----------------------
 2 files changed, 51 insertions(+), 26 deletions(-)
 create mode 100644 scripts/generate_data.py

diff --git a/scripts/generate_data.py b/scripts/generate_data.py
new file mode 100644
index 0000000..427ec58
--- /dev/null
+++ b/scripts/generate_data.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+"""Script to generate training data for ANN models.
+
+@author: Laura C. Kühle
+"""
+import sys
+import numpy as np
+
+from tcd import Initial_Condition
+from tcd import ANN_Data_Generator
+
+
+def main() -> None:
+    """Generate training data for ANN models."""
+    # Set random seed if given
+    if snakemake.params['seed'] is not None:
+        np.random.seed(snakemake.params['seed'])
+
+    # Determine list of initial conditions
+    init_cond_list = []
+    for function in snakemake.params['functions']:
+        init_cond_list.append({
+            'function': getattr(Initial_Condition, function)({}),
+            'config': {} if snakemake.config['functions'][function] is None
+            else snakemake.config['functions'][function]
+        })
+
+    # Generate training data
+    with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
+        sys.stdout = logfile
+        generator = ANN_Data_Generator.TrainingDataGenerator()
+        _ = generator.build_training_data(
+            balance=snakemake.params['balance'],
+            init_cond_list=init_cond_list,
+            directory=snakemake.params['directory'],
+            num_samples=snakemake.params['sample_number'],
+            add_reconstructions=snakemake.params['reconstruction_flag'],
+            stencil_len=snakemake.params['stencil_len'])
+
+
+if __name__ == '__main__':
+    if "snakemake" in locals():
+        main()
+    else:
+        print('Not Defined.')
diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk
index 9d38b01..d11164d 100644
--- a/workflows/ANN_data.smk
+++ b/workflows/ANN_data.smk
@@ -1,15 +1,8 @@
-import sys
-import time
-import numpy as np
-
-from scripts.tcd import Initial_Condition
-from scripts.tcd import ANN_Data_Generator
-
 configfile: 'config.yaml'
 
+
 DIR = config['data_dir']
-if config['random_seed'] is not None:
-    np.random.seed(config['random_seed'])
+
 
 rule generate_data:
     output:
@@ -22,23 +15,10 @@ rule generate_data:
         stencil_len = config['stencil_length'],
         sample_number = config['sample_number'],
         reconstruction_flag = config['add_reconstructions'],
+        directory = DIR,
+        seed = config['random_seed'],
         functions = expand('{FUNCTION}', FUNCTION=config['functions'])
     log:
         DIR+'/log/generate_data.log'
-    run:
-        init_cond_list = []
-        for function in params.functions:
-            init_cond_list.append({
-                'function': getattr(Initial_Condition, function)({}),
-                'config': {} if config['functions'][function] is None
-                else config['functions'][function]
-            })
-
-        with open(str(log), 'w') as logfile:
-            sys.stdout = logfile
-            generator = ANN_Data_Generator.TrainingDataGenerator()
-            data = generator.build_training_data(balance=params.balance,
-                init_cond_list=init_cond_list, directory=DIR,
-                num_samples=params.sample_number,
-                add_reconstructions=params.reconstruction_flag,
-                stencil_len=params.stencil_len)
\ No newline at end of file
+    script:
+        '../scripts/generate_data.py'
-- 
GitLab