Skip to content
Snippets Groups Projects
Commit 41b397b2 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Adapted number of ghost cells based on ANN stencil.

parent 6bf3a1ca
Branches
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ TODO: Enforce periodic boundary condition for projection with decorator -> Done
TODO: Enforce Boxplot bounds with decorator -> Done
TODO: Enforce Boxplot folds with decorator -> Done
TODO: Enforce boundary for initial condition in exact solution only -> Done
TODO: Adapt number of ghost cells based on ANN stencil
TODO: Adapt number of ghost cells based on ANN stencil -> Done
TODO: Ensure exact solution is calculated in Equation class
TODO: Extract objects from UpdateScheme
TODO: Enforce num_ghost_cells to be positive integer for DG (not training)
......
......@@ -27,15 +27,25 @@ def main() -> None:
params = snakemake.params['dg_params']
num_ghost_cells = params.pop('num_ghost_cells', 1)
if len(snakemake.input) > 0:
params['detector_config']['model_state'] = snakemake.input[0]
# Adapt number of ghost cells
if 'stencil_len' in params['detector_config']:
if num_ghost_cells < params['detector_config']['stencil_len']//2:
num_ghost_cells = params['detector_config']['stencil_len']//2
print(f'Number of ghost cells was increased to '
f'{num_ghost_cells} to accommodate the stencil length.')
print(params)
# Initialize mesh with two ghost cells on each side
mesh = Mesh(num_cells=params.pop('num_mesh_cells', 64),
left_bound=params.pop('left_bound', -1),
right_bound=params.pop('right_bound', 1))
right_bound=params.pop('right_bound', 1),
num_ghost_cells=num_ghost_cells)
# Initialize basis
basis = OrthonormalLegendre(params.pop('polynomial_degree', 2))
......
......@@ -40,7 +40,7 @@ class TrainingDataGenerator:
self._quadrature_list = [Gauss({'num_nodes': pol_deg+1})
for pol_deg in range(7)]
self._mesh_list = [Mesh(left_bound=-1, right_bound=1,
num_cells=2**exp, training_data_mode=True)
num_cells=2**exp, num_ghost_cells=0)
for exp in range(5, 12)]
def build_training_data(self, init_cond_list, num_samples, balance=0.5,
......
......@@ -45,7 +45,7 @@ class Mesh:
"""
def __init__(self, num_cells: int, left_bound: float, right_bound: float,
training_data_mode: bool = False) -> None:
num_ghost_cells: int = 1) -> None:
"""Initialize Mesh.
Parameters
......@@ -57,9 +57,9 @@ class Mesh:
Left boundary of the mesh interval.
right_bound : float
Right boundary of the mesh interval.
training_data_mode : bool, optional
Flag indicating whether the mesh is used for training data
generation. Default: False.
num_ghost_cells : int, optional
Number of ghost cells on each side of the mesh, respectively.
Default: False.
Raises
------
......@@ -68,10 +68,8 @@ class Mesh:
"""
self._num_cells = num_cells
self._mode = 'training' if training_data_mode else 'evaluation'
self._num_ghost_cells = 0
if not training_data_mode:
self._num_ghost_cells = 1
self._num_ghost_cells = num_ghost_cells
if self._num_ghost_cells != 0:
if not math.log(self._num_cells, 2).is_integer():
raise ValueError('The number of cells in the mesh has to be '
'an exponential of 2')
......@@ -79,11 +77,6 @@ class Mesh:
self._right_bound = right_bound
self.boundary = 'periodic'
@property
def mode(self) -> str:
"""Return mode ('training' or 'evaluation')."""
return self._mode
@property
def num_cells(self) -> int:
"""Return number of mesh cells."""
......@@ -128,7 +121,8 @@ class Mesh:
"""Return dictionary with data necessary to construct mesh."""
return {'num_cells': self._num_cells,
'left_bound': self._left_bound,
'right_bound': self._right_bound}
'right_bound': self._right_bound,
'num_ghost_cells': self._num_ghost_cells}
def random_stencil(self, stencil_len: int) -> Mesh:
"""Return random stencil.
......@@ -155,4 +149,4 @@ class Mesh:
# Return new mesh instance
return Mesh(left_bound=point - stencil_len/2 * mesh_spacing,
right_bound=point + stencil_len/2 * mesh_spacing,
num_cells=stencil_len, training_data_mode=True)
num_cells=stencil_len, num_ghost_cells=0)
......@@ -430,7 +430,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
coarse_mesh = Mesh(num_cells=mesh.num_cells//2,
left_bound=mesh.bounds[0],
right_bound=mesh.bounds[1],
training_data_mode=True)
num_ghost_cells=0)
# Plot exact and approximate solutions for coarse mesh
coarse_grid, coarse_exact = calculate_exact_solution(
......
......@@ -177,14 +177,6 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
List of indices for all detected troubled cells.
"""
# Reset ghost cells to adjust for stencil length
num_ghost_cells = self._stencil_len//2
projection = projection[:, self._mesh.num_ghost_cells:
-self._mesh.num_ghost_cells]
projection = np.concatenate((projection[:, -num_ghost_cells:],
projection,
projection[:, :num_ghost_cells]), axis=1)
# Calculate input data depending on stencil length
projection_window = projection[:, self._window_mask]
input_data = torch.from_numpy(self._basis.calculate_cell_average(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment