Skip to content
Snippets Groups Projects
Commit 1aa08945 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Separated model training in Snakefile by using wildcards.

parent 36a09681
Branches
Tags 1.0.5 v1.0.5
No related merge requests found
......@@ -66,17 +66,16 @@ rule train_model:
input:
DIR+'/input_data.npy',
DIR+'/output_data.npy'
params:
models = MODELS
# params:
# # parameter = MODELS[{model}],
# lambda wildcards: [MODELS[wildcards.model]]
log:
DIR+'/log/train_model.log'
DIR+'/log/train_model_{model}.log'
output:
expand(DIR+'/trained models/model__{model}.pt', model=MODELS),
expand(DIR+'/trained models/loss__{model}.pt', model=MODELS)
DIR+'/trained models/model__{model}.pt',
DIR+'/trained models/loss__{model}.pt'
run:
for model in params.models:
print(model)
trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR,
'model_dir': DIR, **params.models[model]})
trainer= ANN_Training.ModelTrainer({'model_name': wildcards.model, 'dir': DIR,
'model_dir': DIR, **MODELS[wildcards.model]})
trainer.epoch_training()
trainer.save_model()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment