import sys

import ANN_Training
from ANN_Training import *
from Plotting import plot_evaluation_results

configfile: 'config.yaml'

DIR = config['data_dir']
MODELS = config['models']

rule all:
    input:
        expand(DIR+'/trained models/{model}.model.pt', model=MODELS),
        DIR+'/model evaluation/'+'_'.join(MODELS.keys())
        +'.barplot.pdf'
    default_target: True

rule plot_test_results:
    input:
        json_file=DIR+'/model evaluation/'+ '_'.join(MODELS.keys()) + '.json'
    output:
        DIR+'/model evaluation/'+'_'.join(MODELS.keys())
        +'.barplot.pdf'
    params:
        colors = config['classification_colors']
    log:
        DIR+'/log/plot_test_results.log'
    run:
        with open(str(log), 'w') as logfile:
            sys.stdout = logfile
            sys.stderr = logfile
            plot_evaluation_results(evaluation_file=input.json_file,
                directory=DIR, colors=params.colors)


rule test_model:
    input:
        DIR+'/input_data.raw.npy',
        DIR+'/input_data.normalized.npy',
        DIR+'/output_data.npy'
    output:
        protected(DIR+'/model evaluation/'+'_'.join(MODELS.keys())+'.json')
    params:
        num_iterations = config['num_iterations'],
        compare_normalization = config['compare_normalization'],
        random_seed = config['random_seed']
    log:
        DIR+'/log/test_model.log'
    run:
        models = {}
        with open(str(log), 'w') as logfile:
            sys.stdout = logfile
            sys.stderr = logfile
            for model in MODELS:
                trainer= ANN_Training.ModelTrainer(
                    {'model_name': model, 'dir': DIR, 'model_dir': DIR,
                     'random_seed': params.random_seed,
                     **MODELS[model]})
                models[model] = trainer
            evaluate_models(models=models, directory=DIR,
                num_iterations=params.num_iterations,
                compare_normalization=params.compare_normalization)

rule train_model:
    input:
        DIR+'/input_data.raw.npy',
        DIR+'/input_data.normalized.npy',
        DIR+'/output_data.npy'
    output:
        protected(DIR+'/trained models/{model}.model.pt'),
        protected(DIR+'/trained models/{model}.loss.pt')
    params:
        random_seed = config['random_seed']
    log:
        DIR+'/log/train_model_{model}.log'
    run:
        with open(str(log), 'w') as logfile:
            sys.stdout = logfile
            training_data = read_training_data(DIR)
            trainer= ANN_Training.ModelTrainer(
                config={'random_seed': params.random_seed,
                        **MODELS[wildcards.model]})
            trainer.epoch_training(dataset=training_data)
            trainer.save_model(directory=DIR, model_name=wildcards.model)