# -*- coding: utf-8 -*-
"""Script to train ANN models.

@author: Laura C. Kühle
"""
import sys

from tcd.ANN_Training import ModelTrainer, read_training_data


def main() -> None:
    """Train ANN models."""
    with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
        sys.stdout = logfile

        # Read training data
        training_data = read_training_data(snakemake.params['directory'])

        # Train model
        trainer = ModelTrainer(config={
            'random_seed': snakemake.params['random_seed'],
            **snakemake.params['models'][snakemake.wildcards['model']]})
        trainer.epoch_training(dataset=training_data)

        # Save trained model
        trainer.save_model(directory=snakemake.params['directory'],
                           model_name=snakemake.wildcards['model'])


if __name__ == '__main__':
    if "snakemake" in locals():
        main()
    else:
        print('Not Defined.')