Skip to content
Snippets Groups Projects
Commit 5effb4e8 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Outsourced run command for SM rule 'train_model' into script.

parent bfcddd48
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""Script to train ANN models.
@author: Laura C. Kühle
"""
import sys
from tcd.ANN_Training import ModelTrainer, read_training_data
def main() -> None:
"""Train ANN models."""
with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
sys.stdout = logfile
# Read training data
training_data = read_training_data(snakemake.params['directory'])
# Train model
trainer = ModelTrainer(config={
'random_seed': snakemake.params['random_seed'],
**snakemake.params['models'][snakemake.wildcards['model']]})
trainer.epoch_training(dataset=training_data)
# Save trained model
trainer.save_model(directory=snakemake.params['directory'],
model_name=snakemake.wildcards['model'])
if __name__ == '__main__':
if "snakemake" in locals():
main()
else:
print('Not Defined.')
......@@ -71,15 +71,10 @@ rule train_model:
protected(DIR+'/trained models/{model}.model.pt'),
protected(DIR+'/trained models/{model}.loss.pt')
params:
random_seed = config['random_seed']
random_seed = config['random_seed'],
directory = DIR,
models = config['models']
log:
DIR+'/log/train_model/{model}.log'
run:
with open(str(log), 'w') as logfile:
sys.stdout = logfile
training_data = read_training_data(DIR)
trainer= ANN_Training.ModelTrainer(
config={'random_seed': params.random_seed,
**MODELS[wildcards.model]})
trainer.epoch_training(dataset=training_data)
trainer.save_model(directory=DIR, model_name=wildcards.model)
script:
'../scripts/train_model.py'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment