# -*- coding: utf-8 -*- """Script to train ANN models. @author: Laura C. Kühle """ import sys from tcd.ANN_Training import ModelTrainer, read_training_data def main() -> None: """Train ANN models.""" with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: sys.stdout = logfile # Read training data training_data = read_training_data(snakemake.params['directory']) # Train model trainer = ModelTrainer(config={ 'random_seed': snakemake.params['random_seed'], **snakemake.params['models'][snakemake.wildcards['model']]}) trainer.epoch_training(dataset=training_data) # Save trained model trainer.save_model(directory=snakemake.params['directory'], model_name=snakemake.wildcards['model']) if __name__ == '__main__': if "snakemake" in locals(): main() else: print('Not Defined.')