Skip to content
Snippets Groups Projects
Commit 05447a99 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Cleaned up Snakefile.

parent 108a3943
No related branches found
No related tags found
No related merge requests found
configfile: 'config.yaml'
import ANN_Data_Generator, Initial_Condition, ANN_Training import ANN_Data_Generator, Initial_Condition, ANN_Training
from ANN_Training import evaluate_models from ANN_Training import evaluate_models
import numpy as np import numpy as np
def replace_none(list): configfile: 'config.yaml'
return {} if list is None else list
DIR = config['data_directory'] DIR = config['data_directory']
MODELS = config['models'] MODELS = config['models']
if config['random_seed'] is not None: if config['random_seed'] is not None:
np.random.seed(config['random_seed']) np.random.seed(config['random_seed'])
...@@ -34,11 +31,13 @@ rule test_model: ...@@ -34,11 +31,13 @@ rule test_model:
trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR, trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR,
'model_dir': DIR, **MODELS[model]}) 'model_dir': DIR, **MODELS[model]})
models[model] = trainer models[model] = trainer
evaluate_models(models, DIR, 2, params.colors, params.compare_normalization) evaluate_models(models=models, directory=DIR, num_iterations=2, colors=params.colors,
compare_normalization=params.compare_normalization)
rule generate_data: rule generate_data:
output: output:
DIR+'/input_data.npy', DIR+'/input_data.npy',
DIR+'/normalized_input_data.npy',
DIR+'/output_data.npy' DIR+'/output_data.npy'
params: params:
left_bound = config['left_boundary'], left_bound = config['left_boundary'],
...@@ -55,28 +54,26 @@ rule generate_data: ...@@ -55,28 +54,26 @@ rule generate_data:
initial_conditions.append({ initial_conditions.append({
'function': getattr(Initial_Condition, function)( 'function': getattr(Initial_Condition, function)(
params.left_bound, params.right_bound, {}), params.left_bound, params.right_bound, {}),
'config': replace_none(config['functions'][function])}) 'config': {} if function in config['functions'] else config['functions'][function]
})
generator = ANN_Data_Generator.TrainingDataGenerator(initial_conditions, generator = ANN_Data_Generator.TrainingDataGenerator(initial_conditions=initial_conditions,
left_bound=params.left_bound, right_bound=params.right_bound, balance=params.balance, left_bound=params.left_bound, right_bound=params.right_bound, balance=params.balance,
stencil_length=params.stencil_length, directory=DIR) stencil_length=params.stencil_length, directory=DIR)
data = generator.build_training_data(params.sample_number) data = generator.build_training_data(num_samples=params.sample_number)
# print(data[0])
rule train_model: rule train_model:
input: input:
DIR+'/input_data.npy', DIR+'/input_data.npy',
DIR+'/normalized_input_data.npy',
DIR+'/output_data.npy' DIR+'/output_data.npy'
# params:
# # parameter = MODELS[{model}],
# lambda wildcards: [MODELS[wildcards.model]]
log: log:
DIR+'/log/train_model_{model}.log' DIR+'/log/train_model_{model}.log'
output: output:
DIR+'/trained models/model__{model}.pt', DIR+'/trained models/model__{model}.pt',
DIR+'/trained models/loss__{model}.pt' DIR+'/trained models/loss__{model}.pt'
run: run:
trainer= ANN_Training.ModelTrainer({'model_name': wildcards.model, 'dir': DIR, trainer= ANN_Training.ModelTrainer(config={'model_name': wildcards.model, 'dir': DIR,
'model_dir': DIR, **MODELS[wildcards.model]}) 'model_dir': DIR, **MODELS[wildcards.model]})
trainer.epoch_training() trainer.epoch_training()
trainer.save_model() trainer.save_model()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment