diff --git a/scripts/train_model.py b/scripts/train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..492d0f72ea8ee187bb6f2f155781d028e52772c5 --- /dev/null +++ b/scripts/train_model.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +"""Script to train ANN models. + +@author: Laura C. Kühle +""" +import sys + +from tcd.ANN_Training import ModelTrainer, read_training_data + + +def main() -> None: + """Train ANN models.""" + with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: + sys.stdout = logfile + + # Read training data + training_data = read_training_data(snakemake.params['directory']) + + # Train model + trainer = ModelTrainer(config={ + 'random_seed': snakemake.params['random_seed'], + **snakemake.params['models'][snakemake.wildcards['model']]}) + trainer.epoch_training(dataset=training_data) + + # Save trained model + trainer.save_model(directory=snakemake.params['directory'], + model_name=snakemake.wildcards['model']) + + +if __name__ == '__main__': + if "snakemake" in locals(): + main() + else: + print('Not Defined.') diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index 47c3dc217ddf692de260a6895744e17df9400b8e..fd03f39fb10ddbe6d638ebcff67100958220927e 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -71,15 +71,10 @@ rule train_model: protected(DIR+'/trained models/{model}.model.pt'), protected(DIR+'/trained models/{model}.loss.pt') params: - random_seed = config['random_seed'] + random_seed = config['random_seed'], + directory = DIR, + models = config['models'] log: DIR+'/log/train_model/{model}.log' - run: - with open(str(log), 'w') as logfile: - sys.stdout = logfile - training_data = read_training_data(DIR) - trainer= ANN_Training.ModelTrainer( - config={'random_seed': params.random_seed, - **MODELS[wildcards.model]}) - trainer.epoch_training(dataset=training_data) - trainer.save_model(directory=DIR, model_name=wildcards.model) + script: + '../scripts/train_model.py'