diff --git a/DG_Approximation.py b/DG_Approximation.py index 21b67e12c6f922e7db94d59e7a4a933113f88d7f..49994b7149c80fc95996e90fa7f3c39c62615712 100644 --- a/DG_Approximation.py +++ b/DG_Approximation.py @@ -4,17 +4,29 @@ Discussion: TODO: Ask whether cell averages/reconstructions should be contained in basis -TODO: Contemplate whether basis variables should be public -TODO: Contemplate a Mesh class (mesh, cell_len, num_grid_cells, bounds, etc.) + -> Done (yes, hard-code simplification) +TODO: Contemplate whether basis variables should be public -> Done (yes) +TODO: Contemplate a Mesh class + (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) + -> Done (yes) +TODO: Contemplate to contain polynomial degree in basis -> Done (yes) Urgent: +TODO: Hard-code simplification of cell average/reconstruction in basis +TODO: Make basis variables public (if feasible) +TODO: Contain polynomial degree in basis +TODO: Introduce Mesh class + (mesh, cell_len, num_grid_cells, bounds, num_ghost_cells, etc.) +TODO: Check whether ghost cells are handled/set correctly +TODO: Find error in centering for ANN training +TODO: Investigate g-mesh(?) TODO: Extract do_initial_projection() from DGScheme -> Done TODO: Move inverse mass matrix to basis -> Done TODO: Extract calculate_cell_average() from TCD -> Done TODO: Improve calculate_cell_average() -> Done -TODO: Extract calculate_[...]_solution() from Plotting +TODO: Extract calculate_[...]_solution() from Plotting -> Done TODO: Extract plotting from TCD completely - (maybe give indicator which plots are required instead?) + (maybe give indicator which plots are required instead?) -> Done TODO: Contain all plotting in Plotting TODO: Remove use of DGScheme from ANN_Data_Generator TODO: Clean up docstrings @@ -59,6 +71,7 @@ import json import numpy as np from sympy import Symbol import math +import seaborn as sns import matplotlib from matplotlib import pyplot as plt @@ -68,10 +81,14 @@ import Limiter import Quadrature import Update_Scheme from Basis_Function import OrthonormalLegendre -from projection_utils import calculate_cell_average +from projection_utils import calculate_cell_average, \ + calculate_exact_solution, calculate_approximate_solution +from Plotting import plot_solution_and_approx, plot_semilog_error, \ + plot_error, plot_shock_tube, plot_details matplotlib.use('Agg') x = Symbol('x') +sns.set() def encode_ndarray(obj): @@ -270,9 +287,12 @@ class DGScheme: current_time += time_step + # Save detector-specific data in dictionary + approx_stats = self._detector.create_data_dict(projection) + # Save approximation results in dictionary - approx_stats = {'projection': projection, 'time_history': time_history, - 'troubled_cell_history': troubled_cell_history} + approx_stats['time_history'] = time_history + approx_stats['troubled_cell_history'] = troubled_cell_history # Encode all ndarrays to fit JSON format approx_stats = {key: encode_ndarray(approx_stats[key]) @@ -353,7 +373,7 @@ def do_initial_projection(initial_condition, basis, quadrature, basis: Vector object Basis used for calculation. quadrature: Quadrature object - Quadrature fused for evaluation. + Quadrature used for evaluation. num_grid_cells : int Number of cells in the mesh. Usually exponential of 2. left_bound : float @@ -402,7 +422,8 @@ def do_initial_projection(initial_condition, basis, quadrature, return np.transpose(np.array(output_matrix)) -def plot_approximation_results(detector, data_file, directory, plot_name): +def plot_approximation_results(data_file, directory, plot_name, quadrature, + init_cond, basis): """Plots given approximation results. Generates plots based on given data, sets plot directory if not @@ -416,6 +437,12 @@ def plot_approximation_results(detector, data_file, directory, plot_name): Path to directory in which plots will be saved. plot_name : str Name of plot. + basis: Vector object + Basis used for calculation. + quadrature: Quadrature object + Quadrature used for evaluation. + init_cond : InitialCondition object + Initial condition used for calculation. """ # Read approximation results @@ -428,7 +455,8 @@ def plot_approximation_results(detector, data_file, directory, plot_name): # Plot exact/approximate results, errors, shock tubes, # and any detector-dependant plots - detector.plot_results(**approx_stats) + plot_results(quadrature=quadrature, basis=basis, + init_cond=init_cond, **approx_stats) # Set paths for plot files if not existing already if not os.path.exists(directory): @@ -443,3 +471,144 @@ def plot_approximation_results(detector, data_file, directory, plot_name): plt.figure(identifier) plt.savefig(directory + '/' + identifier + '/' + plot_name + '.pdf') + + +def plot_results(projection, troubled_cell_history, time_history, mesh, + num_grid_cells, polynomial_degree, wave_speed, final_time, + left_bound, right_bound, basis, quadrature, init_cond, + colors=None, coarse_projection=None, + multiwavelet_coeffs=None): + """Plots results and troubled cells of a projection. + + Plots exact and approximate solution, errors, and troubled cells of a + projection given its evaluation history. + + If coarse grid and projection are given, solutions are displayed for + both coarse and fine grid. Additionally, coefficient details are plotted. + + Parameters + ---------- + projection : ndarray + Matrix of projection for each polynomial degree. + troubled_cell_history : list + List of detected troubled cells for each time step. + time_history : list + List of value of each time step. + mesh : ndarray + List of mesh valuation points. + num_grid_cells : int + Number of cells in the mesh. Usually exponential of 2. + polynomial_degree : int + Polynomial degree. + wave_speed : float + Speed of wave in rightward direction. + final_time : float + Final time for which approximation is calculated. + left_bound : float + Left boundary of interval. + right_bound : float + Right boundary of interval. + basis: Vector object + Basis used for calculation. + quadrature: Quadrature object + Quadrature used for evaluation. + init_cond : InitialCondition object + Initial condition used for calculation. + colors: dict + Dictionary of colors used for plots. + coarse_projection: ndarray, optional + Matrix of projection on coarse grid for each polynomial degree. + Default: None. + multiwavelet_coeffs: ndarray, optional + Matrix of wavelet coefficients. Default: None. + + """ + # Set colors + if colors is None: + colors = {} + colors = _check_colors(colors) + + # Calculate needed variables + interval_len = right_bound-left_bound + cell_len = interval_len/num_grid_cells + + # Plot troubled cells + plot_shock_tube(num_grid_cells, troubled_cell_history, time_history) + + # Determine exact and approximate solution + grid, exact = calculate_exact_solution( + mesh[2:-2], cell_len, wave_speed, + final_time, interval_len, quadrature, + init_cond) + approx = calculate_approximate_solution( + projection[:, 1:-1], quadrature.get_eval_points(), + polynomial_degree, basis.get_basis_vector()) + + # Plot multiwavelet solution (fine and coarse grid) + if coarse_projection is not None: + coarse_cell_len = 2*cell_len + coarse_mesh = np.arange(left_bound - (0.5*coarse_cell_len), + right_bound + (1.5*coarse_cell_len), + coarse_cell_len) + + # Plot exact and approximate solutions for coarse mesh + coarse_grid, coarse_exact = calculate_exact_solution( + coarse_mesh[1:-1], coarse_cell_len, wave_speed, + final_time, interval_len, quadrature, + init_cond) + coarse_approx = calculate_approximate_solution( + coarse_projection, quadrature.get_eval_points(), + polynomial_degree, basis.get_basis_vector()) + plot_solution_and_approx( + coarse_grid, coarse_exact, coarse_approx, colors['coarse_exact'], + colors['coarse_approx']) + + # Plot multiwavelet details + num_coarse_grid_cells = num_grid_cells//2 + plot_details(projection[:, 1:-1], mesh[2:-2], coarse_projection, + basis.get_basis_vector(), + basis.get_wavelet_vector(), multiwavelet_coeffs, + num_coarse_grid_cells, + polynomial_degree) + + plot_solution_and_approx(grid, exact, approx, + colors['fine_exact'], + colors['fine_approx']) + plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', + 'Approx (Fine)']) + # Plot regular solution (fine grid) + else: + plot_solution_and_approx(grid, exact, approx, colors['exact'], + colors['approx']) + plt.legend(['Exact', 'Approx']) + + # Calculate errors + pointwise_error = np.abs(exact-approx) + max_error = np.max(pointwise_error) + + # Plot errors + plot_semilog_error(grid, pointwise_error) + plot_error(grid, exact, approx) + + print('p =', polynomial_degree) + print('N =', num_grid_cells) + print('maximum error =', max_error) + + +def _check_colors(colors): + """Checks plot colors. + + Checks whether colors for plots were given and sets them if required. + + """ + # Set colors for general plots + colors['exact'] = colors.get('exact', 'k-') + colors['approx'] = colors.get('approx', 'y') + + # Set colors for multiwavelet plots + colors['fine_exact'] = colors.get('fine_exact', 'k-.') + colors['fine_approx'] = colors.get('fine_approx', 'b-.') + colors['coarse_exact'] = colors.get('coarse_exact', 'k-') + colors['coarse_approx'] = colors.get('coarse_approx', 'y') + + return colors diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index f2d8691e935c2251b1c1a504fe8872612299adff..4d2aa60bb5edc5dac7f58e03a581f2a2f60409e3 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -9,18 +9,10 @@ TODO: Give detailed description of wavelet detection """ import numpy as np -import matplotlib -from matplotlib import pyplot as plt -import seaborn as sns import torch import ANN_Model -from Plotting import plot_solution_and_approx, plot_semilog_error, \ - plot_error, plot_shock_tube, plot_details -from projection_utils import calculate_cell_average,\ - calculate_approximate_solution, calculate_exact_solution - -matplotlib.use('Agg') +from projection_utils import calculate_cell_average class TroubledCellDetector: @@ -89,21 +81,8 @@ class TroubledCellDetector: self._init_cond = init_cond self._quadrature = quadrature - # Set parameters from config if existing - self._colors = config.pop('colors', {}) - - self._check_colors() self._reset(config) - def _check_colors(self): - """Checks plot colors. - - Checks whether colors for plots were given and sets them if required. - - """ - self._colors['exact'] = self._colors.get('exact', 'k-') - self._colors['approx'] = self._colors.get('approx', 'y') - def _reset(self, config): """Resets instance variables. @@ -113,7 +92,7 @@ class TroubledCellDetector: Additional parameters for detector. """ - sns.set() + pass def get_name(self): """Returns string of class name.""" @@ -130,62 +109,13 @@ class TroubledCellDetector: """ pass - def plot_results(self, projection, troubled_cell_history, time_history): - """Plots results and troubled cells of a projection. - - Plots results and troubled cells of a projection given its evaluation - history. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - troubled_cell_history : list - List of detected troubled cells for each time step. - time_history : list - List of value of each time step. - - """ - plot_shock_tube(self._num_grid_cells, troubled_cell_history, - time_history) - max_error = self._plot_mesh(projection) - - print('p =', self._polynomial_degree) - print('N =', self._num_grid_cells) - print('maximum error =', max_error) - - def _plot_mesh(self, projection): - """Plots exact and approximate solution as well as errors. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - - Returns - ------- - max_error : float - Maximum error between exact and approximate solution. - - """ - grid, exact = calculate_exact_solution( - self._mesh[2:-2], self._cell_len, self._wave_speed, - self._final_time, self._interval_len, self._quadrature, - self._init_cond) - approx = calculate_approximate_solution( - projection[:, 1:-1], self._quadrature.get_eval_points(), - self._polynomial_degree, self._basis.get_basis_vector()) - - pointwise_error = np.abs(exact-approx) - max_error = np.max(pointwise_error) - - plot_solution_and_approx(grid, exact, approx, self._colors['exact'], - self._colors['approx']) - plt.legend(['Exact', 'Approx']) - plot_semilog_error(grid, pointwise_error) - plot_error(grid, exact, approx) - - return max_error + def create_data_dict(self, projection): + return {'projection': projection, 'wave_speed': self._wave_speed, + 'num_grid_cells': self._num_grid_cells, 'mesh': self._mesh, + 'final_time': self._final_time, 'left_bound': + self._left_bound, 'right_bound': self._right_bound, + 'polynomial_degree': self._polynomial_degree + } class NoDetection(TroubledCellDetector): @@ -306,16 +236,6 @@ class WaveletDetector(TroubledCellDetector): ??? """ - def _check_colors(self): - """Checks plot colors. - - Checks whether colors for plots were given and sets them if required. - - """ - self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.') - self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.') - self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-') - self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y') def _reset(self, config): """Resets instance variables. @@ -391,31 +311,6 @@ class WaveletDetector(TroubledCellDetector): """ return [] - def plot_results(self, projection, troubled_cell_history, time_history): - """Plots results and troubled cells of a projection. - - Plots results on coarse and fine grid, errors, troubled cells, - and coefficient details given the projections evaluation history. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - troubled_cell_history : list - List of detected troubled cells for each time step. - time_history : list - List of value of each time step. - - """ - multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) - coarse_projection = self._calculate_coarse_projection(projection) - plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection, - self._basis.get_basis_vector(), - self._basis.get_wavelet_vector(), multiwavelet_coeffs, - self._num_coarse_grid_cells, - self._polynomial_degree) - super().plot_results(projection, troubled_cell_history, time_history) - def _calculate_coarse_projection(self, projection): """Calculates coarse projection. @@ -447,71 +342,17 @@ class WaveletDetector(TroubledCellDetector): return coarse_projection - def _plot_mesh(self, projection): - """Plots exact and approximate solution as well as errors. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - - Returns - ------- - max_error : float - Maximum error between exact and approximate solution. - - """ - - grid, exact = calculate_exact_solution( - self._mesh[2:-2], self._cell_len, self._wave_speed, - self._final_time, self._interval_len, self._quadrature, - self._init_cond) - approx = calculate_approximate_solution( - projection[:, 1:-1], self._quadrature.get_eval_points(), - self._polynomial_degree, self._basis.get_basis_vector()) - - pointwise_error = np.abs(exact-approx) - max_error = np.max(pointwise_error) - - self._plot_coarse_mesh(projection) - plot_solution_and_approx(grid, exact, approx, - self._colors['fine_exact'], - self._colors['fine_approx']) - plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', - 'Approx (Fine)']) - plot_semilog_error(grid, pointwise_error) - plot_error(grid, exact, approx) - - return max_error - - def _plot_coarse_mesh(self, projection): - """Plots exact and approximate solution as well as errors for a coarse - projection. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - - """ - coarse_cell_len = 2*self._cell_len - coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len), - self._right_bound + (1.5*coarse_cell_len), - coarse_cell_len) + def create_data_dict(self, projection): + # Create general directory + data_dict = super().create_data_dict(projection) + # Save multiwavelet-specific data in dictionary + multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) coarse_projection = self._calculate_coarse_projection(projection) + data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs + data_dict['coarse_projection'] = coarse_projection - # Plot exact and approximate solutions for coarse mesh - grid, exact = calculate_exact_solution( - coarse_mesh[1:-1], coarse_cell_len, self._wave_speed, - self._final_time, self._interval_len, self._quadrature, - self._init_cond) - approx = calculate_approximate_solution( - coarse_projection, self._quadrature.get_eval_points(), - self._polynomial_degree, self._basis.get_basis_vector()) - plot_solution_and_approx( - grid, exact, approx, self._colors['coarse_exact'], - self._colors['coarse_approx']) + return data_dict class Boxplot(WaveletDetector): diff --git a/workflows/approximation.smk b/workflows/approximation.smk index 0a6c03586f76c8922324c1292fc8fbc4edfde1c3..f5cc877d4c56d7b9075d18de37f633610aaeff6e 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -94,30 +94,11 @@ rule plot_approximation_results: 'quadrature_config', {})) basis = OrthonormalLegendre(detector_dict.pop( 'polynomial_degree', 2)) - cell_len = (right_bound - left_bound)\ - / params.dg_params.pop('num_grid_cells', 64) - mesh = np.arange(left_bound - (3/2*cell_len), - right_bound + (5/2*cell_len), cell_len) - - detector_dict.pop('cfl_number', None) - detector_dict.pop('verbose', None) - detector_dict.pop('history_threshold', None) - detector_dict.pop('detector', None) - detector_dict.pop('limiter', None) - detector_dict.pop('limiter_config', None) - detector_dict.pop('update_scheme', None) - - detector_dict['config'] = detector_dict.pop( - 'detector_config', {}) - - detector = getattr(Troubled_Cell_Detector, - params.dg_params['detector'])(left_bound=left_bound, - right_bound=right_bound, init_cond=init_cond, mesh=mesh, - quadrature=quadrature, basis=basis, **detector_dict) - - plot_approximation_results(detector=detector, - directory=params.plot_dir, plot_name=wildcards.scheme, - data_file=params.plot_dir+'/'+wildcards.scheme) + + plot_approximation_results(directory=params.plot_dir, + plot_name=wildcards.scheme, + data_file=params.plot_dir+'/'+wildcards.scheme, basis=basis, + quadrature=quadrature, init_cond=init_cond) toc = time.perf_counter() print(f'Time: {toc - tic:0.4f}s') \ No newline at end of file