diff --git a/scripts/plot_test_results.py b/scripts/plot_test_results.py new file mode 100644 index 0000000000000000000000000000000000000000..53eb6d0e78b039a7e1ecd513f0c14a236bec2158 --- /dev/null +++ b/scripts/plot_test_results.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Script to plot results of ANN model testing. + +@author: Laura C. Kühle +""" +import sys + +from tcd.Plotting import plot_evaluation_results + + +def main() -> None: + """Plot results of ANN model tests.""" + with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: + sys.stdout = logfile + sys.stderr = logfile + + # Plot evaluation results + plot_evaluation_results(evaluation_file=snakemake.input['json_file'], + directory=snakemake.params['directory'], + colors=snakemake.params['colors']) + + +if __name__ == '__main__': + if "snakemake" in locals(): + main() + else: + print('Not Defined.') diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index 7cdaa9da51c053021e45bf9e9d5799187ca6c762..f75998ec7431843eced7eeeea772ce4cf59f7413 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,12 +1,10 @@ -import sys - -from scripts.tcd.Plotting import plot_evaluation_results - configfile: 'config.yaml' + DIR = config['data_dir'] MODELS = config['models'] + rule all: input: expand(DIR+'/trained models/{model}.model.pt', model=MODELS), @@ -14,22 +12,20 @@ rule all: +'.barplot.pdf' default_target: True + rule plot_test_results: input: - json_file=DIR+'/model evaluation/'+ '_'.join(MODELS.keys()) + '.json' + json_file = DIR+'/model evaluation/'+'_'.join(MODELS.keys()) + '.json' output: DIR+'/model evaluation/'+'_'.join(MODELS.keys()) +'.barplot.pdf' params: - colors = config['classification_colors'] + colors = config['classification_colors'], + directory = DIR log: DIR+'/log/plot_test_results.log' - run: - with open(str(log), 'w') as logfile: - sys.stdout = logfile - sys.stderr = logfile - plot_evaluation_results(evaluation_file=input.json_file, - directory=DIR, colors=params.colors) + script: + '../scripts/plot_test_results.py' rule test_model: