Select Git revision
generate_data.py
Laura Christine Kühle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
generate_data.py 1.42 KiB
# -*- coding: utf-8 -*-
"""Script to generate training data for ANN models.
@author: Laura C. Kühle
"""
import sys
import numpy as np
from tcd import Initial_Condition
from tcd import ANN_Data_Generator
def main() -> None:
"""Generate training data for ANN models."""
# Set random seed if given
if snakemake.params['seed'] is not None:
np.random.seed(snakemake.params['seed'])
# Determine list of initial conditions
init_cond_list = []
for function in snakemake.params['functions']:
init_cond_list.append({
'function': getattr(Initial_Condition, function)({}),
'config': {} if snakemake.config['functions'][function] is None
else snakemake.config['functions'][function]
})
# Generate training data
with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
sys.stdout = logfile
generator = ANN_Data_Generator.TrainingDataGenerator()
_ = generator.build_training_data(
balance=snakemake.params['balance'],
init_cond_list=init_cond_list,
directory=snakemake.params['directory'],
num_samples=snakemake.params['sample_number'],
add_reconstructions=snakemake.params['reconstruction_flag'],
stencil_len=snakemake.params['stencil_len'])
if __name__ == '__main__':
if "snakemake" in locals():
main()
else:
print('Not Defined.')