diff --git a/Snakefile b/Snakefile index 1cc26ea8130e89979bb94103130abc2205d8df4e..968ea41f54ec63b79390c826cb7e25b95a66ad77 100644 --- a/Snakefile +++ b/Snakefile @@ -25,7 +25,7 @@ TODO: Discuss how wavelet details should be plotted Urgent: TODO: Rename stiffness matrix to volume integral matrix -> Done TODO: Rename boundary matrix to flux matrix -> Done -TODO: Change num_ghost_cells to be 1 for calc and 0 for other +TODO: Change num_ghost_cells to be 1 for calc and 0 for other -> Done TODO: Move boundary condition to Mesh class TODO: Ensure exact solution is calculated in Equation class TODO: Extract objects from UpdateScheme diff --git a/scripts/approximate_solution.py b/scripts/approximate_solution.py index d1c6c28d48561a569c097a9eac01cd80101b0867..cef66cfde1177d0c51bb802ec276b97f5ba0e39a 100644 --- a/scripts/approximate_solution.py +++ b/scripts/approximate_solution.py @@ -35,8 +35,7 @@ def main() -> None: # Initialize mesh with two ghost cells on each side mesh = Mesh(num_cells=params.pop('num_mesh_cells', 64), left_bound=params.pop('left_bound', -1), - right_bound=params.pop('right_bound', 1), - num_ghost_cells=1) + right_bound=params.pop('right_bound', 1)) # Initialize basis basis = OrthonormalLegendre(params.pop('polynomial_degree', 2)) diff --git a/scripts/tcd/ANN_Data_Generator.py b/scripts/tcd/ANN_Data_Generator.py index 783f3665638778c9bd3ffd99e9fa572248437f4e..cbc8eed1aca622dba9bb5157ded0333e0ea5841a 100644 --- a/scripts/tcd/ANN_Data_Generator.py +++ b/scripts/tcd/ANN_Data_Generator.py @@ -40,7 +40,7 @@ class TrainingDataGenerator: self._quadrature_list = [Gauss({'num_nodes': pol_deg+1}) for pol_deg in range(7)] self._mesh_list = [Mesh(left_bound=-1, right_bound=1, - num_ghost_cells=0, num_cells=2**exp) + num_cells=2**exp, training_data_mode=True) for exp in range(5, 12)] def build_training_data(self, init_cond_list, num_samples, balance=0.5, diff --git a/scripts/tcd/Mesh.py b/scripts/tcd/Mesh.py index 4f0a526ea4a7fb029f273f044762786da4444cf5..5d5d1d61183ad1fb94f8fdbf2a80ed1cf319093c 100644 --- a/scripts/tcd/Mesh.py +++ b/scripts/tcd/Mesh.py @@ -25,6 +25,7 @@ class Mesh: exponential of 2. num_ghost_cells : int Number of ghost cells on both sides of the mesh, respectively. + Either 0 during training or 1 during evaluation. bounds : Tuple[float, float] Left and right boundary of the mesh interval. interval_len : float @@ -43,8 +44,7 @@ class Mesh: """ - def __init__(self, num_cells: int, num_ghost_cells: int, - left_bound: float, right_bound: float, + def __init__(self, num_cells: int, left_bound: float, right_bound: float, training_data_mode: bool = False) -> None: """Initialize Mesh. @@ -53,8 +53,6 @@ class Mesh: num_cells : int Number of cells in the mesh (ghost cells notwithstanding). Has to be an exponential of 2. - num_ghost_cells : int - Number of ghost cells on each side of the mesh. left_bound : float Left boundary of the mesh interval. right_bound : float @@ -71,11 +69,12 @@ class Mesh: """ self._num_cells = num_cells self._mode = 'training' if training_data_mode else 'evaluation' + self._num_ghost_cells = 0 if not training_data_mode: + self._num_ghost_cells = 1 if not math.log(self._num_cells, 2).is_integer(): raise ValueError('The number of cells in the mesh has to be ' 'an exponential of 2') - self._num_ghost_cells = num_ghost_cells self._left_bound = left_bound self._right_bound = right_bound @@ -128,8 +127,7 @@ class Mesh: """Return dictionary with data necessary to construct mesh.""" return {'num_cells': self._num_cells, 'left_bound': self._left_bound, - 'right_bound': self._right_bound, - 'num_ghost_cells': self._num_ghost_cells} + 'right_bound': self._right_bound} def random_stencil(self, stencil_len: int) -> Mesh: """Return random stencil. @@ -156,5 +154,4 @@ class Mesh: # Return new mesh instance return Mesh(left_bound=point - stencil_len/2 * mesh_spacing, right_bound=point + stencil_len/2 * mesh_spacing, - num_cells=stencil_len, num_ghost_cells=0, - training_data_mode=True) + num_cells=stencil_len, training_data_mode=True) diff --git a/scripts/tcd/Plotting.py b/scripts/tcd/Plotting.py index 42f8fdc64a92f80aa4825b9f0556f1b48d44bd86..5fdc6d67dbb756d603fd8cdee2d5e9552a3468ed 100644 --- a/scripts/tcd/Plotting.py +++ b/scripts/tcd/Plotting.py @@ -428,8 +428,9 @@ def plot_results(projection: ndarray, troubled_cell_history: list, # Plot multiwavelet solution (fine and coarse mesh) if coarse_projection is not None: coarse_mesh = Mesh(num_cells=mesh.num_cells//2, - num_ghost_cells=0, left_bound=mesh.bounds[0], - right_bound=mesh.bounds[1]) + left_bound=mesh.bounds[0], + right_bound=mesh.bounds[1], + training_data_mode=True) # Plot exact and approximate solutions for coarse mesh coarse_grid, coarse_exact = calculate_exact_solution(