Skip to content
Snippets Groups Projects
Commit 73280c43 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Extract plotting from Troubled_Cell_Detector.

parent 7c2c0818
Branches
No related tags found
No related merge requests found
......@@ -4,17 +4,29 @@
Discussion:
TODO: Ask whether cell averages/reconstructions should be contained in basis
TODO: Contemplate whether basis variables should be public
TODO: Contemplate a Mesh class (mesh, cell_len, num_grid_cells, bounds, etc.)
-> Done (yes, hard-code simplification)
TODO: Contemplate whether basis variables should be public -> Done (yes)
TODO: Contemplate a Mesh class
(mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.)
-> Done (yes)
TODO: Contemplate to contain polynomial degree in basis -> Done (yes)
Urgent:
TODO: Hard-code simplification of cell average/reconstruction in basis
TODO: Make basis variables public (if feasible)
TODO: Contain polynomial degree in basis
TODO: Introduce Mesh class
(mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.)
TODO: Check whether ghost cells are handled/set correctly
TODO: Find error in centering for ANN training
TODO: Investigate g-mesh(?)
TODO: Extract do_initial_projection() from DGScheme -> Done
TODO: Move inverse mass matrix to basis -> Done
TODO: Extract calculate_cell_average() from TCD -> Done
TODO: Improve calculate_cell_average() -> Done
TODO: Extract calculate_[...]_solution() from Plotting
TODO: Extract calculate_[...]_solution() from Plotting -> Done
TODO: Extract plotting from TCD completely
(maybe give indicator which plots are required instead?)
(maybe give indicator which plots are required instead?) -> Done
TODO: Contain all plotting in Plotting
TODO: Remove use of DGScheme from ANN_Data_Generator
TODO: Clean up docstrings
......@@ -59,6 +71,7 @@ import json
import numpy as np
from sympy import Symbol
import math
import seaborn as sns
import matplotlib
from matplotlib import pyplot as plt
......@@ -68,10 +81,14 @@ import Limiter
import Quadrature
import Update_Scheme
from Basis_Function import OrthonormalLegendre
from projection_utils import calculate_cell_average
from projection_utils import calculate_cell_average, \
calculate_exact_solution, calculate_approximate_solution
from Plotting import plot_solution_and_approx, plot_semilog_error, \
plot_error, plot_shock_tube, plot_details
matplotlib.use('Agg')
x = Symbol('x')
sns.set()
def encode_ndarray(obj):
......@@ -270,9 +287,12 @@ class DGScheme:
current_time += time_step
# Save detector-specific data in dictionary
approx_stats = self._detector.create_data_dict(projection)
# Save approximation results in dictionary
approx_stats = {'projection': projection, 'time_history': time_history,
'troubled_cell_history': troubled_cell_history}
approx_stats['time_history'] = time_history
approx_stats['troubled_cell_history'] = troubled_cell_history
# Encode all ndarrays to fit JSON format
approx_stats = {key: encode_ndarray(approx_stats[key])
......@@ -353,7 +373,7 @@ def do_initial_projection(initial_condition, basis, quadrature,
basis: Vector object
Basis used for calculation.
quadrature: Quadrature object
Quadrature fused for evaluation.
Quadrature used for evaluation.
num_grid_cells : int
Number of cells in the mesh. Usually exponential of 2.
left_bound : float
......@@ -402,7 +422,8 @@ def do_initial_projection(initial_condition, basis, quadrature,
return np.transpose(np.array(output_matrix))
def plot_approximation_results(detector, data_file, directory, plot_name):
def plot_approximation_results(data_file, directory, plot_name, quadrature,
init_cond, basis):
"""Plots given approximation results.
Generates plots based on given data, sets plot directory if not
......@@ -416,6 +437,12 @@ def plot_approximation_results(detector, data_file, directory, plot_name):
Path to directory in which plots will be saved.
plot_name : str
Name of plot.
basis: Vector object
Basis used for calculation.
quadrature: Quadrature object
Quadrature used for evaluation.
init_cond : InitialCondition object
Initial condition used for calculation.
"""
# Read approximation results
......@@ -428,7 +455,8 @@ def plot_approximation_results(detector, data_file, directory, plot_name):
# Plot exact/approximate results, errors, shock tubes,
# and any detector-dependant plots
detector.plot_results(**approx_stats)
plot_results(quadrature=quadrature, basis=basis,
init_cond=init_cond, **approx_stats)
# Set paths for plot files if not existing already
if not os.path.exists(directory):
......@@ -443,3 +471,144 @@ def plot_approximation_results(detector, data_file, directory, plot_name):
plt.figure(identifier)
plt.savefig(directory + '/' + identifier + '/' +
plot_name + '.pdf')
def plot_results(projection, troubled_cell_history, time_history, mesh,
num_grid_cells, polynomial_degree, wave_speed, final_time,
left_bound, right_bound, basis, quadrature, init_cond,
colors=None, coarse_projection=None,
multiwavelet_coeffs=None):
"""Plots results and troubled cells of a projection.
Plots exact and approximate solution, errors, and troubled cells of a
projection given its evaluation history.
If coarse grid and projection are given, solutions are displayed for
both coarse and fine grid. Additionally, coefficient details are plotted.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
troubled_cell_history : list
List of detected troubled cells for each time step.
time_history : list
List of value of each time step.
mesh : ndarray
List of mesh valuation points.
num_grid_cells : int
Number of cells in the mesh. Usually exponential of 2.
polynomial_degree : int
Polynomial degree.
wave_speed : float
Speed of wave in rightward direction.
final_time : float
Final time for which approximation is calculated.
left_bound : float
Left boundary of interval.
right_bound : float
Right boundary of interval.
basis: Vector object
Basis used for calculation.
quadrature: Quadrature object
Quadrature used for evaluation.
init_cond : InitialCondition object
Initial condition used for calculation.
colors: dict
Dictionary of colors used for plots.
coarse_projection: ndarray, optional
Matrix of projection on coarse grid for each polynomial degree.
Default: None.
multiwavelet_coeffs: ndarray, optional
Matrix of wavelet coefficients. Default: None.
"""
# Set colors
if colors is None:
colors = {}
colors = _check_colors(colors)
# Calculate needed variables
interval_len = right_bound-left_bound
cell_len = interval_len/num_grid_cells
# Plot troubled cells
plot_shock_tube(num_grid_cells, troubled_cell_history, time_history)
# Determine exact and approximate solution
grid, exact = calculate_exact_solution(
mesh[2:-2], cell_len, wave_speed,
final_time, interval_len, quadrature,
init_cond)
approx = calculate_approximate_solution(
projection[:, 1:-1], quadrature.get_eval_points(),
polynomial_degree, basis.get_basis_vector())
# Plot multiwavelet solution (fine and coarse grid)
if coarse_projection is not None:
coarse_cell_len = 2*cell_len
coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len),
right_bound + (1.5*coarse_cell_len),
coarse_cell_len)
# Plot exact and approximate solutions for coarse mesh
coarse_grid, coarse_exact = calculate_exact_solution(
coarse_mesh[1:-1], coarse_cell_len, wave_speed,
final_time, interval_len, quadrature,
init_cond)
coarse_approx = calculate_approximate_solution(
coarse_projection, quadrature.get_eval_points(),
polynomial_degree, basis.get_basis_vector())
plot_solution_and_approx(
coarse_grid, coarse_exact, coarse_approx, colors['coarse_exact'],
colors['coarse_approx'])
# Plot multiwavelet details
num_coarse_grid_cells = num_grid_cells//2
plot_details(projection[:, 1:-1], mesh[2:-2], coarse_projection,
basis.get_basis_vector(),
basis.get_wavelet_vector(), multiwavelet_coeffs,
num_coarse_grid_cells,
polynomial_degree)
plot_solution_and_approx(grid, exact, approx,
colors['fine_exact'],
colors['fine_approx'])
plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)',
'Approx (Fine)'])
# Plot regular solution (fine grid)
else:
plot_solution_and_approx(grid, exact, approx, colors['exact'],
colors['approx'])
plt.legend(['Exact', 'Approx'])
# Calculate errors
pointwise_error = np.abs(exact-approx)
max_error = np.max(pointwise_error)
# Plot errors
plot_semilog_error(grid, pointwise_error)
plot_error(grid, exact, approx)
print('p =', polynomial_degree)
print('N =', num_grid_cells)
print('maximum error =', max_error)
def _check_colors(colors):
"""Checks plot colors.
Checks whether colors for plots were given and sets them if required.
"""
# Set colors for general plots
colors['exact'] = colors.get('exact', 'k-')
colors['approx'] = colors.get('approx', 'y')
# Set colors for multiwavelet plots
colors['fine_exact'] = colors.get('fine_exact', 'k-.')
colors['fine_approx'] = colors.get('fine_approx', 'b-.')
colors['coarse_exact'] = colors.get('coarse_exact', 'k-')
colors['coarse_approx'] = colors.get('coarse_approx', 'y')
return colors
......@@ -9,18 +9,10 @@ TODO: Give detailed description of wavelet detection
"""
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import ANN_Model
from Plotting import plot_solution_and_approx, plot_semilog_error, \
plot_error, plot_shock_tube, plot_details
from projection_utils import calculate_cell_average,\
calculate_approximate_solution, calculate_exact_solution
matplotlib.use('Agg')
from projection_utils import calculate_cell_average
class TroubledCellDetector:
......@@ -89,21 +81,8 @@ class TroubledCellDetector:
self._init_cond = init_cond
self._quadrature = quadrature
# Set parameters from config if existing
self._colors = config.pop('colors', {})
self._check_colors()
self._reset(config)
def _check_colors(self):
"""Checks plot colors.
Checks whether colors for plots were given and sets them if required.
"""
self._colors['exact'] = self._colors.get('exact', 'k-')
self._colors['approx'] = self._colors.get('approx', 'y')
def _reset(self, config):
"""Resets instance variables.
......@@ -113,7 +92,7 @@ class TroubledCellDetector:
Additional parameters for detector.
"""
sns.set()
pass
def get_name(self):
"""Returns string of class name."""
......@@ -130,62 +109,13 @@ class TroubledCellDetector:
"""
pass
def plot_results(self, projection, troubled_cell_history, time_history):
"""Plots results and troubled cells of a projection.
Plots results and troubled cells of a projection given its evaluation
history.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
troubled_cell_history : list
List of detected troubled cells for each time step.
time_history : list
List of value of each time step.
"""
plot_shock_tube(self._num_grid_cells, troubled_cell_history,
time_history)
max_error = self._plot_mesh(projection)
print('p =', self._polynomial_degree)
print('N =', self._num_grid_cells)
print('maximum error =', max_error)
def _plot_mesh(self, projection):
"""Plots exact and approximate solution as well as errors.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
max_error : float
Maximum error between exact and approximate solution.
"""
grid, exact = calculate_exact_solution(
self._mesh[2:-2], self._cell_len, self._wave_speed,
self._final_time, self._interval_len, self._quadrature,
self._init_cond)
approx = calculate_approximate_solution(
projection[:, 1:-1], self._quadrature.get_eval_points(),
self._polynomial_degree, self._basis.get_basis_vector())
pointwise_error = np.abs(exact-approx)
max_error = np.max(pointwise_error)
plot_solution_and_approx(grid, exact, approx, self._colors['exact'],
self._colors['approx'])
plt.legend(['Exact', 'Approx'])
plot_semilog_error(grid, pointwise_error)
plot_error(grid, exact, approx)
return max_error
def create_data_dict(self, projection):
return {'projection': projection, 'wave_speed': self._wave_speed,
'num_grid_cells': self._num_grid_cells, 'mesh': self._mesh,
'final_time': self._final_time, 'left_bound':
self._left_bound, 'right_bound': self._right_bound,
'polynomial_degree': self._polynomial_degree
}
class NoDetection(TroubledCellDetector):
......@@ -306,16 +236,6 @@ class WaveletDetector(TroubledCellDetector):
???
"""
def _check_colors(self):
"""Checks plot colors.
Checks whether colors for plots were given and sets them if required.
"""
self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')
def _reset(self, config):
"""Resets instance variables.
......@@ -391,31 +311,6 @@ class WaveletDetector(TroubledCellDetector):
"""
return []
def plot_results(self, projection, troubled_cell_history, time_history):
"""Plots results and troubled cells of a projection.
Plots results on coarse and fine grid, errors, troubled cells,
and coefficient details given the projections evaluation history.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
troubled_cell_history : list
List of detected troubled cells for each time step.
time_history : list
List of value of each time step.
"""
multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
coarse_projection = self._calculate_coarse_projection(projection)
plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection,
self._basis.get_basis_vector(),
self._basis.get_wavelet_vector(), multiwavelet_coeffs,
self._num_coarse_grid_cells,
self._polynomial_degree)
super().plot_results(projection, troubled_cell_history, time_history)
def _calculate_coarse_projection(self, projection):
"""Calculates coarse projection.
......@@ -447,71 +342,17 @@ class WaveletDetector(TroubledCellDetector):
return coarse_projection
def _plot_mesh(self, projection):
"""Plots exact and approximate solution as well as errors.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
max_error : float
Maximum error between exact and approximate solution.
"""
grid, exact = calculate_exact_solution(
self._mesh[2:-2], self._cell_len, self._wave_speed,
self._final_time, self._interval_len, self._quadrature,
self._init_cond)
approx = calculate_approximate_solution(
projection[:, 1:-1], self._quadrature.get_eval_points(),
self._polynomial_degree, self._basis.get_basis_vector())
pointwise_error = np.abs(exact-approx)
max_error = np.max(pointwise_error)
self._plot_coarse_mesh(projection)
plot_solution_and_approx(grid, exact, approx,
self._colors['fine_exact'],
self._colors['fine_approx'])
plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)',
'Approx (Fine)'])
plot_semilog_error(grid, pointwise_error)
plot_error(grid, exact, approx)
return max_error
def _plot_coarse_mesh(self, projection):
"""Plots exact and approximate solution as well as errors for a coarse
projection.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
"""
coarse_cell_len = 2*self._cell_len
coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len),
self._right_bound + (1.5*coarse_cell_len),
coarse_cell_len)
def create_data_dict(self, projection):
# Create general directory
data_dict = super().create_data_dict(projection)
# Save multiwavelet-specific data in dictionary
multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
coarse_projection = self._calculate_coarse_projection(projection)
data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs
data_dict['coarse_projection'] = coarse_projection
# Plot exact and approximate solutions for coarse mesh
grid, exact = calculate_exact_solution(
coarse_mesh[1:-1], coarse_cell_len, self._wave_speed,
self._final_time, self._interval_len, self._quadrature,
self._init_cond)
approx = calculate_approximate_solution(
coarse_projection, self._quadrature.get_eval_points(),
self._polynomial_degree, self._basis.get_basis_vector())
plot_solution_and_approx(
grid, exact, approx, self._colors['coarse_exact'],
self._colors['coarse_approx'])
return data_dict
class Boxplot(WaveletDetector):
......
......@@ -94,30 +94,11 @@ rule plot_approximation_results:
'quadrature_config', {}))
basis = OrthonormalLegendre(detector_dict.pop(
'polynomial_degree', 2))
cell_len = (right_bound - left_bound)\
/ params.dg_params.pop('num_grid_cells', 64)
mesh = np.arange(left_bound - (3/2*cell_len),
right_bound + (5/2*cell_len), cell_len)
detector_dict.pop('cfl_number', None)
detector_dict.pop('verbose', None)
detector_dict.pop('history_threshold', None)
detector_dict.pop('detector', None)
detector_dict.pop('limiter', None)
detector_dict.pop('limiter_config', None)
detector_dict.pop('update_scheme', None)
detector_dict['config'] = detector_dict.pop(
'detector_config', {})
detector = getattr(Troubled_Cell_Detector,
params.dg_params['detector'])(left_bound=left_bound,
right_bound=right_bound, init_cond=init_cond, mesh=mesh,
quadrature=quadrature, basis=basis, **detector_dict)
plot_approximation_results(detector=detector,
directory=params.plot_dir, plot_name=wildcards.scheme,
data_file=params.plot_dir+'/'+wildcards.scheme)
plot_approximation_results(directory=params.plot_dir,
plot_name=wildcards.scheme,
data_file=params.plot_dir+'/'+wildcards.scheme, basis=basis,
quadrature=quadrature, init_cond=init_cond)
toc = time.perf_counter()
print(f'Time: {toc - tic:0.4f}s')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment