Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
1 result

approximation.smk

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    approximation.smk 3.05 KiB
    import sys
    import time
    
    import Initial_Condition
    import Quadrature
    from DG_Approximation import DGScheme
    from Plotting import plot_approximation_results
    
    configfile: 'config.yaml'
    
    PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube']
    DIR = config['data_dir']
    SCHEMES = config['schemes']
    
    rule all:
        input:
            expand(config['plot_dir'] + '/{plot}/{scheme}.pdf', plot=PLOTS,
                scheme=SCHEMES)
        default_target: True
    
    def get_ANN_model(wildcards):
        if config['schemes'][wildcards.scheme]['detector'] == \
                'ArtificialNeuralNetwork':
            return DIR + '/trained models/' + config['schemes'][
                wildcards.scheme]['detector_config']['model_state']
        return []
    
    rule approximate_solution:
        input:
            get_ANN_model
        output:
            config['plot_dir']+'/{scheme}.json'
        params:
            dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
            plot_dir=config['plot_dir']
        log:
            DIR+'/log/approximate_solution_{scheme}.log'
        run:
            with open(str(log), 'w') as logfile:
                sys.stdout = logfile
                sys.stderr = logfile
    
                tic = time.perf_counter()
    
                if len(input) > 0:
                    params.dg_params['detector_config']['model_state'] = input
    
                print(params.dg_params)
                dg_scheme = DGScheme(**params.dg_params)
    
                dg_scheme.approximate(
                    data_file=params.plot_dir+'/'+wildcards.scheme)
    
                toc = time.perf_counter()
                print(f'Time: {toc - tic:0.4f}s')
    
    rule plot_approximation_results:
        input:
            config['plot_dir']+'/{scheme}.json'
        output:
            expand(config['plot_dir'] + '/{plot}/{{scheme}}.pdf', plot=PLOTS)
        params:
            dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
            plot_dir=config['plot_dir']
        log:
            DIR+'/log/plot_approximation_results_{scheme}.log'
        run:
            with open(str(log), 'w') as logfile:
                sys.stdout = logfile
                sys.stderr = logfile
    
                tic = time.perf_counter()
    
                if config['schemes'][wildcards.scheme]['detector'] == \
                        'ArtificialNeuralNetwork':
                    params.dg_params['detector_config']['model_state'] = \
                        DIR + '/trained models/' + config['schemes'][
                            wildcards.scheme]['detector_config']['model_state']
    
                detector_dict = params.dg_params.copy()
    
                init_cond = getattr(Initial_Condition, detector_dict.pop(
                    'init_cond', 'Sine'))(
                    config=detector_dict.pop('init_config', {}))
                quadrature = getattr(Quadrature, detector_dict.pop(
                    'quadrature', 'Gauss'))(detector_dict.pop(
                    'quadrature_config', {}))
    
                plot_approximation_results(directory=params.plot_dir,
                    plot_name=wildcards.scheme,
                    data_file=params.plot_dir+'/'+wildcards.scheme,
                    quadrature=quadrature, init_cond=init_cond)
    
                toc = time.perf_counter()
                print(f'Time: {toc - tic:0.4f}s')