diff --git a/scripts/test_model.py b/scripts/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4920b9a87777f55cb929e25c6320d3c2650052 --- /dev/null +++ b/scripts/test_model.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +"""Script to test ANN models. + +@author: Laura C. Kühle +""" +import sys + +from tcd.ANN_Training import ModelTrainer, evaluate_models + + +def main() -> None: + """Test ANN models.""" + models = {} + with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: + sys.stdout = logfile + sys.stderr = logfile + + # Initialize models to be evaluated + for model in snakemake.params['models']: + trainer = ModelTrainer( + {'model_name': model, 'dir': snakemake.params['directory'], + 'model_dir': snakemake.params['directory'], + 'random_seed': snakemake.params['random_seed'], + **snakemake.params['models'][model]}) + models[model] = trainer + + # Evaluate models + evaluate_models( + models=models, directory=snakemake.params['directory'], + num_iterations=snakemake.params['num_iterations'], + compare_normalization=snakemake.params['compare_normalization']) + + +if __name__ == '__main__': + if "snakemake" in locals(): + main() + else: + print('Not Defined.') diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index fd03f39fb10ddbe6d638ebcff67100958220927e..7cdaa9da51c053021e45bf9e9d5799187ca6c762 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,7 +1,5 @@ import sys -from scripts.tcd import ANN_Training -from scripts.tcd.ANN_Training import * from scripts.tcd.Plotting import plot_evaluation_results configfile: 'config.yaml' @@ -44,23 +42,14 @@ rule test_model: params: num_iterations = config['num_iterations'], compare_normalization = config['compare_normalization'], - random_seed = config['random_seed'] + random_seed = config['random_seed'], + directory = DIR, + models = config['models'] log: DIR+'/log/test_model.log' - run: - models = {} - with open(str(log), 'w') as logfile: - sys.stdout = logfile - sys.stderr = logfile - for model in MODELS: - trainer= ANN_Training.ModelTrainer( - {'model_name': model, 'dir': DIR, 'model_dir': DIR, - 'random_seed': params.random_seed, - **MODELS[model]}) - models[model] = trainer - evaluate_models(models=models, directory=DIR, - num_iterations=params.num_iterations, - compare_normalization=params.compare_normalization) + script: + '../scripts/test_model.py' + rule train_model: input: