Skip to content
Snippets Groups Projects
Select Git revision
  • bfcddd48ee4e10f2c50e17b4cc7f2a56e256fcf6
  • master default protected
2 results

generate_data.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    generate_data.py 1.42 KiB
    # -*- coding: utf-8 -*-
    """Script to generate training data for ANN models.
    
    @author: Laura C. Kühle
    """
    import sys
    import numpy as np
    
    from tcd import Initial_Condition
    from tcd import ANN_Data_Generator
    
    
    def main() -> None:
        """Generate training data for ANN models."""
        # Set random seed if given
        if snakemake.params['seed'] is not None:
            np.random.seed(snakemake.params['seed'])
    
        # Determine list of initial conditions
        init_cond_list = []
        for function in snakemake.params['functions']:
            init_cond_list.append({
                'function': getattr(Initial_Condition, function)({}),
                'config': {} if snakemake.config['functions'][function] is None
                else snakemake.config['functions'][function]
            })
    
        # Generate training data
        with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
            sys.stdout = logfile
            generator = ANN_Data_Generator.TrainingDataGenerator()
            _ = generator.build_training_data(
                balance=snakemake.params['balance'],
                init_cond_list=init_cond_list,
                directory=snakemake.params['directory'],
                num_samples=snakemake.params['sample_number'],
                add_reconstructions=snakemake.params['reconstruction_flag'],
                stencil_len=snakemake.params['stencil_len'])
    
    
    if __name__ == '__main__':
        if "snakemake" in locals():
            main()
        else:
            print('Not Defined.')