Skip to content
Snippets Groups Projects
Commit bfcddd48 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Outsourced run command for SM rule 'generate_data' into script.

parent 7dc4358d
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""Script to generate training data for ANN models.
@author: Laura C. Kühle
"""
import sys
import numpy as np
from tcd import Initial_Condition
from tcd import ANN_Data_Generator
def main() -> None:
"""Generate training data for ANN models."""
# Set random seed if given
if snakemake.params['seed'] is not None:
np.random.seed(snakemake.params['seed'])
# Determine list of initial conditions
init_cond_list = []
for function in snakemake.params['functions']:
init_cond_list.append({
'function': getattr(Initial_Condition, function)({}),
'config': {} if snakemake.config['functions'][function] is None
else snakemake.config['functions'][function]
})
# Generate training data
with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
sys.stdout = logfile
generator = ANN_Data_Generator.TrainingDataGenerator()
_ = generator.build_training_data(
balance=snakemake.params['balance'],
init_cond_list=init_cond_list,
directory=snakemake.params['directory'],
num_samples=snakemake.params['sample_number'],
add_reconstructions=snakemake.params['reconstruction_flag'],
stencil_len=snakemake.params['stencil_len'])
if __name__ == '__main__':
if "snakemake" in locals():
main()
else:
print('Not Defined.')
import sys
import time
import numpy as np
from scripts.tcd import Initial_Condition
from scripts.tcd import ANN_Data_Generator
configfile: 'config.yaml' configfile: 'config.yaml'
DIR = config['data_dir'] DIR = config['data_dir']
if config['random_seed'] is not None:
np.random.seed(config['random_seed'])
rule generate_data: rule generate_data:
output: output:
...@@ -22,23 +15,10 @@ rule generate_data: ...@@ -22,23 +15,10 @@ rule generate_data:
stencil_len = config['stencil_length'], stencil_len = config['stencil_length'],
sample_number = config['sample_number'], sample_number = config['sample_number'],
reconstruction_flag = config['add_reconstructions'], reconstruction_flag = config['add_reconstructions'],
directory = DIR,
seed = config['random_seed'],
functions = expand('{FUNCTION}', FUNCTION=config['functions']) functions = expand('{FUNCTION}', FUNCTION=config['functions'])
log: log:
DIR+'/log/generate_data.log' DIR+'/log/generate_data.log'
run: script:
init_cond_list = [] '../scripts/generate_data.py'
for function in params.functions:
init_cond_list.append({
'function': getattr(Initial_Condition, function)({}),
'config': {} if config['functions'][function] is None
else config['functions'][function]
})
with open(str(log), 'w') as logfile:
sys.stdout = logfile
generator = ANN_Data_Generator.TrainingDataGenerator()
data = generator.build_training_data(balance=params.balance,
init_cond_list=init_cond_list, directory=DIR,
num_samples=params.sample_number,
add_reconstructions=params.reconstruction_flag,
stencil_len=params.stencil_len)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment