Skip to content
Snippets Groups Projects
Commit 97f6c885 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Ensured repeatable output through seeds and deterministic models.

parent 1449ed05
No related branches found
No related tags found
No related merge requests found
......@@ -71,6 +71,13 @@ class ModelTrainer:
self._num_epochs = config.pop('num_epochs', 1000)
self._threshold = config.pop('threshold', 1e-5)
# Set random seed
seed = config.pop('random_seed', None)
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
model = config.pop('model', 'ThreeLayerReLu')
model_config = config.pop('model_config', {})
loss_function = config.pop('loss_function', 'BCELoss')
......
import numpy as np
configfile: 'config.yaml'
DIR = 'workflows'
if config['random_seed'] is not None:
np.random.seed(config['random_seed'])
module_config = {'data_dir': config['data_dir'],
'random_seed': config['random_seed']}
......
import sys
import numpy as np
import ANN_Training
from ANN_Training import *
......@@ -9,8 +8,6 @@ configfile: 'config.yaml'
DIR = config['data_dir']
MODELS = config['models']
if config['random_seed'] is not None:
np.random.seed(config['random_seed'])
rule all:
input:
......@@ -30,15 +27,9 @@ rule plot_test_results:
log:
DIR+'/log/plot_test_results.log'
run:
models = {}
with open(str(log), 'w') as logfile:
sys.stdout = logfile
sys.stderr = logfile
for model in MODELS:
trainer= ANN_Training.ModelTrainer(
{'model_name': model, 'dir': DIR, 'model_dir': DIR,
**MODELS[model]})
models[model] = trainer
plot_evaluation_results(evaluation_file=input.json_file,
directory=DIR, colors=params.colors)
......@@ -52,7 +43,8 @@ rule test_model:
protected(DIR+'/model evaluation/'+'_'.join(MODELS.keys())+'.json')
params:
num_iterations = config['num_iterations'],
compare_normalization = config['compare_normalization']
compare_normalization = config['compare_normalization'],
random_seed = config['random_seed']
log:
DIR+'/log/test_model.log'
run:
......@@ -63,6 +55,7 @@ rule test_model:
for model in MODELS:
trainer= ANN_Training.ModelTrainer(
{'model_name': model, 'dir': DIR, 'model_dir': DIR,
'random_seed': params.random_seed,
**MODELS[model]})
models[model] = trainer
evaluate_models(models=models, directory=DIR,
......@@ -77,6 +70,8 @@ rule train_model:
output:
protected(DIR+'/trained models/{model}.model.pt'),
protected(DIR+'/trained models/{model}.loss.pt')
params:
random_seed = config['random_seed']
log:
DIR+'/log/train_model_{model}.log'
run:
......@@ -84,6 +79,7 @@ rule train_model:
sys.stdout = logfile
training_data = read_training_data(DIR)
trainer= ANN_Training.ModelTrainer(
config={**MODELS[wildcards.model]})
config={'random_seed': params.random_seed,
**MODELS[wildcards.model]})
trainer.epoch_training(dataset=training_data)
trainer.save_model(directory=DIR, model_name=wildcards.model)
import sys
import time
import numpy as np
import Initial_Condition
import Quadrature
from DG_Approximation import DGScheme
from Plotting import plot_approximation_results
from projection_utils import Mesh
configfile: 'config.yaml'
PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube']
DIR = config['data_dir']
SCHEMES = config['schemes']
if config['random_seed'] is not None:
np.random.seed(config['random_seed'])
rule all:
input:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment