diff --git a/Snakefile b/Snakefile index 75ddd1e4a4809b1790f7732c90597a2081e68629..9c31fbd19a5db9ff51bda0dd10c9904b151334f0 100644 --- a/Snakefile +++ b/Snakefile @@ -6,21 +6,24 @@ DIR = 'workflows' if config['random_seed'] is not None: np.random.seed(config['random_seed']) +module_config = {'data_dir': config['data_dir'], + 'random_seed': config['random_seed']} + module ann_data: snakefile: DIR + '/ANN_data.smk' - config: {**config['ANN_Data'], 'data_dir': config['data_dir']} + config: {**config['ANN_Data'], **module_config} use rule * from ann_data as ANN_* module ann_training: snakefile: DIR + '/ANN_training.smk' - config: {**config['ANN_Training'], 'data_dir':config['data_dir']} + config: {**config['ANN_Training'], **module_config} use rule * from ann_training as ANN_* module approximation: snakefile: DIR + '/approximation.smk' - config: {**config['Approximation'], 'data_dir': config['data_dir']} + config: {**config['Approximation'], **module_config} use rule * from approximation as DG_* diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index ef0bc3be22b14a50ee998e0c29b2bb4f317c3691..9e29805a8ad5b12cf1106278120043e074759797 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -1,13 +1,14 @@ import sys import time - -configfile: 'config.yaml' +import numpy as np import ANN_Data_Generator, Initial_Condition +configfile: 'config.yaml' + DIR = config['data_dir'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule generate_data: output: diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index 5842c6afbbb8c0a00e049fddd4b2752d1d7b490f..af8085efed9435e0011c0b1a8017a7430f7e8f8f 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,4 +1,5 @@ import sys +import numpy as np import ANN_Training from ANN_Training import * @@ -7,8 +8,8 @@ configfile: 'config.yaml' DIR = config['data_dir'] MODELS = config['models'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule all: input: diff --git a/workflows/approximation.smk b/workflows/approximation.smk index dd47186f1121bd2f5b8fc17c89adbcce3a42218a..293d0f48752d54ae0fa280f6702efe3202acd83c 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -1,5 +1,6 @@ import sys import time +import numpy as np from DG_Approximation import DGScheme @@ -8,8 +9,8 @@ configfile: 'config.yaml' PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube'] DIR = config['data_dir'] SCHEMES = config['schemes'] -# if config['random_seed'] is not None: -# np.random.seed(config['random_seed']) +if config['random_seed'] is not None: + np.random.seed(config['random_seed']) rule all: input: