From be18f1bea6c1f5c6c36b6e68f1be22d34e6006b7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 19 Jan 2022 15:48:17 +0100
Subject: [PATCH] Generalize output for 'approximate_solution' rule.

---
 DG_Approximation.py | 10 ++--------
 Snakefile           | 36 ++++++++----------------------------
 config.yaml         |  3 ++-
 3 files changed, 12 insertions(+), 37 deletions(-)

diff --git a/DG_Approximation.py b/DG_Approximation.py
index 5615d19..6336f78 100644
--- a/DG_Approximation.py
+++ b/DG_Approximation.py
@@ -206,19 +206,13 @@ class DGScheme(object):
         # Plot exact/approximate results, errors, shock tubes and any detector-dependant plots
         self._detector.plot_results(projection, troubled_cell_history, time_history)
 
-    def save_plots(self):
+    def save_plots(self, plot_name):
         """Saves plotted results.
 
         Sets plot directory, if not already existing, and saves plots generated during the last
         approximation.
 
         """
-        name = self._init_cond.get_name() + '__' + self._detector.get_name() + '__' \
-            + self._limiter.get_name() + '__' + self._update_scheme.get_name() + '__' \
-            + self._quadrature.get_name() + '__final_time_' + str(self._final_time) \
-            + '__wave_speed_' + str(self._wave_speed) + '__number_of_cells_' \
-            + str(self._num_grid_cells) + '__polynomial_degree_' + str(self._polynomial_degree)
-
         # Set paths for plot files if not existing already
         if not os.path.exists(self._plot_dir):
             os.makedirs(self._plot_dir)
@@ -230,7 +224,7 @@ class DGScheme(object):
                 os.makedirs(self._plot_dir + '/' + identifier)
 
             plt.figure(identifier)
-            plt.savefig(self._plot_dir + '/' + identifier + '/' + name + '.pdf')
+            plt.savefig(self._plot_dir + '/' + identifier + '/' + plot_name + '.pdf')
 
     def _reset(self):
         """Resets instance variables."""
diff --git a/Snakefile b/Snakefile
index a51fd69..e9cc261 100644
--- a/Snakefile
+++ b/Snakefile
@@ -8,6 +8,7 @@ import numpy as np
 
 configfile: 'config.yaml'
 
+PLOTS = ['error', 'exact_and_approx', 'semilog_error', 'shock_tube']
 DIR = config['data_directory']
 MODELS = config['models']
 if config['random_seed'] is not None:
@@ -17,38 +18,18 @@ rule all:
     input:
         expand(DIR+'/trained models/model__{model}.pt', model=MODELS),
         DIR+'/model evaluation/classification_accuracy/' + '_'.join(MODELS.keys()) + '.pdf',
-        config['dg_parameter']['plot_dir'] + '/error/' + 'Sine__ArtificialNeuralNetwork__' +
-            'ModifiedMinMod0__SSPRK3__Gauss12__final_time_1__wave_speed_1__' +
-            'number_of_cells_32__polynomial_degree_2.pdf'
+        expand(config['plot_dir'] + '/{plot}/' + config['plot_name'] + '.pdf', plot=PLOTS)
 
 rule approximate_solution:
     input:
         config['dg_parameter']['detector_config']['model_state']
         if config['dg_parameter']['detector'] == 'ArtificialNeuralNetwork' else ''
     output:
-        error=config['dg_parameter']['plot_dir'] + '/error/' + 'Sine__ArtificialNeuralNetwork__' +
-              'ModifiedMinMod0__SSPRK3__Gauss12__final_time_1__wave_speed_1__' +
-              'number_of_cells_32__polynomial_degree_2.pdf'
+        expand(config['plot_dir'] + '/{plot}/' + config['plot_name'] + '.pdf', plot=PLOTS)
     params:
-        # plot_dir=config['plot_dir'],
-        # wave_speed=config['wave_speed'],
-        # polynomial_degree=config['polynomial_degree'],
-        # cfl_number=config['cfl_number'],
-        # num_grid_cells=config['num_grid_cells'],
-        # final_time=config['final_time'],
-        # left_bound=config['left_bound'],
-        # right_bound=config['right_bound'],
-        # verbose=config['verbose'],
-        # detector=config['detector'],
-        # detector_config=config['detector_config'],
-        # init_cond=config['init_cond'],
-        # init_config=config['init_config'],
-        # limiter=config['limiter'],
-        # limiter_config=config['limiter_config'],
-        # quadrature=config['quadrature'],
-        # quadrature_config=config['quadrature_config'],
-        # update_scheme=config['update_scheme']
-        dg_params=config['dg_parameter']
+        dg_params=config['dg_parameter'],
+        plot_dir=config['plot_dir'],
+        plot_name=config['plot_name']
     log:
         DIR+'/log/approximate_solution.log'
     run:
@@ -56,11 +37,10 @@ rule approximate_solution:
             tic = timeit.default_timer()
 
             print(params.dg_params)
-            # dg_scheme = DGScheme(**params)
-            dg_scheme = DGScheme(**params.dg_params)
+            dg_scheme = DGScheme(plot_dir=params.plot_dir, **params.dg_params)
 
             dg_scheme.approximate()
-            dg_scheme.save_plots()
+            dg_scheme.save_plots(params.plot_name)
 
             toc = timeit.default_timer()
             print('Time:',toc-tic)
diff --git a/config.yaml b/config.yaml
index 080ec40..6374b4c 100644
--- a/config.yaml
+++ b/config.yaml
@@ -2,8 +2,9 @@ data_directory: "Snakemake-Test"
 random_seed: 1234
 
 # Parameter for Approximation with Troubled Cell Detection
+plot_name: 'DG_Test'
+plot_dir: 'testing'
 dg_parameter:
-    plot_dir: 'testing'
     wave_speed: 1
     polynomial_degree: 2
     cfl_number: 0.2
-- 
GitLab