diff --git a/ANN_Training.py b/ANN_Training.py index 41c71d66800872d5c98b2aa701461401bc9beade..aa51a908037464da87deeab070fa5402298ac2a6 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -71,6 +71,13 @@ class ModelTrainer: self._num_epochs = config.pop('num_epochs', 1000) self._threshold = config.pop('threshold', 1e-5) + # Set random seed + seed = config.pop('random_seed', None) + if seed is not None: + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(True) + model = config.pop('model', 'ThreeLayerReLu') model_config = config.pop('model_config', {}) loss_function = config.pop('loss_function', 'BCELoss') diff --git a/Snakefile b/Snakefile index 9c31fbd19a5db9ff51bda0dd10c9904b151334f0..4b60aec8440658bfac51fcb105a85ed7e913d3b0 100644 --- a/Snakefile +++ b/Snakefile @@ -1,10 +1,6 @@ -import numpy as np - configfile: 'config.yaml' DIR = 'workflows' -if config['random_seed'] is not None: - np.random.seed(config['random_seed']) module_config = {'data_dir': config['data_dir'], 'random_seed': config['random_seed']} diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index f4ccfc40753da3178ffecfdf141e03dd11c30384..25cb9dd999a362f4423441b2e398b6dde982fd56 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,5 +1,4 @@ import sys -import numpy as np import ANN_Training from ANN_Training import * @@ -9,8 +8,6 @@ configfile: 'config.yaml' DIR = config['data_dir'] MODELS = config['models'] -if config['random_seed'] is not None: - np.random.seed(config['random_seed']) rule all: input: @@ -30,15 +27,9 @@ rule plot_test_results: log: DIR+'/log/plot_test_results.log' run: - models = {} with open(str(log), 'w') as logfile: sys.stdout = logfile sys.stderr = logfile - for model in MODELS: - trainer= ANN_Training.ModelTrainer( - {'model_name': model, 'dir': DIR, 'model_dir': DIR, - **MODELS[model]}) - models[model] = trainer plot_evaluation_results(evaluation_file=input.json_file, directory=DIR, colors=params.colors) @@ -52,7 +43,8 @@ rule test_model: protected(DIR+'/model evaluation/'+'_'.join(MODELS.keys())+'.json') params: num_iterations = config['num_iterations'], - compare_normalization = config['compare_normalization'] + compare_normalization = config['compare_normalization'], + random_seed = config['random_seed'] log: DIR+'/log/test_model.log' run: @@ -63,6 +55,7 @@ rule test_model: for model in MODELS: trainer= ANN_Training.ModelTrainer( {'model_name': model, 'dir': DIR, 'model_dir': DIR, + 'random_seed': params.random_seed, **MODELS[model]}) models[model] = trainer evaluate_models(models=models, directory=DIR, @@ -77,6 +70,8 @@ rule train_model: output: protected(DIR+'/trained models/{model}.model.pt'), protected(DIR+'/trained models/{model}.loss.pt') + params: + random_seed = config['random_seed'] log: DIR+'/log/train_model_{model}.log' run: @@ -84,6 +79,7 @@ rule train_model: sys.stdout = logfile training_data = read_training_data(DIR) trainer= ANN_Training.ModelTrainer( - config={**MODELS[wildcards.model]}) + config={'random_seed': params.random_seed, + **MODELS[wildcards.model]}) trainer.epoch_training(dataset=training_data) trainer.save_model(directory=DIR, model_name=wildcards.model) diff --git a/workflows/approximation.smk b/workflows/approximation.smk index cee3af3b90d115fd67a79acc1872a600d77e6bd0..46713ce1af0da5eae664a2f9991d4020b3407b60 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -1,20 +1,16 @@ import sys import time -import numpy as np import Initial_Condition import Quadrature from DG_Approximation import DGScheme from Plotting import plot_approximation_results -from projection_utils import Mesh configfile: 'config.yaml' PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube'] DIR = config['data_dir'] SCHEMES = config['schemes'] -if config['random_seed'] is not None: - np.random.seed(config['random_seed']) rule all: input: