From bb74ef62a3b6ae5800782428674f26292133db7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 1 Mar 2022 17:54:21 +0100 Subject: [PATCH] Fixed issue that random seed was not set in each workflow module. --- Snakefile | 9 ++++++--- workflows/ANN_data.smk | 9 +++++---- workflows/ANN_training.smk | 5 +++-- workflows/approximation.smk | 5 +++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/Snakefile b/Snakefile index 75ddd1e..9c31fbd 100644 --- a/Snakefile +++ b/Snakefile @@ -6,21 +6,24 @@ DIR = 'workflows' if config['random_seed'] is not None: np.random.seed(config['random_seed']) +module_config = {'data_dir': config['data_dir'], + 'random_seed': config['random_seed']} + module ann_data: snakefile: DIR + '/ANN_data.smk' - config: {**config['ANN_Data'], 'data_dir': config['data_dir']} + config: {**config['ANN_Data'], **module_config} use rule * from ann_data as ANN_* module ann_training: snakefile: DIR + '/ANN_training.smk' - config: {**config['ANN_Training'], 'data_dir':config['data_dir']} + config: {**config['ANN_Training'], **module_config} use rule * from ann_training as ANN_* module approximation: snakefile: DIR + '/approximation.smk' - config: {**config['Approximation'], 'data_dir': config['data_dir']} + config: {**config['Approximation'], **module_config} use rule * from approximation as DG_* diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index ef0bc3b..9e29805 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -1,13 +1,14 @@ import sys import time - -configfile: 'config.yaml' +import numpy as np import ANN_Data_Generator, Initial_Condition +configfile: 'config.yaml' + DIR = config['data_dir'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule generate_data: output: diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index 5842c6a..af8085e 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,4 +1,5 @@ import sys +import numpy as np import ANN_Training from ANN_Training import * @@ -7,8 +8,8 @@ configfile: 'config.yaml' DIR = config['data_dir'] MODELS = config['models'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule all: input: diff --git a/workflows/approximation.smk b/workflows/approximation.smk index dd47186..293d0f4 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -1,5 +1,6 @@ import sys import time +import numpy as np from DG_Approximation import DGScheme @@ -8,8 +9,8 @@ configfile: 'config.yaml' PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube'] DIR = config['data_dir'] SCHEMES = config['schemes'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule all: input: -- GitLab