From d9f1d316343d99687ad0e8dcf120ceb107e1bd1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Fri, 24 Feb 2023 17:39:37 +0100 Subject: [PATCH] Changed 'num_ghost_cells' to be dependent on mesh mode. --- Snakefile | 2 +- scripts/approximate_solution.py | 3 +-- scripts/tcd/ANN_Data_Generator.py | 2 +- scripts/tcd/Mesh.py | 15 ++++++--------- scripts/tcd/Plotting.py | 5 +++-- 5 files changed, 12 insertions(+), 15 deletions(-) diff --git a/Snakefile b/Snakefile index 1cc26ea..968ea41 100644 --- a/Snakefile +++ b/Snakefile @@ -25,7 +25,7 @@ TODO: Discuss how wavelet details should be plotted Urgent: TODO: Rename stiffness matrix to volume integral matrix -> Done TODO: Rename boundary matrix to flux matrix -> Done -TODO: Change num_ghost_cells to be 1 for calc and 0 for other +TODO: Change num_ghost_cells to be 1 for calc and 0 for other -> Done TODO: Move boundary condition to Mesh class TODO: Ensure exact solution is calculated in Equation class TODO: Extract objects from UpdateScheme diff --git a/scripts/approximate_solution.py b/scripts/approximate_solution.py index d1c6c28..cef66cf 100644 --- a/scripts/approximate_solution.py +++ b/scripts/approximate_solution.py @@ -35,8 +35,7 @@ def main() -> None: # Initialize mesh with two ghost cells on each side mesh = Mesh(num_cells=params.pop('num_mesh_cells', 64), left_bound=params.pop('left_bound', -1), - right_bound=params.pop('right_bound', 1), - num_ghost_cells=1) + right_bound=params.pop('right_bound', 1)) # Initialize basis basis = OrthonormalLegendre(params.pop('polynomial_degree', 2)) diff --git a/scripts/tcd/ANN_Data_Generator.py b/scripts/tcd/ANN_Data_Generator.py index 783f366..cbc8eed 100644 --- a/scripts/tcd/ANN_Data_Generator.py +++ b/scripts/tcd/ANN_Data_Generator.py @@ -40,7 +40,7 @@ class TrainingDataGenerator: self._quadrature_list = [Gauss({'num_nodes': pol_deg+1}) for pol_deg in range(7)] self._mesh_list = [Mesh(left_bound=-1, right_bound=1, - num_ghost_cells=0, num_cells=2**exp) + num_cells=2**exp, training_data_mode=True) for exp in range(5, 12)] def build_training_data(self, init_cond_list, num_samples, balance=0.5, diff --git a/scripts/tcd/Mesh.py b/scripts/tcd/Mesh.py index 4f0a526..5d5d1d6 100644 --- a/scripts/tcd/Mesh.py +++ b/scripts/tcd/Mesh.py @@ -25,6 +25,7 @@ class Mesh: exponential of 2. num_ghost_cells : int Number of ghost cells on both sides of the mesh, respectively. + Either 0 during training or 1 during evaluation. bounds : Tuple[float, float] Left and right boundary of the mesh interval. interval_len : float @@ -43,8 +44,7 @@ class Mesh: """ - def __init__(self, num_cells: int, num_ghost_cells: int, - left_bound: float, right_bound: float, + def __init__(self, num_cells: int, left_bound: float, right_bound: float, training_data_mode: bool = False) -> None: """Initialize Mesh. @@ -53,8 +53,6 @@ class Mesh: num_cells : int Number of cells in the mesh (ghost cells notwithstanding). Has to be an exponential of 2. - num_ghost_cells : int - Number of ghost cells on each side of the mesh. left_bound : float Left boundary of the mesh interval. right_bound : float @@ -71,11 +69,12 @@ class Mesh: """ self._num_cells = num_cells self._mode = 'training' if training_data_mode else 'evaluation' + self._num_ghost_cells = 0 if not training_data_mode: + self._num_ghost_cells = 1 if not math.log(self._num_cells, 2).is_integer(): raise ValueError('The number of cells in the mesh has to be ' 'an exponential of 2') - self._num_ghost_cells = num_ghost_cells self._left_bound = left_bound self._right_bound = right_bound @@ -128,8 +127,7 @@ class Mesh: """Return dictionary with data necessary to construct mesh.""" return {'num_cells': self._num_cells, 'left_bound': self._left_bound, - 'right_bound': self._right_bound, - 'num_ghost_cells': self._num_ghost_cells} + 'right_bound': self._right_bound} def random_stencil(self, stencil_len: int) -> Mesh: """Return random stencil. @@ -156,5 +154,4 @@ class Mesh: # Return new mesh instance return Mesh(left_bound=point - stencil_len/2 * mesh_spacing, right_bound=point + stencil_len/2 * mesh_spacing, - num_cells=stencil_len, num_ghost_cells=0, - training_data_mode=True) + num_cells=stencil_len, training_data_mode=True) diff --git a/scripts/tcd/Plotting.py b/scripts/tcd/Plotting.py index 42f8fdc..5fdc6d6 100644 --- a/scripts/tcd/Plotting.py +++ b/scripts/tcd/Plotting.py @@ -428,8 +428,9 @@ def plot_results(projection: ndarray, troubled_cell_history: list, # Plot multiwavelet solution (fine and coarse mesh) if coarse_projection is not None: coarse_mesh = Mesh(num_cells=mesh.num_cells//2, - num_ghost_cells=0, left_bound=mesh.bounds[0], - right_bound=mesh.bounds[1]) + left_bound=mesh.bounds[0], + right_bound=mesh.bounds[1], + training_data_mode=True) # Plot exact and approximate solutions for coarse mesh coarse_grid, coarse_exact = calculate_exact_solution( -- GitLab