From 488a1366348cb162518a9c9e590c19535e3fbc43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Thu, 26 May 2022 00:22:47 +0200 Subject: [PATCH] Replaced mesh with Mesh class. --- DG_Approximation.py | 33 ++++++++++++++++++++++++++------- Plotting.py | 34 +++++++++++++++++++--------------- Troubled_Cell_Detector.py | 20 ++++++++++++-------- projection_utils.py | 8 ++++---- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/DG_Approximation.py b/DG_Approximation.py index 85fd0f7..b41fdfb 100644 --- a/DG_Approximation.py +++ b/DG_Approximation.py @@ -7,16 +7,32 @@ TODO: Contemplate saving 5-CV split and evaluating models separately TODO: Contemplate separating cell average and reconstruction calculations completely TODO: Contemplate removing Methods section from class docstring +TODO: Ask whether there is a difference between grid and mesh -> Done + (same, keep mesh) +TODO: Contemplate containing the quadrature application for plots in Mesh +TODO: Contemplate containing coarse mesh generation in Mesh Urgent: TODO: Introduce Mesh class - (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) + (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) -> Done +TODO: Add property attribute for non-ghost cells in Mesh -> Done +TODO: Replace mesh with Mesh class -> Done +TODO: Put basis initialization for plots in function +TODO: Contain cell length in mesh +TODO: Contain bounds in mesh +TODO: Contain number of grid cells in mesh +TODO: Contain interval length in mesh +TODO: Create data dict for mesh separately TODO: Check whether ghost cells are handled/set correctly +TODO: Ensure uniform use of mesh and grid +TODO: Check whether eval_point in initial projection is set correctly -> Done TODO: Replace getter with property attributes for quadrature TODO: Remove use of DGScheme from ANN_Data_Generator TODO: Find error in centering for ANN training TODO: Adapt TCD from Soraya (Dropbox->...->TEST_troubled-cell-detector->Troubled_Cell_Detector) +TODO: Add TC condition to only flag cell if left-adjacent one is flagged as + well TODO: Add verbose output TODO: Improve file naming (e.g. use '.' instead of '__') TODO: Combine ANN workflows @@ -67,6 +83,7 @@ import Quadrature import Update_Scheme from Basis_Function import OrthonormalLegendre from encoding_utils import encode_ndarray +from projection_utils import Mesh x = Symbol('x') sns.set() @@ -86,8 +103,8 @@ class DGScheme: Length of a cell in mesh. basis : Basis object Basis for calculation. - mesh : ndarray - List of mesh valuation points. + mesh : Mesh + Mesh for calculation. inv_mass : ndarray Inverse mass matrix. @@ -281,10 +298,12 @@ class DGScheme: # Set additional necessary config parameters self._limiter_config['cell_len'] = self._cell_len - # Set mesh with one ghost point on each side - self._mesh = np.arange(self._left_bound - (3/2*self._cell_len), - self._right_bound + (5/2*self._cell_len), - self._cell_len) # +3/2 + # Initialize mesh with two ghost cells on each side + self._mesh = Mesh(num_grid_cells=self._num_grid_cells, + num_ghost_cells=2, left_bound=self._left_bound, + right_bound=self._right_bound) + print(len(self._mesh.cells)) + print(type(self._mesh.cells)) def build_training_data(self, adjustment, stencil_length, add_reconstructions, initial_condition=None): diff --git a/Plotting.py b/Plotting.py index bd104cd..4d00a8f 100644 --- a/Plotting.py +++ b/Plotting.py @@ -20,7 +20,7 @@ from Quadrature import Quadrature from Initial_Condition import InitialCondition from Basis_Function import Basis from projection_utils import calculate_exact_solution,\ - calculate_approximate_solution + calculate_approximate_solution, Mesh from encoding_utils import decode_ndarray @@ -124,7 +124,7 @@ def plot_shock_tube(num_grid_cells: int, troubled_cell_history: list, plt.title('Shock Tubes') -def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis, +def plot_details(fine_projection: ndarray, fine_mesh: Mesh, basis: Basis, coarse_projection: ndarray, multiwavelet_coeffs: ndarray, num_coarse_grid_cells: int) -> None: """Plots details of projection to coarser mesh. @@ -133,8 +133,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis, ---------- fine_projection, coarse_projection : ndarray Matrix of projection for each polynomial degree. - fine_mesh : ndarray - List of evaluation points for fine mesh. + fine_mesh : Mesh + Fine mesh for evaluation. basis: Basis object Basis used for calculation. multiwavelet_coeffs : ndarray @@ -165,8 +165,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis, projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0) plt.figure('coeff_details') - plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.') - plt.plot(fine_mesh, projected_wavelet_coeffs, 'y') + plt.plot(fine_mesh.non_ghost_cells, projected_fine-projected_coarse, 'm-.') + plt.plot(fine_mesh.non_ghost_cells, projected_wavelet_coeffs, 'y') plt.legend(['Fine-Coarse', 'Wavelet Coeff']) plt.xlabel('X') plt.ylabel('Detail Coefficients') @@ -349,6 +349,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str, approx_stats = {key: decode_ndarray(approx_stats[key]) for key in approx_stats.keys()} approx_stats.pop('polynomial_degree') + approx_stats['mesh'] = Mesh(**approx_stats['mesh']) # Plot exact/approximate results, errors, shock tubes, # and any detector-dependant plots @@ -371,7 +372,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str, def plot_results(projection: ndarray, troubled_cell_history: list, - time_history: list, mesh: ndarray, num_grid_cells: int, + time_history: list, mesh: Mesh, num_grid_cells: int, wave_speed: float, final_time: float, left_bound: float, right_bound: float, basis: Basis, quadrature: Quadrature, init_cond: InitialCondition, @@ -393,8 +394,8 @@ def plot_results(projection: ndarray, troubled_cell_history: list, List of detected troubled cells for each time step. time_history : list List of value of each time step. - mesh : ndarray - List of mesh valuation points. + mesh : Mesh + Mesh for calculation. num_grid_cells : int Number of cells in the mesh. Usually exponential of 2. wave_speed : float @@ -434,7 +435,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list, # Determine exact and approximate solution grid, exact = calculate_exact_solution( - mesh[2:-2], cell_len, wave_speed, + mesh, cell_len, wave_speed, final_time, interval_len, quadrature, init_cond) approx = calculate_approximate_solution( @@ -444,13 +445,16 @@ def plot_results(projection: ndarray, troubled_cell_history: list, # Plot multiwavelet solution (fine and coarse grid) if coarse_projection is not None: coarse_cell_len = 2*cell_len - coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len), - right_bound + (1.5*coarse_cell_len), - coarse_cell_len) + coarse_mesh = Mesh(num_grid_cells=num_grid_cells//2, + num_ghost_cells=1, left_bound=left_bound, + right_bound=right_bound) + # coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len), + # right_bound + (1.5*coarse_cell_len), + # coarse_cell_len) # Plot exact and approximate solutions for coarse mesh coarse_grid, coarse_exact = calculate_exact_solution( - coarse_mesh[1:-1], coarse_cell_len, wave_speed, + coarse_mesh, coarse_cell_len, wave_speed, final_time, interval_len, quadrature, init_cond) coarse_approx = calculate_approximate_solution( @@ -462,7 +466,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list, # Plot multiwavelet details num_coarse_grid_cells = num_grid_cells//2 - plot_details(projection[:, 1:-1], mesh[2:-2], basis, coarse_projection, + plot_details(projection[:, 1:-1], mesh, basis, coarse_projection, multiwavelet_coeffs, num_coarse_grid_cells) plot_solution_and_approx(grid, exact, approx, diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index 237d00f..807e5ab 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -13,6 +13,7 @@ import numpy as np import torch import ANN_Model +from projection_utils import Mesh class TroubledCellDetector(ABC): @@ -33,8 +34,6 @@ class TroubledCellDetector(ABC): Returns string of class name. get_cells(projection) Calculates troubled cells in a given projection. - plot_results(projection, troubled_cell_history, time_history) - Plots results and troubled cells of a projection. """ def __init__(self, config, init_cond, quadrature, basis, mesh, @@ -52,8 +51,8 @@ class TroubledCellDetector(ABC): Quadrature for evaluation. basis : Basis object Basis for calculation. - mesh : ndarray - List of mesh valuation points. + mesh : Mesh + Mesh for calculation. wave_speed : float, optional Speed of wave in rightward direction. Default: 1. num_grid_cells : int, optional @@ -109,10 +108,15 @@ class TroubledCellDetector(ABC): def create_data_dict(self, projection): return {'projection': projection, 'wave_speed': self._wave_speed, - 'num_grid_cells': self._num_grid_cells, 'mesh': self._mesh, - 'final_time': self._final_time, 'left_bound': - self._left_bound, 'right_bound': self._right_bound, - 'polynomial_degree': self._basis.polynomial_degree + 'num_grid_cells': self._num_grid_cells, + 'final_time': self._final_time, + 'left_bound': self._left_bound, + 'right_bound': self._right_bound, + 'polynomial_degree': self._basis.polynomial_degree, + 'mesh': {'num_grid_cells': self._num_grid_cells, + 'left_bound': self._left_bound, + 'right_bound': self._right_bound, + 'num_ghost_cells': 2} } diff --git a/projection_utils.py b/projection_utils.py index 8f8598c..41d6654 100644 --- a/projection_utils.py +++ b/projection_utils.py @@ -131,15 +131,15 @@ def calculate_approximate_solution( def calculate_exact_solution( - mesh: ndarray, cell_len: float, wave_speed: float, final_time: + mesh: Mesh, cell_len: float, wave_speed: float, final_time: float, interval_len: float, quadrature: Quadrature, init_cond: InitialCondition) -> Tuple[ndarray, ndarray]: """Calculate exact solution. Parameters ---------- - mesh : ndarray - List of mesh evaluation points. + mesh : Mesh + Mesh for evaluation. cell_len : float Length of a cell in mesh. wave_speed : float @@ -165,7 +165,7 @@ def calculate_exact_solution( exact = [] num_periods = np.floor(wave_speed * final_time / interval_len) - for cell_center in mesh: + for cell_center in mesh.non_ghost_cells: eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points() eval_values = [] -- GitLab