From be18f1bea6c1f5c6c36b6e68f1be22d34e6006b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Wed, 19 Jan 2022 15:48:17 +0100 Subject: [PATCH] Generalize output for 'approximate_solution' rule. --- DG_Approximation.py | 10 ++-------- Snakefile | 36 ++++++++---------------------------- config.yaml | 3 ++- 3 files changed, 12 insertions(+), 37 deletions(-) diff --git a/DG_Approximation.py b/DG_Approximation.py index 5615d19..6336f78 100644 --- a/DG_Approximation.py +++ b/DG_Approximation.py @@ -206,19 +206,13 @@ class DGScheme(object): # Plot exact/approximate results, errors, shock tubes and any detector-dependant plots self._detector.plot_results(projection, troubled_cell_history, time_history) - def save_plots(self): + def save_plots(self, plot_name): """Saves plotted results. Sets plot directory, if not already existing, and saves plots generated during the last approximation. """ - name = self._init_cond.get_name() + '__' + self._detector.get_name() + '__' \ - + self._limiter.get_name() + '__' + self._update_scheme.get_name() + '__' \ - + self._quadrature.get_name() + '__final_time_' + str(self._final_time) \ - + '__wave_speed_' + str(self._wave_speed) + '__number_of_cells_' \ - + str(self._num_grid_cells) + '__polynomial_degree_' + str(self._polynomial_degree) - # Set paths for plot files if not existing already if not os.path.exists(self._plot_dir): os.makedirs(self._plot_dir) @@ -230,7 +224,7 @@ class DGScheme(object): os.makedirs(self._plot_dir + '/' + identifier) plt.figure(identifier) - plt.savefig(self._plot_dir + '/' + identifier + '/' + name + '.pdf') + plt.savefig(self._plot_dir + '/' + identifier + '/' + plot_name + '.pdf') def _reset(self): """Resets instance variables.""" diff --git a/Snakefile b/Snakefile index a51fd69..e9cc261 100644 --- a/Snakefile +++ b/Snakefile @@ -8,6 +8,7 @@ import numpy as np configfile: 'config.yaml' +PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube'] DIR = config['data_directory'] MODELS = config['models'] if config['random_seed'] is not None: @@ -17,38 +18,18 @@ rule all: input: expand(DIR+'/trained models/model__{model}.pt', model=MODELS), DIR+'/model evaluation/classification_accuracy/' + '_'.join(MODELS.keys()) + '.pdf', - config['dg_parameter']['plot_dir'] + '/error/' + 'Sine__ArtificialNeuralNetwork__' + - 'ModifiedMinMod0__SSPRK3__Gauss12__final_time_1__wave_speed_1__' + - 'number_of_cells_32__polynomial_degree_2.pdf' + expand(config['plot_dir'] + '/{plot}/' + config['plot_name'] + '.pdf', plot=PLOTS) rule approximate_solution: input: config['dg_parameter']['detector_config']['model_state'] if config['dg_parameter']['detector'] == 'ArtificialNeuralNetwork' else '' output: - error=config['dg_parameter']['plot_dir'] + '/error/' + 'Sine__ArtificialNeuralNetwork__' + - 'ModifiedMinMod0__SSPRK3__Gauss12__final_time_1__wave_speed_1__' + - 'number_of_cells_32__polynomial_degree_2.pdf' + expand(config['plot_dir'] + '/{plot}/' + config['plot_name'] + '.pdf', plot=PLOTS) params: - # plot_dir=config['plot_dir'], - # wave_speed=config['wave_speed'], - # polynomial_degree=config['polynomial_degree'], - # cfl_number=config['cfl_number'], - # num_grid_cells=config['num_grid_cells'], - # final_time=config['final_time'], - # left_bound=config['left_bound'], - # right_bound=config['right_bound'], - # verbose=config['verbose'], - # detector=config['detector'], - # detector_config=config['detector_config'], - # init_cond=config['init_cond'], - # init_config=config['init_config'], - # limiter=config['limiter'], - # limiter_config=config['limiter_config'], - # quadrature=config['quadrature'], - # quadrature_config=config['quadrature_config'], - # update_scheme=config['update_scheme'] - dg_params=config['dg_parameter'] + dg_params=config['dg_parameter'], + plot_dir=config['plot_dir'], + plot_name=config['plot_name'] log: DIR+'/log/approximate_solution.log' run: @@ -56,11 +37,10 @@ rule approximate_solution: tic = timeit.default_timer() print(params.dg_params) - # dg_scheme = DGScheme(**params) - dg_scheme = DGScheme(**params.dg_params) + dg_scheme = DGScheme(plot_dir=params.plot_dir, **params.dg_params) dg_scheme.approximate() - dg_scheme.save_plots() + dg_scheme.save_plots(params.plot_name) toc = timeit.default_timer() print('Time:',toc-tic) diff --git a/config.yaml b/config.yaml index 080ec40..6374b4c 100644 --- a/config.yaml +++ b/config.yaml @@ -2,8 +2,9 @@ data_directory: "Snakemake-Test" random_seed: 1234 # Parameter for Approximation with Troubled Cell Detection +plot_name: 'DG_Test' +plot_dir: 'testing' dg_parameter: - plot_dir: 'testing' wave_speed: 1 polynomial_degree: 2 cfl_number: 0.2 -- GitLab