From cbf98d278aa5bd98f4f24d49e7fde6de7f0a76e6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Tue, 31 May 2022 18:40:51 +0200
Subject: [PATCH] Moved basis initialization for plots into
 'plot_approximation_results()'.

---
 Plotting.py                 | 11 ++++-------
 Troubled_Cell_Detector.py   |  2 +-
 workflows/approximation.smk |  5 +----
 3 files changed, 6 insertions(+), 12 deletions(-)

diff --git a/Plotting.py b/Plotting.py
index 4d00a8f..cfed1ef 100644
--- a/Plotting.py
+++ b/Plotting.py
@@ -18,7 +18,7 @@ from sympy import Symbol
 
 from Quadrature import Quadrature
 from Initial_Condition import InitialCondition
-from Basis_Function import Basis
+from Basis_Function import Basis, OrthonormalLegendre
 from projection_utils import calculate_exact_solution,\
     calculate_approximate_solution, Mesh
 from encoding_utils import decode_ndarray
@@ -318,7 +318,7 @@ def plot_evaluation_results(evaluation_file: str, directory: str,
 
 
 def plot_approximation_results(data_file: str, directory: str, plot_name: str,
-                               basis: Basis, quadrature: Quadrature,
+                               quadrature: Quadrature,
                                init_cond: InitialCondition) -> None:
     """Plots given approximation results.
 
@@ -333,8 +333,6 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
         Path to directory in which plots will be saved.
     plot_name : str
         Name of plot.
-    basis: Basis object
-        Basis used for calculation.
     quadrature: Quadrature object
         Quadrature used for evaluation.
     init_cond : InitialCondition object
@@ -348,13 +346,12 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
     # Decode all ndarrays by converting lists
     approx_stats = {key: decode_ndarray(approx_stats[key])
                     for key in approx_stats.keys()}
-    approx_stats.pop('polynomial_degree')
+    approx_stats['basis'] = OrthonormalLegendre(**approx_stats['basis'])
     approx_stats['mesh'] = Mesh(**approx_stats['mesh'])
 
     # Plot exact/approximate results, errors, shock tubes,
     # and any detector-dependant plots
-    plot_results(quadrature=quadrature, basis=basis,
-                 init_cond=init_cond, **approx_stats)
+    plot_results(quadrature=quadrature, init_cond=init_cond, **approx_stats)
 
     # Set paths for plot files if not existing already
     if not os.path.exists(directory):
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 807e5ab..ed420b9 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -112,7 +112,7 @@ class TroubledCellDetector(ABC):
                 'final_time': self._final_time,
                 'left_bound': self._left_bound,
                 'right_bound': self._right_bound,
-                'polynomial_degree': self._basis.polynomial_degree,
+                'basis': {'polynomial_degree': self._basis.polynomial_degree},
                 'mesh': {'num_grid_cells': self._num_grid_cells,
                          'left_bound': self._left_bound,
                          'right_bound': self._right_bound,
diff --git a/workflows/approximation.smk b/workflows/approximation.smk
index eb8a84c..00b89a3 100644
--- a/workflows/approximation.smk
+++ b/workflows/approximation.smk
@@ -6,7 +6,6 @@ import Initial_Condition
 import Quadrature
 from DG_Approximation import DGScheme
 from Plotting import plot_approximation_results
-from Basis_Function import OrthonormalLegendre
 
 configfile: 'config.yaml'
 
@@ -92,12 +91,10 @@ rule plot_approximation_results:
             quadrature = getattr(Quadrature, detector_dict.pop(
                 'quadrature', 'Gauss'))(detector_dict.pop(
                 'quadrature_config', {}))
-            basis = OrthonormalLegendre(detector_dict.pop(
-                'polynomial_degree', 2))
 
             plot_approximation_results(directory=params.plot_dir,
                 plot_name=wildcards.scheme,
-                data_file=params.plot_dir+'/'+wildcards.scheme, basis=basis,
+                data_file=params.plot_dir+'/'+wildcards.scheme,
                 quadrature=quadrature, init_cond=init_cond)
 
             toc = time.perf_counter()
-- 
GitLab