diff --git a/scripts/approximate_solution.py b/scripts/approximate_solution.py index 5b4086d30c92b1a6e29e1b91bde64708085a1143..32fc785022d270a3b2b61ea9d633dc5c53bc7037 100644 --- a/scripts/approximate_solution.py +++ b/scripts/approximate_solution.py @@ -6,7 +6,14 @@ import sys import time +from tcd import Initial_Condition +from tcd import Limiter +from tcd import Quadrature +from tcd import Troubled_Cell_Detector +from tcd import Update_Scheme +from tcd.Basis_Function import OrthonormalLegendre from tcd.DG_Approximation import DGScheme +from tcd.Mesh import Mesh def main() -> None: @@ -17,12 +24,69 @@ def main() -> None: tic = time.perf_counter() + params = snakemake.params['dg_params'] + if len(snakemake.input) > 0: - snakemake.params['dg_params']['detector_config']['model_state'] = \ - snakemake.input[0] + params['detector_config']['model_state'] = snakemake.input[0] + + print(params) + + # Initialize mesh with two ghost cells on each side + mesh = Mesh(num_cells=params.pop('num_mesh_cells', 64), + left_bound=params.pop('left_bound', -1), + right_bound=params.pop('right_bound', 1), + num_ghost_cells=2) + + # Initialize basis + basis = OrthonormalLegendre(params.pop('polynomial_degree', 2)) + + # Initialize limiter after checking its legitimacy + limiter_name = params.pop('limiter', 'ModifiedMinMod') + limiter_config = params.pop('limiter_config', {}) + limiter_config['cell_len'] = mesh.cell_len + if not hasattr(Limiter, limiter_name): + raise ValueError('Invalid limiter: "%s"' % limiter_name) + limiter = getattr(Limiter, limiter_name)(config=limiter_config) + + # Initialize troubled cell detector after checking its legitimacy + detector_name = params.pop('detector') + detector_config = params.pop('detector_config', {}) + if not hasattr(Troubled_Cell_Detector, detector_name): + raise ValueError('Invalid detector: "%s"' % detector_name) + detector = getattr(Troubled_Cell_Detector, detector_name)( + config=detector_config, + mesh=mesh, basis=basis) + + # Initialize update scheme after checking its legitimacy + scheme_name = params.pop('update_scheme', 'SSPRK3') + wave_speed = params.pop('wave_speed', 1) + if not hasattr(Update_Scheme, scheme_name): + raise ValueError('Invalid update scheme: "%s"' % scheme_name) + update_scheme = getattr(Update_Scheme, scheme_name)( + polynomial_degree=basis.polynomial_degree, + num_cells=mesh.num_cells, detector=detector, + limiter=limiter, wave_speed=wave_speed) + + # Initialize quadrature after checking its legitimacy + quadrature_name = params.pop('quadrature', 'Gauss') + quadrature_config = params.pop('quadrature_config', {}) + if not hasattr(Quadrature, quadrature_name): + raise ValueError('Invalid quadrature: "%s"' % quadrature_name) + quadrature = getattr(Quadrature, quadrature_name)( + config=quadrature_config) + + # Initialize initial condition after checking its legitimacy + init_name = params.pop('init_cond', 'Sine') + init_config = params.pop('init_config', {}) + if not hasattr(Initial_Condition, init_name): + raise ValueError('Invalid initial condition: "%s"' + % init_name) + init_cond = getattr(Initial_Condition, init_name)(config=init_config) - print(snakemake.params['dg_params']) - dg_scheme = DGScheme(**snakemake.params['dg_params']) + dg_scheme = DGScheme(detector=detector, quadrature=quadrature, + init_cond=init_cond, update_scheme=update_scheme, + mesh=mesh, basis=basis, wave_speed=wave_speed, + **params) dg_scheme.approximate( data_file=snakemake.params['plot_dir'] + '/' + diff --git a/scripts/tcd/DG_Approximation.py b/scripts/tcd/DG_Approximation.py index 0b07bb2771fd22c24dd1ce3dc709d4f5f885143d..a374ab4bc961c25340f08eb030bade2daafd4b51 100644 --- a/scripts/tcd/DG_Approximation.py +++ b/scripts/tcd/DG_Approximation.py @@ -9,14 +9,13 @@ from sympy import Symbol import math import seaborn as sns -from . import Troubled_Cell_Detector -from . import Initial_Condition -from . import Limiter -from . import Quadrature -from . import Update_Scheme -from .Basis_Function import OrthonormalLegendre +from .Basis_Function import Basis from .encoding_utils import encode_ndarray +from .Initial_Condition import InitialCondition from .Mesh import Mesh +from .Quadrature import Quadrature +from .Troubled_Cell_Detector import TroubledCellDetector +from .Update_Scheme import UpdateScheme x = Symbol('x') sns.set() @@ -28,13 +27,6 @@ class DGScheme: Approximates linear advection equation using Discontinuous Galerkin Method with troubled-cell-based limiting. - Attributes - ---------- - basis : Basis object - Basis for calculation. - mesh : Mesh - Mesh for calculation. - Methods ------- approximate() @@ -45,115 +37,62 @@ class DGScheme: Builds training data set. """ - def __init__(self, detector, **kwargs): + def __init__(self, detector: TroubledCellDetector, + quadrature: Quadrature, init_cond: InitialCondition, + update_scheme: UpdateScheme, mesh: Mesh, basis: Basis, + wave_speed, **kwargs): """Initializes DGScheme. Parameters ---------- - detector : str - Name of troubled cell detector class. + detector : TroubledCellDetector object + Troubled cell detector. + quadrature : Quadrature object + Quadrature for evaluation. + init_cond : InitialCondition object + Initial condition for evaluation. + update_scheme : UpdateScheme object + Update scheme for evaluation. + mesh : Mesh object + Mesh for calculation. + basis : Basis object + Basis for calculation. + wave_speed : float, optional + Speed of wave in rightward direction. Other Parameters ---------------- - wave_speed : float, optional - Speed of wave in rightward direction. Default: 1. - polynomial_degree : int, optional - Polynomial degree. Default: 2. cfl_number : float, optional CFL number to ensure stability. Default: 0.2. - num_mesh_cells : int, optional - Number of cells in the mesh. Usually exponential of 2. Default: 64. final_time : float, optional Final time for which approximation is calculated. Default: 1. - left_bound : float, optional - Left boundary of interval. Default: -1. - right_bound : float, optional - Right boundary of interval. Default: 1. verbose : bool, optional Flag whether commentary in console is wanted. Default: False. history_threshold : float, optional Threshold when history will be recorded. Default: math.ceil(0.2/cfl_number). - detector_config : dict, optional - Additional parameters for detector object. Default: {}. - init_cond : str, optional - Name of initial condition for evaluation. Default: 'Sine' - init_config : dict, optional - Additional parameters for initial condition object. Default: {}. - limiter : str, optional - Name of limiter for evaluation. Default: 'ModifiedMinMod'. - limiter_config : dict, optional - Additional parameters for limiter. object. Default: {}: - quadrature : str, optional - Name of quadrature for evaluation. Default: 'Gauss'. - quadrature_config : dict, optional - Additional parameters for quadrature object. Default: {}. - update_scheme : str, optional - Name of update scheme for evaluation. Default: 'SSPRK3'. """ + self._detector = detector + self._quadrature = quadrature + self._init_cond = init_cond + self._update_scheme = update_scheme + self._mesh = mesh + self._basis = basis + self._wave_speed = wave_speed + # Unpack keyword arguments - self._wave_speed = kwargs.pop('wave_speed', 1) self._cfl_number = kwargs.pop('cfl_number', 0.2) self._final_time = kwargs.pop('final_time', 1) self._verbose = kwargs.pop('verbose', False) self._history_threshold = kwargs.pop('history_threshold', math.ceil(0.2/self._cfl_number)) - self._detector = detector - self._detector_config = kwargs.pop('detector_config', {}) - self._init_cond = kwargs.pop('init_cond', 'Sine') - self._init_config = kwargs.pop('init_config', {}) - self._limiter = kwargs.pop('limiter', 'ModifiedMinMod') - self._limiter_config = kwargs.pop('limiter_config', {}) - self._quadrature = kwargs.pop('quadrature', 'Gauss') - self._quadrature_config = kwargs.pop('quadrature_config', {}) - self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3') - self._basis = OrthonormalLegendre(kwargs.pop('polynomial_degree', 2)) - - # Initialize mesh with two ghost cells on each side - self._mesh = Mesh(num_cells=kwargs.pop('num_mesh_cells', 64), - left_bound=kwargs.pop('left_bound', -1), - right_bound=kwargs.pop('right_bound', 1), - num_ghost_cells=2) - # print(len(self._mesh.cells)) - # print(type(self._mesh.cells)) # Throw an error if there are extra keyword arguments if len(kwargs) > 0: extra = ', '.join('"%s"' % k for k in list(kwargs.keys())) raise ValueError('Unrecognized arguments: %s' % extra) - # Make sure all classes actually exist - if not hasattr(Troubled_Cell_Detector, self._detector): - raise ValueError('Invalid detector: "%s"' % self._detector) - if not hasattr(Initial_Condition, self._init_cond): - raise ValueError('Invalid initial condition: "%s"' - % self._init_cond) - if not hasattr(Limiter, self._limiter): - raise ValueError('Invalid limiter: "%s"' % self._limiter) - if not hasattr(Quadrature, self._quadrature): - raise ValueError('Invalid quadrature: "%s"' % self._quadrature) - if not hasattr(Update_Scheme, self._update_scheme): - raise ValueError('Invalid update scheme: "%s"' - % self._update_scheme) - - self._reset() - - # Replace the string names with the actual class instances - # (and add the instance variables for the quadrature) - self._init_cond = getattr(Initial_Condition, self._init_cond)( - config=self._init_config) - self._limiter = getattr(Limiter, self._limiter)( - config=self._limiter_config) - self._quadrature = getattr(Quadrature, self._quadrature)( - config=self._quadrature_config) - self._detector = getattr(Troubled_Cell_Detector, self._detector)( - config=self._detector_config, mesh=self._mesh, basis=self._basis) - self._update_scheme = getattr(Update_Scheme, self._update_scheme)( - polynomial_degree=self._basis.polynomial_degree, - num_cells=self._mesh.num_cells, detector=self._detector, - limiter=self._limiter, wave_speed=self._wave_speed) - def approximate(self, data_file): """Approximates projection. @@ -216,11 +155,6 @@ class DGScheme: as json_file: json_file.write(json.dumps(approx_stats)) - def _reset(self): - """Resets instance variables.""" - # Set additional necessary config parameters - self._limiter_config['cell_len'] = self._mesh.cell_len - def do_initial_projection(init_cond, mesh, basis, quadrature, x_shift=0): @@ -233,7 +167,7 @@ def do_initial_projection(init_cond, mesh, basis, quadrature, ---------- init_cond : InitialCondition object Initial condition used for calculation. - mesh : Mesh + mesh : Mesh object Mesh for calculation. basis: Basis object Basis used for calculation.