diff --git a/Snakefile b/Snakefile index bf8805b9d581022280c9d07d3fab2b13d22ce382..c1fbb0e047d1f7f6b4ab376ed8df9cd1b9e217fc 100644 --- a/Snakefile +++ b/Snakefile @@ -29,7 +29,7 @@ TODO: Enforce Boxplot bounds with decorator -> Done TODO: Enforce Boxplot folds with decorator -> Done TODO: Enforce boundary for initial condition in exact solution only -> Done TODO: Adapt number of ghost cells based on ANN stencil -> Done -TODO: Ensure exact solution is calculated in Equation class +TODO: Ensure exact solution is calculated in Equation class -> Done TODO: Extract objects from UpdateScheme TODO: Enforce num_ghost_cells to be positive integer for DG (not training) TODO: Add Burger class @@ -56,6 +56,8 @@ TODO: Add verbose output TODO: Add tests for all functions Not feasible yet or doc-related: +TODO: Clean up result plotting (remove object properties of Equation) +TODO: Add functions to create each object from dict TODO: Move plot_approximation_results() into plotting script TODO: Move plot_results() into plotting script TODO: Move plot_evaluation_results() into plotting script diff --git a/scripts/tcd/DG_Approximation.py b/scripts/tcd/DG_Approximation.py index 5f38c2649e03f8a3cb3f920c848ccdc6eadee4ba..d3e044b4d84e9e912ad2d77dde7106e167584915 100644 --- a/scripts/tcd/DG_Approximation.py +++ b/scripts/tcd/DG_Approximation.py @@ -99,14 +99,11 @@ class DGScheme: current_time += time_step - # Save detector-specific data in dictionary - approx_stats = self._detector.create_data_dict(projection) - - # Save approximation results in dictionary - approx_stats['wave_speed'] = self._equation.wave_speed - approx_stats['final_time'] = self._equation.final_time - approx_stats['time_history'] = time_history - approx_stats['troubled_cell_history'] = troubled_cell_history + # Save approximation-specific data in dictionary + approx_stats = {**self._detector.create_data_dict(projection), + **self._equation.create_data_dict(), + 'time_history': time_history, + 'troubled_cell_history': troubled_cell_history} # Encode all ndarrays to fit JSON format approx_stats = {key: encode_ndarray(approx_stats[key]) diff --git a/scripts/tcd/Equation.py b/scripts/tcd/Equation.py index be8de94e8fe269a92dd847b94d8b2b60da4e8ca7..995aba750f5ab282252d85a8bdb1abc7c4df19f8 100644 --- a/scripts/tcd/Equation.py +++ b/scripts/tcd/Equation.py @@ -64,6 +64,21 @@ class Equation(ABC): self._reset() + @property + def basis(self) -> Basis: + """Return basis.""" + return self._basis + + @property + def mesh(self) -> Mesh: + """Return basis.""" + return self._mesh + + @property + def quadrature(self) -> Quadrature: + """Return basis.""" + return self._quadrature + @property def final_time(self) -> float: """Return final time.""" @@ -97,6 +112,15 @@ class Equation(ABC): """Initialize projection.""" pass + def create_data_dict(self): + """Return dictionary with data necessary to construct equation.""" + return {'basis': self._basis.create_data_dict(), + 'mesh': self._mesh.create_data_dict(), + 'final_time': self._final_time, + 'wave_speed': self._wave_speed, + 'cfl_number': self._cfl_number + } + @abstractmethod def update_time_step(self, current_time: float, time_step: float) -> Tuple[float, float]: @@ -259,11 +283,11 @@ class LinearAdvection(Equation): points = np.array([point-self.wave_speed * self.final_time+num_periods * mesh.interval_len for point in grid]) - left_bound, right_bound = self._mesh.bounds + left_bound, right_bound = mesh.bounds while np.any(points < left_bound): - points[points < left_bound] += self._mesh.interval_len + points[points < left_bound] += mesh.interval_len while np.any(points) > right_bound: - points[points > right_bound] -= self._mesh.interval_len + points[points > right_bound] -= mesh.interval_len exact = np.array([self._init_cond.calculate(mesh=mesh, x=point) for point in points]) diff --git a/scripts/tcd/Plotting.py b/scripts/tcd/Plotting.py index 07b708ca5dab2504e48855386daefb0d28276561..84891296b2d7eabf4a35f4150647e7b86f8a77c0 100644 --- a/scripts/tcd/Plotting.py +++ b/scripts/tcd/Plotting.py @@ -3,7 +3,6 @@ @author: Laura C. Kühle """ - import os import time import json @@ -17,8 +16,8 @@ from sympy import Symbol from .Quadrature import Quadrature from .Initial_Condition import InitialCondition from .Basis_Function import Basis, OrthonormalLegendre -from .projection_utils import calculate_exact_solution,\ - calculate_approximate_solution +from .projection_utils import calculate_approximate_solution +from .Equation import Equation, LinearAdvection from .Mesh import Mesh from .encoding_utils import decode_ndarray @@ -341,14 +340,22 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str, # Decode all ndarrays by converting lists approx_stats = {key: decode_ndarray(approx_stats[key]) for key in approx_stats.keys()} - approx_stats['basis'] = OrthonormalLegendre(**approx_stats['basis']) - approx_stats['mesh'] = Mesh(**approx_stats['mesh']) - print([key for key in approx_stats.keys()]) + basis = OrthonormalLegendre(**approx_stats['basis']) + mesh = Mesh(**approx_stats['mesh']) + print(list(approx_stats.keys())) + + approx_stats['equation'] = LinearAdvection( + quadrature=quadrature, init_cond=init_cond, basis=basis, mesh=mesh, + final_time=approx_stats['final_time'], + wave_speed=approx_stats['wave_speed'], + cfl_number=approx_stats['cfl_number']) + + for key in ['basis', 'mesh', 'final_time', 'wave_speed', 'cfl_number']: + approx_stats.pop(key, None) # Plot exact/approximate results, errors, shock tubes, # and any detector-dependant plots - plot_results(quadrature=quadrature, init_cond=init_cond, - colors=colors, **approx_stats) + plot_results(colors=colors, **approx_stats) # Set paths for plot files if not existing already if not os.path.exists(directory): @@ -366,9 +373,7 @@ def plot_approximation_results(data_file: str, directory: str, plot_name: str, def plot_results(projection: ndarray, troubled_cell_history: list, - time_history: list, mesh: Mesh, wave_speed: float, - final_time: float, basis: Basis, - quadrature: Quadrature, init_cond: InitialCondition, + time_history: list, equation: Equation, colors: dict = None, coarse_projection: ndarray = None, multiwavelet_coeffs: ndarray = None) -> None: """Plots results and troubled cells of a projection. @@ -387,18 +392,8 @@ def plot_results(projection: ndarray, troubled_cell_history: list, List of detected troubled cells for each time step. time_history : list List of value of each time step. - mesh : Mesh - Mesh for calculation. - wave_speed : float - Speed of wave in rightward direction. - final_time : float - Final time for which approximation is calculated. - basis: Vector object - Basis used for calculation. - quadrature: Quadrature object - Quadrature used for evaluation. - init_cond : InitialCondition object - Initial condition used for calculation. + equation: Equation object + Equation used for calculation. colors: dict Dictionary of colors used for plots. coarse_projection: ndarray, optional @@ -413,12 +408,15 @@ def plot_results(projection: ndarray, troubled_cell_history: list, colors = {} colors = _check_colors(colors) + mesh = equation.mesh + basis = equation.basis + quadrature = equation.quadrature + # Plot troubled cells plot_shock_tube(mesh.num_cells, troubled_cell_history, time_history) # Determine exact and approximate solution - grid, exact = calculate_exact_solution( - mesh, wave_speed, final_time, quadrature, init_cond) + grid, exact = equation.solve_exactly(mesh) projection = projection[:, mesh.num_ghost_cells: -mesh.num_ghost_cells] approx = calculate_approximate_solution( @@ -433,8 +431,7 @@ def plot_results(projection: ndarray, troubled_cell_history: list, num_ghost_cells=0) # Plot exact and approximate solutions for coarse mesh - coarse_grid, coarse_exact = calculate_exact_solution( - coarse_mesh, wave_speed, final_time, quadrature, init_cond) + coarse_grid, coarse_exact = equation.solve_exactly(coarse_mesh) coarse_approx = calculate_approximate_solution( coarse_projection, quadrature.nodes, basis.polynomial_degree, basis.basis) diff --git a/scripts/tcd/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py index 9b65ce6f3fc8a31f8b238e31d6bebb85b08b7cff..c4fdd9a29715b818f954f11b1eace6498cadf0b7 100644 --- a/scripts/tcd/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -77,10 +77,7 @@ class TroubledCellDetector(ABC): def create_data_dict(self, projection): """Return dictionary with data necessary to plot troubled cells.""" - return {'projection': projection, - 'basis': self._basis.create_data_dict(), - 'mesh': self._mesh.create_data_dict() - } + return {'projection': projection} class NoDetection(TroubledCellDetector): diff --git a/scripts/tcd/projection_utils.py b/scripts/tcd/projection_utils.py index b7a2bed4462a36bb494110abdefcf10345cdfaf1..899ae8dbb34e202fa2411d95b8c91761d8d011c7 100644 --- a/scripts/tcd/projection_utils.py +++ b/scripts/tcd/projection_utils.py @@ -3,16 +3,10 @@ @author: Laura C. Kühle """ - -from typing import Tuple import numpy as np from numpy import ndarray from sympy import Symbol, lambdify -from .Mesh import Mesh -from .Quadrature import Quadrature -from .Initial_Condition import InitialCondition - x = Symbol('x') @@ -52,57 +46,6 @@ def calculate_approximate_solution( return np.reshape(approx, (1, approx.size)) -def calculate_exact_solution( - mesh: Mesh, wave_speed: float, final_time: - float, quadrature: Quadrature, - init_cond: InitialCondition) -> Tuple[ndarray, ndarray]: - """Calculate exact solution. - - Parameters - ---------- - mesh : Mesh - Mesh for evaluation. - wave_speed : float - Speed of wave in rightward direction. - final_time : float - Final time for which approximation is calculated. - quadrature : Quadrature object - Quadrature for evaluation. - init_cond : InitialCondition object - Initial condition for evaluation. - - Returns - ------- - grid : ndarray - Array containing evaluation grid for a function. - exact : ndarray - Array containing exact evaluation of a function. - - """ - num_periods = np.floor(wave_speed * final_time / mesh.interval_len) - - grid = np.repeat(mesh.non_ghost_cells, quadrature.num_nodes) + \ - mesh.cell_len/2 * np.tile(quadrature.nodes, mesh.num_cells) - - # Project points into correct periodic interval - points = np.array([point-wave_speed * - final_time+num_periods * mesh.interval_len - for point in grid]) - left_bound, right_bound = mesh.bounds - while np.any(points < left_bound): - points[points < left_bound] += mesh.interval_len - while np.any(points) > right_bound: - points[points > right_bound] -= mesh.interval_len - - exact = np.array([init_cond.calculate(mesh=mesh, x=point) for - point in points]) - - grid = np.reshape(grid, (1, grid.size)) - exact = np.reshape(exact, (1, exact.size)) - - return grid, exact - - def do_initial_projection(init_cond, mesh, basis, quadrature, x_shift=0): """Calculates initial projection.