Skip to content
Snippets Groups Projects
Commit 43bcfcc3 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Outsourced run command for SM rule 'test_model' into script.

parent 5effb4e8
Branches
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""Script to test ANN models.
@author: Laura C. Kühle
"""
import sys
from tcd.ANN_Training import ModelTrainer, evaluate_models
def main() -> None:
"""Test ANN models."""
models = {}
with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
sys.stdout = logfile
sys.stderr = logfile
# Initialize models to be evaluated
for model in snakemake.params['models']:
trainer = ModelTrainer(
{'model_name': model, 'dir': snakemake.params['directory'],
'model_dir': snakemake.params['directory'],
'random_seed': snakemake.params['random_seed'],
**snakemake.params['models'][model]})
models[model] = trainer
# Evaluate models
evaluate_models(
models=models, directory=snakemake.params['directory'],
num_iterations=snakemake.params['num_iterations'],
compare_normalization=snakemake.params['compare_normalization'])
if __name__ == '__main__':
if "snakemake" in locals():
main()
else:
print('Not Defined.')
import sys
from scripts.tcd import ANN_Training
from scripts.tcd.ANN_Training import *
from scripts.tcd.Plotting import plot_evaluation_results
configfile: 'config.yaml'
......@@ -44,23 +42,14 @@ rule test_model:
params:
num_iterations = config['num_iterations'],
compare_normalization = config['compare_normalization'],
random_seed = config['random_seed']
random_seed = config['random_seed'],
directory = DIR,
models = config['models']
log:
DIR+'/log/test_model.log'
run:
models = {}
with open(str(log), 'w') as logfile:
sys.stdout = logfile
sys.stderr = logfile
for model in MODELS:
trainer= ANN_Training.ModelTrainer(
{'model_name': model, 'dir': DIR, 'model_dir': DIR,
'random_seed': params.random_seed,
**MODELS[model]})
models[model] = trainer
evaluate_models(models=models, directory=DIR,
num_iterations=params.num_iterations,
compare_normalization=params.compare_normalization)
script:
'../scripts/test_model.py'
rule train_model:
input:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment