From bfcddd48ee4e10f2c50e17b4cc7f2a56e256fcf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Wed, 5 Oct 2022 23:11:13 +0200 Subject: [PATCH] Outsourced run command for SM rule 'generate_data' into script. --- scripts/generate_data.py | 45 ++++++++++++++++++++++++++++++++++++++++ workflows/ANN_data.smk | 32 ++++++---------------------- 2 files changed, 51 insertions(+), 26 deletions(-) create mode 100644 scripts/generate_data.py diff --git a/scripts/generate_data.py b/scripts/generate_data.py new file mode 100644 index 0000000..427ec58 --- /dev/null +++ b/scripts/generate_data.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Script to generate training data for ANN models. + +@author: Laura C. Kühle +""" +import sys +import numpy as np + +from tcd import Initial_Condition +from tcd import ANN_Data_Generator + + +def main() -> None: + """Generate training data for ANN models.""" + # Set random seed if given + if snakemake.params['seed'] is not None: + np.random.seed(snakemake.params['seed']) + + # Determine list of initial conditions + init_cond_list = [] + for function in snakemake.params['functions']: + init_cond_list.append({ + 'function': getattr(Initial_Condition, function)({}), + 'config': {} if snakemake.config['functions'][function] is None + else snakemake.config['functions'][function] + }) + + # Generate training data + with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: + sys.stdout = logfile + generator = ANN_Data_Generator.TrainingDataGenerator() + _ = generator.build_training_data( + balance=snakemake.params['balance'], + init_cond_list=init_cond_list, + directory=snakemake.params['directory'], + num_samples=snakemake.params['sample_number'], + add_reconstructions=snakemake.params['reconstruction_flag'], + stencil_len=snakemake.params['stencil_len']) + + +if __name__ == '__main__': + if "snakemake" in locals(): + main() + else: + print('Not Defined.') diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index 9d38b01..d11164d 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -1,15 +1,8 @@ -import sys -import time -import numpy as np - -from scripts.tcd import Initial_Condition -from scripts.tcd import ANN_Data_Generator - configfile: 'config.yaml' + DIR = config['data_dir'] -if config['random_seed'] is not None: - np.random.seed(config['random_seed']) + rule generate_data: output: @@ -22,23 +15,10 @@ rule generate_data: stencil_len = config['stencil_length'], sample_number = config['sample_number'], reconstruction_flag = config['add_reconstructions'], + directory = DIR, + seed = config['random_seed'], functions = expand('{FUNCTION}', FUNCTION=config['functions']) log: DIR+'/log/generate_data.log' - run: - init_cond_list = [] - for function in params.functions: - init_cond_list.append({ - 'function': getattr(Initial_Condition, function)({}), - 'config': {} if config['functions'][function] is None - else config['functions'][function] - }) - - with open(str(log), 'w') as logfile: - sys.stdout = logfile - generator = ANN_Data_Generator.TrainingDataGenerator() - data = generator.build_training_data(balance=params.balance, - init_cond_list=init_cond_list, directory=DIR, - num_samples=params.sample_number, - add_reconstructions=params.reconstruction_flag, - stencil_len=params.stencil_len) \ No newline at end of file + script: + '../scripts/generate_data.py' -- GitLab