Skip to content
Snippets Groups Projects
Select Git revision
  • 6ff8cc191479c5cdd2f22922d7eee4ae29247a35
  • develop default protected
  • master protected
  • rodin2
  • rodin3
  • feature/theory_plugin
  • feature/multiview
  • csp
  • feature/newcore
  • feature/csp
  • 3.0.11
  • 3.0.8
  • 3.0.5
  • 2.4.1
  • 2.3.3
  • 2.3.2
  • 2.3.1
  • 2.3.0_fix1
  • 2.3.0
19 results

build.gradle

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_data.smk 1.75 KiB
    import sys
    import time
    import numpy as np
    
    import ANN_Data_Generator, Initial_Condition
    
    configfile: 'config.yaml'
    
    DIR = config['data_dir']
    if config['random_seed'] is not None:
        np.random.seed(config['random_seed'])
    
    rule generate_data:
        output:
            protected(DIR+'/input_data.raw.npy'),
            protected(DIR+'/input_data.normalized.npy'),
            protected(DIR+'/output_data.npy')
        default_target: True
        params:
            left_bound = config['left_boundary'],
            right_bound = config['right_boundary'],
            balance = config['smooth_troubled_balance'],
            stencil_length = config['stencil_length'],
            sample_number = config['sample_number'],
            reconstruction_flag = config['add_reconstructions'],
            functions = expand('{FUNCTION}', FUNCTION=config['functions'])
        log:
            DIR+'/log/generate_data.log'
        run:
            initial_conditions = []
            for function in params.functions:
                initial_conditions.append({
                    'function': getattr(Initial_Condition, function)(
                        params.left_bound, params.right_bound, {}),
                    'config': {} if function in config['functions']
                    else config['functions'][function]
                })
    
            with open(str(log), 'w') as logfile:
                sys.stdout = logfile
                generator = ANN_Data_Generator.TrainingDataGenerator(
                    initial_conditions=initial_conditions,
                    left_bound=params.left_bound, right_bound=params.right_bound,
                    balance=params.balance,
                    stencil_length=params.stencil_length, directory=DIR,
                    add_reconstructions=params.reconstruction_flag)
                data = generator.build_training_data(
                    num_samples=params.sample_number)