Skip to content
Snippets Groups Projects
Commit 488a1366 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Replaced mesh with Mesh class.

parent b19ef51c
No related branches found
No related tags found
No related merge requests found
......@@ -7,16 +7,32 @@ TODO: Contemplate saving 5-CV split and evaluating models separately
TODO: Contemplate separating cell average and reconstruction calculations
completely
TODO: Contemplate removing Methods section from class docstring
TODO: Ask whether there is a difference between grid and mesh -> Done
(same, keep mesh)
TODO: Contemplate containing the quadrature application for plots in Mesh
TODO: Contemplate containing coarse mesh generation in Mesh
Urgent:
TODO: Introduce Mesh class
(mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.)
(mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) -> Done
TODO: Add property attribute for non-ghost cells in Mesh -> Done
TODO: Replace mesh with Mesh class -> Done
TODO: Put basis initialization for plots in function
TODO: Contain cell length in mesh
TODO: Contain bounds in mesh
TODO: Contain number of grid cells in mesh
TODO: Contain interval length in mesh
TODO: Create data dict for mesh separately
TODO: Check whether ghost cells are handled/set correctly
TODO: Ensure uniform use of mesh and grid
TODO: Check whether eval_point in initial projection is set correctly -> Done
TODO: Replace getter with property attributes for quadrature
TODO: Remove use of DGScheme from ANN_Data_Generator
TODO: Find error in centering for ANN training
TODO: Adapt TCD from Soraya
(Dropbox->...->TEST_troubled-cell-detector->Troubled_Cell_Detector)
TODO: Add TC condition to only flag cell if left-adjacent one is flagged as
well
TODO: Add verbose output
TODO: Improve file naming (e.g. use '.' instead of '__')
TODO: Combine ANN workflows
......@@ -67,6 +83,7 @@ import Quadrature
import Update_Scheme
from Basis_Function import OrthonormalLegendre
from encoding_utils import encode_ndarray
from projection_utils import Mesh
x = Symbol('x')
sns.set()
......@@ -86,8 +103,8 @@ class DGScheme:
Length of a cell in mesh.
basis : Basis object
Basis for calculation.
mesh : ndarray
List of mesh valuation points.
mesh : Mesh
Mesh for calculation.
inv_mass : ndarray
Inverse mass matrix.
......@@ -281,10 +298,12 @@ class DGScheme:
# Set additional necessary config parameters
self._limiter_config['cell_len'] = self._cell_len
# Set mesh with one ghost point on each side
self._mesh = np.arange(self._left_bound - (3/2*self._cell_len),
self._right_bound + (5/2*self._cell_len),
self._cell_len) # +3/2
# Initialize mesh with two ghost cells on each side
self._mesh = Mesh(num_grid_cells=self._num_grid_cells,
num_ghost_cells=2, left_bound=self._left_bound,
right_bound=self._right_bound)
print(len(self._mesh.cells))
print(type(self._mesh.cells))
def build_training_data(self, adjustment, stencil_length,
add_reconstructions, initial_condition=None):
......
......@@ -20,7 +20,7 @@ from Quadrature import Quadrature
from Initial_Condition import InitialCondition
from Basis_Function import Basis
from projection_utils import calculate_exact_solution,\
calculate_approximate_solution
calculate_approximate_solution, Mesh
from encoding_utils import decode_ndarray
......@@ -124,7 +124,7 @@ def plot_shock_tube(num_grid_cells: int, troubled_cell_history: list,
plt.title('Shock Tubes')
def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
def plot_details(fine_projection: ndarray, fine_mesh: Mesh, basis: Basis,
coarse_projection: ndarray, multiwavelet_coeffs: ndarray,
num_coarse_grid_cells: int) -> None:
"""Plots details of projection to coarser mesh.
......@@ -133,8 +133,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
----------
fine_projection, coarse_projection : ndarray
Matrix of projection for each polynomial degree.
fine_mesh : ndarray
List of evaluation points for fine mesh.
fine_mesh : Mesh
Fine mesh for evaluation.
basis: Basis object
Basis used for calculation.
multiwavelet_coeffs : ndarray
......@@ -165,8 +165,8 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, basis: Basis,
projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
plt.figure('coeff_details')
plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
plt.plot(fine_mesh.non_ghost_cells, projected_fine-projected_coarse, 'm-.')
plt.plot(fine_mesh.non_ghost_cells, projected_wavelet_coeffs, 'y')
plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
plt.xlabel('X')
plt.ylabel('Detail Coefficients')
......@@ -349,6 +349,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
approx_stats = {key: decode_ndarray(approx_stats[key])
for key in approx_stats.keys()}
approx_stats.pop('polynomial_degree')
approx_stats['mesh'] = Mesh(**approx_stats['mesh'])
# Plot exact/approximate results, errors, shock tubes,
# and any detector-dependant plots
......@@ -371,7 +372,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str,
def plot_results(projection: ndarray, troubled_cell_history: list,
time_history: list, mesh: ndarray, num_grid_cells: int,
time_history: list, mesh: Mesh, num_grid_cells: int,
wave_speed: float, final_time: float,
left_bound: float, right_bound: float, basis: Basis,
quadrature: Quadrature, init_cond: InitialCondition,
......@@ -393,8 +394,8 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
List of detected troubled cells for each time step.
time_history : list
List of value of each time step.
mesh : ndarray
List of mesh valuation points.
mesh : Mesh
Mesh for calculation.
num_grid_cells : int
Number of cells in the mesh. Usually exponential of 2.
wave_speed : float
......@@ -434,7 +435,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
# Determine exact and approximate solution
grid, exact = calculate_exact_solution(
mesh[2:-2], cell_len, wave_speed,
mesh, cell_len, wave_speed,
final_time, interval_len, quadrature,
init_cond)
approx = calculate_approximate_solution(
......@@ -444,13 +445,16 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
# Plot multiwavelet solution (fine and coarse grid)
if coarse_projection is not None:
coarse_cell_len = 2*cell_len
coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len),
right_bound + (1.5*coarse_cell_len),
coarse_cell_len)
coarse_mesh = Mesh(num_grid_cells=num_grid_cells//2,
num_ghost_cells=1, left_bound=left_bound,
right_bound=right_bound)
# coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len),
# right_bound + (1.5*coarse_cell_len),
# coarse_cell_len)
# Plot exact and approximate solutions for coarse mesh
coarse_grid, coarse_exact = calculate_exact_solution(
coarse_mesh[1:-1], coarse_cell_len, wave_speed,
coarse_mesh, coarse_cell_len, wave_speed,
final_time, interval_len, quadrature,
init_cond)
coarse_approx = calculate_approximate_solution(
......@@ -462,7 +466,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list,
# Plot multiwavelet details
num_coarse_grid_cells = num_grid_cells//2
plot_details(projection[:, 1:-1], mesh[2:-2], basis, coarse_projection,
plot_details(projection[:, 1:-1], mesh, basis, coarse_projection,
multiwavelet_coeffs, num_coarse_grid_cells)
plot_solution_and_approx(grid, exact, approx,
......
......@@ -13,6 +13,7 @@ import numpy as np
import torch
import ANN_Model
from projection_utils import Mesh
class TroubledCellDetector(ABC):
......@@ -33,8 +34,6 @@ class TroubledCellDetector(ABC):
Returns string of class name.
get_cells(projection)
Calculates troubled cells in a given projection.
plot_results(projection, troubled_cell_history, time_history)
Plots results and troubled cells of a projection.
"""
def __init__(self, config, init_cond, quadrature, basis, mesh,
......@@ -52,8 +51,8 @@ class TroubledCellDetector(ABC):
Quadrature for evaluation.
basis : Basis object
Basis for calculation.
mesh : ndarray
List of mesh valuation points.
mesh : Mesh
Mesh for calculation.
wave_speed : float, optional
Speed of wave in rightward direction. Default: 1.
num_grid_cells : int, optional
......@@ -109,10 +108,15 @@ class TroubledCellDetector(ABC):
def create_data_dict(self, projection):
return {'projection': projection, 'wave_speed': self._wave_speed,
'num_grid_cells': self._num_grid_cells, 'mesh': self._mesh,
'final_time': self._final_time, 'left_bound':
self._left_bound, 'right_bound': self._right_bound,
'polynomial_degree': self._basis.polynomial_degree
'num_grid_cells': self._num_grid_cells,
'final_time': self._final_time,
'left_bound': self._left_bound,
'right_bound': self._right_bound,
'polynomial_degree': self._basis.polynomial_degree,
'mesh': {'num_grid_cells': self._num_grid_cells,
'left_bound': self._left_bound,
'right_bound': self._right_bound,
'num_ghost_cells': 2}
}
......
......@@ -131,15 +131,15 @@ def calculate_approximate_solution(
def calculate_exact_solution(
mesh: ndarray, cell_len: float, wave_speed: float, final_time:
mesh: Mesh, cell_len: float, wave_speed: float, final_time:
float, interval_len: float, quadrature: Quadrature, init_cond:
InitialCondition) -> Tuple[ndarray, ndarray]:
"""Calculate exact solution.
Parameters
----------
mesh : ndarray
List of mesh evaluation points.
mesh : Mesh
Mesh for evaluation.
cell_len : float
Length of a cell in mesh.
wave_speed : float
......@@ -165,7 +165,7 @@ def calculate_exact_solution(
exact = []
num_periods = np.floor(wave_speed * final_time / interval_len)
for cell_center in mesh:
for cell_center in mesh.non_ghost_cells:
eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points()
eval_values = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment