From bb74ef62a3b6ae5800782428674f26292133db7c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Tue, 1 Mar 2022 17:54:21 +0100
Subject: [PATCH] Fixed issue that random seed was not set in each workflow
 module.

---
 Snakefile                   | 9 ++++++---
 workflows/ANN_data.smk      | 9 +++++----
 workflows/ANN_training.smk  | 5 +++--
 workflows/approximation.smk | 5 +++--
 4 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/Snakefile b/Snakefile
index 75ddd1e..9c31fbd 100644
--- a/Snakefile
+++ b/Snakefile
@@ -6,21 +6,24 @@ DIR = 'workflows'
 if config['random_seed'] is not None:
     np.random.seed(config['random_seed'])
 
+module_config = {'data_dir': config['data_dir'],
+                 'random_seed': config['random_seed']}
+
 module ann_data:
     snakefile: DIR + '/ANN_data.smk'
-    config: {**config['ANN_Data'], 'data_dir': config['data_dir']}
+    config: {**config['ANN_Data'], **module_config}
 
 use rule * from ann_data as ANN_*
 
 module ann_training:
     snakefile: DIR + '/ANN_training.smk'
-    config: {**config['ANN_Training'], 'data_dir':config['data_dir']}
+    config: {**config['ANN_Training'], **module_config}
 
 use rule * from ann_training as ANN_*
 
 module approximation:
     snakefile: DIR + '/approximation.smk'
-    config: {**config['Approximation'], 'data_dir': config['data_dir']}
+    config: {**config['Approximation'], **module_config}
 
 use rule * from approximation as DG_*
 
diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk
index ef0bc3b..9e29805 100644
--- a/workflows/ANN_data.smk
+++ b/workflows/ANN_data.smk
@@ -1,13 +1,14 @@
 import sys
 import time
-
-configfile: 'config.yaml'
+import numpy as np
 
 import ANN_Data_Generator, Initial_Condition
 
+configfile: 'config.yaml'
+
 DIR = config['data_dir']
-# if config['random_seed'] is not None:
-#     np.random.seed(config['random_seed'])
+if config['random_seed'] is not None:
+    np.random.seed(config['random_seed'])
 
 rule generate_data:
     output:
diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk
index 5842c6a..af8085e 100644
--- a/workflows/ANN_training.smk
+++ b/workflows/ANN_training.smk
@@ -1,4 +1,5 @@
 import sys
+import numpy as np
 
 import ANN_Training
 from ANN_Training import *
@@ -7,8 +8,8 @@ configfile: 'config.yaml'
 
 DIR = config['data_dir']
 MODELS = config['models']
-# if config['random_seed'] is not None:
-#     np.random.seed(config['random_seed'])
+if config['random_seed'] is not None:
+    np.random.seed(config['random_seed'])
 
 rule all:
     input:
diff --git a/workflows/approximation.smk b/workflows/approximation.smk
index dd47186..293d0f4 100644
--- a/workflows/approximation.smk
+++ b/workflows/approximation.smk
@@ -1,5 +1,6 @@
 import sys
 import time
+import numpy as np
 
 from DG_Approximation import DGScheme
 
@@ -8,8 +9,8 @@ configfile: 'config.yaml'
 PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube']
 DIR = config['data_dir']
 SCHEMES = config['schemes']
-# if config['random_seed'] is not None:
-#     np.random.seed(config['random_seed'])
+if config['random_seed'] is not None:
+    np.random.seed(config['random_seed'])
 
 rule all:
     input:
-- 
GitLab