From da5e745a9d7e1d745c5ff5e2704a22c1c982f4e9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Mon, 14 Mar 2022 14:07:34 +0100
Subject: [PATCH] Moved plotting into separate rule.

---
 DG_Approximation.py         | 111 +++++++++++++++++++-----------------
 Troubled_Cell_Detector.py   |  40 +++++++------
 workflows/approximation.smk |  73 ++++++++++++++++++++++--
 3 files changed, 149 insertions(+), 75 deletions(-)

diff --git a/DG_Approximation.py b/DG_Approximation.py
index 09050d1..44dc745 100644
--- a/DG_Approximation.py
+++ b/DG_Approximation.py
@@ -4,7 +4,8 @@
 
 Urgent:
 TODO: Move plotting into separate function -> Done
-TODO: Move plotting into separate rule
+TODO: Move plotting into separate rule -> Done
+TODO: Extract object initialization from DGScheme
 TODO: Adapt TCD from Soraya
     (Dropbox->...->TEST_troubled-cell-detector->Troubled_Cell_Detector)
 TODO: Add verbose output
@@ -189,18 +190,23 @@ class DGScheme:
         # Replace the string names with the actual class instances
         # (and add the instance variables for the quadrature)
         self._init_cond = getattr(Initial_Condition, self._init_cond)(
-            self._left_bound, self._right_bound, self._init_config)
-        self._limiter = getattr(Limiter, self._limiter)(self._limiter_config)
+            left_bound=self._left_bound, right_bound=self._right_bound,
+            config=self._init_config)
+        self._limiter = getattr(Limiter, self._limiter)(
+            config=self._limiter_config)
         self._quadrature = getattr(Quadrature, self._quadrature)(
-            self._quadrature_config)
+            config=self._quadrature_config)
         self._detector = getattr(Troubled_Cell_Detector, self._detector)(
-            self._detector_config, self._mesh, self._wave_speed,
-            self._polynomial_degree, self._num_grid_cells, self._final_time,
-            self._left_bound, self._right_bound, self._basis,
-            self._init_cond, self._quadrature)
+            config=self._detector_config, mesh=self._mesh,
+            wave_speed=self._wave_speed, num_grid_cells=self._num_grid_cells,
+            polynomial_degree=self._polynomial_degree,
+            final_time=self._final_time, left_bound=self._left_bound,
+            right_bound=self._right_bound, basis=self._basis,
+            init_cond=self._init_cond, quadrature=self._quadrature)
         self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
-            self._polynomial_degree, self._num_grid_cells, self._detector,
-            self._limiter)
+            polynomial_degree=self._polynomial_degree,
+            num_grid_cells=self._num_grid_cells, detector=self._detector,
+            limiter=self._limiter)
 
     def approximate(self, data_file):
         """Approximates projection.
@@ -256,48 +262,6 @@ class DGScheme:
                 as json_file:
             json_file.write(json.dumps(approx_stats))
 
-    def plot_approximation_results(self, data_file, directory, plot_name):
-        """Plots given approximation results.
-
-        Generates plots based on given data, sets plot directory if not
-        already existing, and saves plots.
-
-        Parameters
-        ----------
-        data_file: str
-            Path to data file for plotting.
-        directory: str
-            Path to directory in which plots will be saved.
-        plot_name : str
-            Name of plot.
-
-        """
-        # Read approximation results
-        with open(data_file + '.json') as json_file:
-            approx_stats = json.load(json_file)
-
-        # Decode all ndarrays by converting lists
-        approx_stats = {key: decode_ndarray(approx_stats[key])
-                        for key in approx_stats.keys()}
-
-        # Plot exact/approximate results, errors, shock tubes,
-        # and any detector-dependant plots
-        self._detector.plot_results(**approx_stats)
-
-        # Set paths for plot files if not existing already
-        if not os.path.exists(directory):
-            os.makedirs(directory)
-
-        # Save plots
-        for identifier in plt.get_figlabels():
-            # Set path for figure directory if not existing already
-            if not os.path.exists(directory + '/' + identifier):
-                os.makedirs(directory + '/' + identifier)
-
-            plt.figure(identifier)
-            plt.savefig(directory + '/' + identifier + '/' +
-                        plot_name + '.pdf')
-
     def _reset(self):
         """Resets instance variables."""
         # Set additional necessary instance variables
@@ -409,3 +373,46 @@ class DGScheme:
         return self._detector.calculate_cell_average(projection[:, 1:-1],
                                                      stencil_length,
                                                      add_reconstructions)
+
+
+def plot_approximation_results(detector, data_file, directory, plot_name):
+    """Plots given approximation results.
+
+    Generates plots based on given data, sets plot directory if not
+    already existing, and saves plots.
+
+    Parameters
+    ----------
+    data_file: str
+        Path to data file for plotting.
+    directory: str
+        Path to directory in which plots will be saved.
+    plot_name : str
+        Name of plot.
+
+    """
+    # Read approximation results
+    with open(data_file + '.json') as json_file:
+        approx_stats = json.load(json_file)
+
+    # Decode all ndarrays by converting lists
+    approx_stats = {key: decode_ndarray(approx_stats[key])
+                    for key in approx_stats.keys()}
+
+    # Plot exact/approximate results, errors, shock tubes,
+    # and any detector-dependant plots
+    detector.plot_results(**approx_stats)
+
+    # Set paths for plot files if not existing already
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+
+    # Save plots
+    for identifier in plt.get_figlabels():
+        # Set path for figure directory if not existing already
+        if not os.path.exists(directory + '/' + identifier):
+            os.makedirs(directory + '/' + identifier)
+
+        plt.figure(identifier)
+        plt.savefig(directory + '/' + identifier + '/' +
+                    plot_name + '.pdf')
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 1537699..1a2dd0b 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -49,33 +49,35 @@ class TroubledCellDetector:
         Plots results and troubled cells of a projection.
 
     """
-    def __init__(self, config, mesh, wave_speed, polynomial_degree,
-                 num_grid_cells, final_time, left_bound, right_bound, basis,
-                 init_cond, quadrature):
+    def __init__(self, config, init_cond, quadrature, basis, mesh,
+                 wave_speed=1, polynomial_degree=2, num_grid_cells=64,
+                 final_time=1, left_bound=-1, right_bound=1):
         """Initializes TroubledCellDetector.
 
         Parameters
         ----------
-        mesh : ndarray
-            List of mesh valuation points.
-        wave_speed : float
-            Speed of wave in rightward direction.
-        polynomial_degree : int
-            Polynomial degree.
-        num_grid_cells : int
-            Number of cells in the mesh. Usually exponential of 2.
-        final_time : float
-            Final time for which approximation is calculated.
-        left_bound : float
-            Left boundary of interval.
-        right_bound : float
-            Right boundary of interval.
-        basis : Basis object
-            Basis for calculation.
+        config : dict
+            Additional parameters for detector.
         init_cond : InitialCondition object
             Initial condition for evaluation.
         quadrature : Quadrature object
             Quadrature for evaluation.
+        basis : Basis object
+            Basis for calculation.
+        mesh : ndarray
+            List of mesh valuation points.
+        wave_speed : float, optional
+            Speed of wave in rightward direction. Default: 1.
+        polynomial_degree : int, optional
+            Polynomial degree. Default: 2.
+        num_grid_cells : int, optional
+            Number of cells in the mesh. Usually exponential of 2. Default: 64.
+        final_time : float, optional
+            Final time for which approximation is calculated. Default: 1.
+        left_bound : float, optional
+            Left boundary of interval. Default: -1.
+        right_bound : float, optional
+            Right boundary of interval. Default: 1.
 
         """
         self._mesh = mesh
diff --git a/workflows/approximation.smk b/workflows/approximation.smk
index 13ae4e7..0a6c035 100644
--- a/workflows/approximation.smk
+++ b/workflows/approximation.smk
@@ -2,7 +2,11 @@ import sys
 import time
 import numpy as np
 
-from DG_Approximation import DGScheme
+import Troubled_Cell_Detector
+import Initial_Condition
+import Quadrature
+from DG_Approximation import DGScheme, plot_approximation_results
+from Basis_Function import OrthonormalLegendre
 
 configfile: 'config.yaml'
 
@@ -29,7 +33,7 @@ rule approximate_solution:
     input:
         get_ANN_model
     output:
-        expand(config['plot_dir'] + '/{plot}/{{scheme}}.pdf', plot=PLOTS)
+        config['plot_dir']+'/{scheme}.json'
     params:
         dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
         plot_dir=config['plot_dir']
@@ -50,8 +54,69 @@ rule approximate_solution:
 
             dg_scheme.approximate(
                 data_file=params.plot_dir+'/'+wildcards.scheme)
-            dg_scheme.plot_approximation_results(directory=params.plot_dir,
-                plot_name=wildcards.scheme,
+
+            toc = time.perf_counter()
+            print(f'Time: {toc - tic:0.4f}s')
+
+rule plot_approximation_results:
+    input:
+        config['plot_dir']+'/{scheme}.json'
+    output:
+        expand(config['plot_dir'] + '/{plot}/{{scheme}}.pdf', plot=PLOTS)
+    params:
+        dg_params=lambda wildcards: config['schemes'][wildcards.scheme],
+        plot_dir=config['plot_dir']
+    log:
+        DIR+'/log/plot_approximation_results_{scheme}.log'
+    run:
+        with open(str(log), 'w') as logfile:
+            sys.stdout = logfile
+            sys.stderr = logfile
+
+            tic = time.perf_counter()
+
+            if config['schemes'][wildcards.scheme]['detector'] == \
+                    'ArtificialNeuralNetwork':
+                params.dg_params['detector_config']['model_state'] = \
+                    DIR + '/trained models/' + config['schemes'][
+                        wildcards.scheme]['detector_config']['model_state']
+
+            detector_dict = params.dg_params.copy()
+
+            left_bound = detector_dict.pop('left_bound', -1)
+            right_bound = detector_dict.pop('right_bound', 1)
+            init_cond = getattr(Initial_Condition, detector_dict.pop(
+                'init_cond', 'Sine'))(left_bound=left_bound,
+                right_bound=right_bound,
+                config=detector_dict.pop('init_config', {}))
+            quadrature = getattr(Quadrature, detector_dict.pop(
+                'quadrature', 'Gauss'))(detector_dict.pop(
+                'quadrature_config', {}))
+            basis = OrthonormalLegendre(detector_dict.pop(
+                'polynomial_degree', 2))
+            cell_len = (right_bound - left_bound)\
+                       / params.dg_params.pop('num_grid_cells', 64)
+            mesh = np.arange(left_bound - (3/2*cell_len),
+                right_bound + (5/2*cell_len), cell_len)
+
+            detector_dict.pop('cfl_number', None)
+            detector_dict.pop('verbose', None)
+            detector_dict.pop('history_threshold', None)
+            detector_dict.pop('detector', None)
+            detector_dict.pop('limiter', None)
+            detector_dict.pop('limiter_config', None)
+            detector_dict.pop('update_scheme', None)
+
+            detector_dict['config'] = detector_dict.pop(
+                'detector_config', {})
+
+            detector = getattr(Troubled_Cell_Detector,
+                params.dg_params['detector'])(left_bound=left_bound,
+                right_bound=right_bound, init_cond=init_cond, mesh=mesh,
+                quadrature=quadrature, basis=basis, **detector_dict)
+
+            plot_approximation_results(detector=detector,
+                directory=params.plot_dir, plot_name=wildcards.scheme,
                 data_file=params.plot_dir+'/'+wildcards.scheme)
 
             toc = time.perf_counter()
-- 
GitLab