import sys import ANN_Training from ANN_Training import * from Plotting import plot_evaluation_results configfile: 'config.yaml' DIR = config['data_dir'] MODELS = config['models'] rule all: input: expand(DIR+'/trained models/{model}.model.pt', model=MODELS), DIR+'/model evaluation/'+'_'.join(MODELS.keys()) +'.barplot.pdf' default_target: True rule plot_test_results: input: json_file=DIR+'/model evaluation/'+ '_'.join(MODELS.keys()) + '.json' output: DIR+'/model evaluation/'+'_'.join(MODELS.keys()) +'.barplot.pdf' params: colors = config['classification_colors'] log: DIR+'/log/plot_test_results.log' run: with open(str(log), 'w') as logfile: sys.stdout = logfile sys.stderr = logfile plot_evaluation_results(evaluation_file=input.json_file, directory=DIR, colors=params.colors) rule test_model: input: DIR+'/input_data.raw.npy', DIR+'/input_data.normalized.npy', DIR+'/output_data.npy' output: protected(DIR+'/model evaluation/'+'_'.join(MODELS.keys())+'.json') params: num_iterations = config['num_iterations'], compare_normalization = config['compare_normalization'], random_seed = config['random_seed'] log: DIR+'/log/test_model.log' run: models = {} with open(str(log), 'w') as logfile: sys.stdout = logfile sys.stderr = logfile for model in MODELS: trainer= ANN_Training.ModelTrainer( {'model_name': model, 'dir': DIR, 'model_dir': DIR, 'random_seed': params.random_seed, **MODELS[model]}) models[model] = trainer evaluate_models(models=models, directory=DIR, num_iterations=params.num_iterations, compare_normalization=params.compare_normalization) rule train_model: input: DIR+'/input_data.raw.npy', DIR+'/input_data.normalized.npy', DIR+'/output_data.npy' output: protected(DIR+'/trained models/{model}.model.pt'), protected(DIR+'/trained models/{model}.loss.pt') params: random_seed = config['random_seed'] log: DIR+'/log/train_model_{model}.log' run: with open(str(log), 'w') as logfile: sys.stdout = logfile training_data = read_training_data(DIR) trainer= ANN_Training.ModelTrainer( config={'random_seed': params.random_seed, **MODELS[wildcards.model]}) trainer.epoch_training(dataset=training_data) trainer.save_model(directory=DIR, model_name=wildcards.model)