diff --git a/scripts/approximate_solution.py b/scripts/approximate_solution.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4086d30c92b1a6e29e1b91bde64708085a1143 --- /dev/null +++ b/scripts/approximate_solution.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Script to approximate solution with Discontinuous Galerkin. + +@author: Laura C. Kühle +""" +import sys +import time + +from tcd.DG_Approximation import DGScheme + + +def main() -> None: + """Approximate solution.""" + with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile: + sys.stdout = logfile + sys.stderr = logfile + + tic = time.perf_counter() + + if len(snakemake.input) > 0: + snakemake.params['dg_params']['detector_config']['model_state'] = \ + snakemake.input[0] + + print(snakemake.params['dg_params']) + dg_scheme = DGScheme(**snakemake.params['dg_params']) + + dg_scheme.approximate( + data_file=snakemake.params['plot_dir'] + '/' + + snakemake.wildcards['scheme']) + + toc = time.perf_counter() + print(f'Time: {toc-tic:0.4f}s') + + +if __name__ == '__main__': + if "snakemake" in locals(): + main() + else: + print('Not Defined.') diff --git a/workflows/approximation.smk b/workflows/approximation.smk index ad8d8c2bffeaa459daa4fedbcea233fc76e1b478..5613a69722915b2d547481457dedad8f5f4c79f4 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -3,21 +3,22 @@ import time from scripts.tcd import Initial_Condition from scripts.tcd import Quadrature -from scripts.tcd.DG_Approximation import DGScheme from scripts.tcd.Plotting import plot_approximation_results configfile: 'config.yaml' + PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube'] DIR = config['data_dir'] SCHEMES = config['schemes'] + rule all: input: - expand(DIR + '/fig/{plot}/{scheme}.pdf', plot=PLOTS, - scheme=SCHEMES) + expand(DIR + '/fig/{plot}/{scheme}.pdf', plot=PLOTS, scheme=SCHEMES) default_target: True + def get_ANN_model(wildcards): if config['schemes'][wildcards.scheme]['detector'] == \ 'ArtificialNeuralNetwork': @@ -25,34 +26,20 @@ def get_ANN_model(wildcards): wildcards.scheme]['detector_config']['model_state'] return [] + rule approximate_solution: input: get_ANN_model output: - DIR +'/fig/{scheme}.json' + DIR+'/fig/{scheme}.json' params: - dg_params=lambda wildcards: config['schemes'][wildcards.scheme], - plot_dir=DIR + '/fig' + dg_params = lambda wildcards: config['schemes'][wildcards.scheme], + plot_dir = DIR + '/fig' log: DIR+'/log/approximate_solution/{scheme}.log' - run: - with open(str(log), 'w') as logfile: - sys.stdout = logfile - sys.stderr = logfile - - tic = time.perf_counter() + script: + '../scripts/approximate_solution.py' - if len(input) > 0: - params.dg_params['detector_config']['model_state'] = input - - print(params.dg_params) - dg_scheme = DGScheme(**params.dg_params) - - dg_scheme.approximate( - data_file=params.plot_dir+'/'+wildcards.scheme) - - toc = time.perf_counter() - print(f'Time: {toc - tic:0.4f}s') rule plot_approximation_results: input: