Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
1 result

ANN_data.smk

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_data.smk 1.67 KiB
    import sys
    import time
    import numpy as np
    
    import ANN_Data_Generator, Initial_Condition
    
    configfile: 'config.yaml'
    
    DIR = config['data_dir']
    if config['random_seed'] is not None:
        np.random.seed(config['random_seed'])
    
    rule generate_data:
        output:
            protected(DIR+'/input_data.raw.npy'),
            protected(DIR+'/input_data.normalized.npy'),
            protected(DIR+'/output_data.npy')
        default_target: True
        params:
            left_bound = config['left_boundary'],
            right_bound = config['right_boundary'],
            balance = config['smooth_troubled_balance'],
            stencil_length = config['stencil_length'],
            sample_number = config['sample_number'],
            reconstruction_flag = config['add_reconstructions'],
            functions = expand('{FUNCTION}', FUNCTION=config['functions'])
        log:
            DIR+'/log/generate_data.log'
        run:
            initial_conditions = []
            for function in params.functions:
                initial_conditions.append({
                    'function': getattr(Initial_Condition, function)({}),
                    'config': {} if function in config['functions']
                    else config['functions'][function]
                })
    
            with open(str(log), 'w') as logfile:
                sys.stdout = logfile
                generator = ANN_Data_Generator.TrainingDataGenerator(
                    left_bound=params.left_bound, right_bound=params.right_bound)
                data = generator.build_training_data(balance=params.balance,
                    initial_conditions=initial_conditions, directory=DIR,
                    num_samples=params.sample_number,
                    add_reconstructions=params.reconstruction_flag,
                    stencil_length=params.stencil_length)