Select Git revision
approximation.smk
-
Laura Christine Kühle authoredLaura Christine Kühle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
approximation.smk 3.05 KiB
import sys
import time
import Initial_Condition
import Quadrature
from DG_Approximation import DGScheme
from Plotting import plot_approximation_results
configfile: 'config.yaml'
PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube']
DIR = config['data_dir']
SCHEMES = config['schemes']
rule all:
input:
expand(config['plot_dir'] + '/{plot}/{scheme}.pdf', plot=PLOTS,
scheme=SCHEMES)
default_target: True
def get_ANN_model(wildcards):
if config['schemes'][wildcards.scheme]['detector'] == \
'ArtificialNeuralNetwork':
return DIR + '/trained models/' + config['schemes'][
wildcards.scheme]['detector_config']['model_state']
return []
rule approximate_solution:
input:
get_ANN_model
output:
config['plot_dir']+'/{scheme}.json'
params:
dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
plot_dir=config['plot_dir']
log:
DIR+'/log/approximate_solution_{scheme}.log'
run:
with open(str(log), 'w') as logfile:
sys.stdout = logfile
sys.stderr = logfile
tic = time.perf_counter()
if len(input) > 0:
params.dg_params['detector_config']['model_state'] = input
print(params.dg_params)
dg_scheme = DGScheme(**params.dg_params)
dg_scheme.approximate(
data_file=params.plot_dir+'/'+wildcards.scheme)
toc = time.perf_counter()
print(f'Time: {toc - tic:0.4f}s')
rule plot_approximation_results:
input:
config['plot_dir']+'/{scheme}.json'
output:
expand(config['plot_dir'] + '/{plot}/{{scheme}}.pdf', plot=PLOTS)
params:
dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
plot_dir=config['plot_dir']
log:
DIR+'/log/plot_approximation_results_{scheme}.log'
run:
with open(str(log), 'w') as logfile:
sys.stdout = logfile
sys.stderr = logfile
tic = time.perf_counter()
if config['schemes'][wildcards.scheme]['detector'] == \
'ArtificialNeuralNetwork':
params.dg_params['detector_config']['model_state'] = \
DIR + '/trained models/' + config['schemes'][
wildcards.scheme]['detector_config']['model_state']
detector_dict = params.dg_params.copy()
init_cond = getattr(Initial_Condition, detector_dict.pop(
'init_cond', 'Sine'))(
config=detector_dict.pop('init_config', {}))
quadrature = getattr(Quadrature, detector_dict.pop(
'quadrature', 'Gauss'))(detector_dict.pop(
'quadrature_config', {}))
plot_approximation_results(directory=params.plot_dir,
plot_name=wildcards.scheme,
data_file=params.plot_dir+'/'+wildcards.scheme,
quadrature=quadrature, init_cond=init_cond)
toc = time.perf_counter()
print(f'Time: {toc - tic:0.4f}s')