From 488a1366348cb162518a9c9e590c19535e3fbc43 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Thu, 26 May 2022 00:22:47 +0200
Subject: [PATCH] Replaced mesh with Mesh class.

---
 DG_Approximation.py       | 33 ++++++++++++++++++++++++++-------
 Plotting.py               | 34 +++++++++++++++++++---------------
 Troubled_Cell_Detector.py | 20 ++++++++++++--------
 projection_utils.py       |  8 ++++----
 4 files changed, 61 insertions(+), 34 deletions(-)

diff --git a/DG_Approximation.py b/DG_Approximation.py
index 85fd0f7..b41fdfb 100644
--- a/DG_Approximation.py
+++ b/DG_Approximation.py
@@ -7,16 +7,32 @@ TODO: Contemplate saving 5-CV split and evaluating models separately
 TODO: Contemplate separating cell average and reconstruction calculations
     completely
 TODO: Contemplate removing Methods section from class docstring
+TODO: Ask whether there is a difference between grid and mesh -> Done
+    (same, keep mesh)
+TODO: Contemplate containing the quadrature application for plots in Mesh
+TODO: Contemplate containing coarse mesh generation in Mesh
 
 Urgent:
 TODO: Introduce Mesh class
-    (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.)
+    (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) -> Done
+TODO: Add property attribute for non-ghost cells in Mesh -> Done
+TODO: Replace mesh with Mesh class -> Done
+TODO: Put basis initialization for plots in function
+TODO: Contain cell length in mesh
+TODO: Contain bounds in mesh
+TODO: Contain number of grid cells in mesh
+TODO: Contain interval length in mesh
+TODO: Create data dict for mesh separately
 TODO: Check whether ghost cells are handled/set correctly
+TODO: Ensure uniform use of mesh and grid
+TODO: Check whether eval_point in initial projection is set correctly -> Done
 TODO: Replace getter with property attributes for quadrature
 TODO: Remove use of DGScheme from ANN_Data_Generator
 TODO: Find error in centering for ANN training
 TODO: Adapt TCD from Soraya
     (Dropbox->...->TEST_troubled-cell-detector->Troubled_Cell_Detector)
+TODO: Add TC condition to only flag cell if left-adjacent one is flagged as
+    well
 TODO: Add verbose output
 TODO: Improve file naming (e.g. use '.' instead of '__')
 TODO: Combine ANN workflows
@@ -67,6 +83,7 @@ import Quadrature
 import Update_Scheme
 from Basis_Function import OrthonormalLegendre
 from encoding_utils import encode_ndarray
+from projection_utils import Mesh
 
 x = Symbol('x')
 sns.set()
@@ -86,8 +103,8 @@ class DGScheme:
         Length of a cell in mesh.
     basis : Basis object
         Basis for calculation.
-    mesh : ndarray
-        List of mesh valuation points.
+    mesh : Mesh
+        Mesh for calculation.
     inv_mass : ndarray
         Inverse mass matrix.
 
@@ -281,10 +298,12 @@ class DGScheme:
         # Set additional necessary config parameters
         self._limiter_config['cell_len'] = self._cell_len
 
-        # Set mesh with one ghost point on each side
-        self._mesh = np.arange(self._left_bound - (3/2*self._cell_len),
-                               self._right_bound + (5/2*self._cell_len),
-                               self._cell_len)  # +3/2
+        # Initialize mesh with two ghost cells on each side
+        self._mesh = Mesh(num_grid_cells=self._num_grid_cells,
+                          num_ghost_cells=2, left_bound=self._left_bound,
+                          right_bound=self._right_bound)
+        print(len(self._mesh.cells))
+        print(type(self._mesh.cells))
 
     def build_training_data(self, adjustment, stencil_length,
                             add_reconstructions, initial_condition=None):
diff --git a/Plotting.py b/Plotting.py
index bd104cd..4d00a8f 100644
--- a/Plotting.py
+++ b/Plotting.py
@@ -20,7 +20,7 @@ from Quadrature import Quadrature
 from Initial_Condition import InitialCondition
 from Basis_Function import Basis
 from projection_utils import calculate_exact_solution,\
-    calculate_approximate_solution
+    calculate_approximate_solution, Mesh
 from encoding_utils import decode_ndarray
 
 
@@ -124,7 +124,7 @@ def plot_shock_tube(num_grid_cells: int, troubled_cell_history: list,
     plt.title('Shock Tubes')
 
 
-def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
+def plot_details(fine_projection: ndarray, fine_mesh: Mesh, basis: Basis,
                  coarse_projection: ndarray, multiwavelet_coeffs: ndarray,
                  num_coarse_grid_cells: int) -> None:
     """Plots details of projection to coarser mesh.
@@ -133,8 +133,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
     ----------
     fine_projection, coarse_projection : ndarray
         Matrix of projection for each polynomial degree.
-    fine_mesh : ndarray
-        List of evaluation points for fine mesh.
+    fine_mesh : Mesh
+        Fine mesh for evaluation.
     basis: Basis object
         Basis used for calculation.
     multiwavelet_coeffs : ndarray
@@ -165,8 +165,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
     projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
 
     plt.figure('coeff_details')
-    plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
-    plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
+    plt.plot(fine_mesh.non_ghost_cells, projected_fine-projected_coarse, 'm-.')
+    plt.plot(fine_mesh.non_ghost_cells, projected_wavelet_coeffs, 'y')
     plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
     plt.xlabel('X')
     plt.ylabel('Detail Coefficients')
@@ -349,6 +349,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
     approx_stats = {key: decode_ndarray(approx_stats[key])
                     for key in approx_stats.keys()}
     approx_stats.pop('polynomial_degree')
+    approx_stats['mesh'] = Mesh(**approx_stats['mesh'])
 
     # Plot exact/approximate results, errors, shock tubes,
     # and any detector-dependant plots
@@ -371,7 +372,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
 
 
 def plot_results(projection: ndarray, troubled_cell_history: list,
-                 time_history: list, mesh: ndarray, num_grid_cells: int,
+                 time_history: list, mesh: Mesh, num_grid_cells: int,
                  wave_speed: float, final_time: float,
                  left_bound: float, right_bound: float, basis: Basis,
                  quadrature: Quadrature, init_cond: InitialCondition,
@@ -393,8 +394,8 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
         List of detected troubled cells for each time step.
     time_history : list
         List of value of each time step.
-    mesh : ndarray
-        List of mesh valuation points.
+    mesh : Mesh
+        Mesh for calculation.
     num_grid_cells : int
         Number of cells in the mesh. Usually exponential of 2.
     wave_speed : float
@@ -434,7 +435,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
 
     # Determine exact and approximate solution
     grid, exact = calculate_exact_solution(
-        mesh[2:-2], cell_len, wave_speed,
+        mesh, cell_len, wave_speed,
         final_time, interval_len, quadrature,
         init_cond)
     approx = calculate_approximate_solution(
@@ -444,13 +445,16 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
     # Plot multiwavelet solution (fine and coarse grid)
     if coarse_projection is not None:
         coarse_cell_len = 2*cell_len
-        coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len),
-                                right_bound + (1.5*coarse_cell_len),
-                                coarse_cell_len)
+        coarse_mesh = Mesh(num_grid_cells=num_grid_cells//2,
+                           num_ghost_cells=1, left_bound=left_bound,
+                           right_bound=right_bound)
+        # coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len),
+        #                         right_bound + (1.5*coarse_cell_len),
+        #                         coarse_cell_len)
 
         # Plot exact and approximate solutions for coarse mesh
         coarse_grid, coarse_exact = calculate_exact_solution(
-            coarse_mesh[1:-1], coarse_cell_len, wave_speed,
+            coarse_mesh, coarse_cell_len, wave_speed,
             final_time, interval_len, quadrature,
             init_cond)
         coarse_approx = calculate_approximate_solution(
@@ -462,7 +466,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
 
         # Plot multiwavelet details
         num_coarse_grid_cells = num_grid_cells//2
-        plot_details(projection[:, 1:-1], mesh[2:-2], basis, coarse_projection,
+        plot_details(projection[:, 1:-1], mesh, basis, coarse_projection,
                      multiwavelet_coeffs, num_coarse_grid_cells)
 
         plot_solution_and_approx(grid, exact, approx,
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 237d00f..807e5ab 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -13,6 +13,7 @@ import numpy as np
 import torch
 
 import ANN_Model
+from projection_utils import Mesh
 
 
 class TroubledCellDetector(ABC):
@@ -33,8 +34,6 @@ class TroubledCellDetector(ABC):
         Returns string of class name.
     get_cells(projection)
         Calculates troubled cells in a given projection.
-    plot_results(projection, troubled_cell_history, time_history)
-        Plots results and troubled cells of a projection.
 
     """
     def __init__(self, config, init_cond, quadrature, basis, mesh,
@@ -52,8 +51,8 @@ class TroubledCellDetector(ABC):
             Quadrature for evaluation.
         basis : Basis object
             Basis for calculation.
-        mesh : ndarray
-            List of mesh valuation points.
+        mesh : Mesh
+            Mesh for calculation.
         wave_speed : float, optional
             Speed of wave in rightward direction. Default: 1.
         num_grid_cells : int, optional
@@ -109,10 +108,15 @@ class TroubledCellDetector(ABC):
 
     def create_data_dict(self, projection):
         return {'projection': projection, 'wave_speed': self._wave_speed,
-                'num_grid_cells': self._num_grid_cells, 'mesh': self._mesh,
-                'final_time': self._final_time, 'left_bound':
-                    self._left_bound, 'right_bound': self._right_bound,
-                'polynomial_degree': self._basis.polynomial_degree
+                'num_grid_cells': self._num_grid_cells,
+                'final_time': self._final_time,
+                'left_bound': self._left_bound,
+                'right_bound': self._right_bound,
+                'polynomial_degree': self._basis.polynomial_degree,
+                'mesh': {'num_grid_cells': self._num_grid_cells,
+                         'left_bound': self._left_bound,
+                         'right_bound': self._right_bound,
+                         'num_ghost_cells': 2}
                 }
 
 
diff --git a/projection_utils.py b/projection_utils.py
index 8f8598c..41d6654 100644
--- a/projection_utils.py
+++ b/projection_utils.py
@@ -131,15 +131,15 @@ def calculate_approximate_solution(
 
 
 def calculate_exact_solution(
-        mesh: ndarray, cell_len: float, wave_speed: float, final_time:
+        mesh: Mesh, cell_len: float, wave_speed: float, final_time:
         float, interval_len: float, quadrature: Quadrature, init_cond:
         InitialCondition) -> Tuple[ndarray, ndarray]:
     """Calculate exact solution.
 
     Parameters
     ----------
-    mesh : ndarray
-        List of mesh evaluation points.
+    mesh : Mesh
+        Mesh for evaluation.
     cell_len : float
         Length of a cell in mesh.
     wave_speed : float
@@ -165,7 +165,7 @@ def calculate_exact_solution(
     exact = []
     num_periods = np.floor(wave_speed * final_time / interval_len)
 
-    for cell_center in mesh:
+    for cell_center in mesh.non_ghost_cells:
         eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points()
 
         eval_values = []
-- 
GitLab