Skip to content
Snippets Groups Projects
Commit 59096e0f authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Extracted object initialization from DGScheme.

parent 9a5d3610
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,14 @@ ...@@ -6,7 +6,14 @@
import sys import sys
import time import time
from tcd import Initial_Condition
from tcd import Limiter
from tcd import Quadrature
from tcd import Troubled_Cell_Detector
from tcd import Update_Scheme
from tcd.Basis_Function import OrthonormalLegendre
from tcd.DG_Approximation import DGScheme from tcd.DG_Approximation import DGScheme
from tcd.Mesh import Mesh
def main() -> None: def main() -> None:
...@@ -17,12 +24,69 @@ def main() -> None: ...@@ -17,12 +24,69 @@ def main() -> None:
tic = time.perf_counter() tic = time.perf_counter()
params = snakemake.params['dg_params']
if len(snakemake.input) > 0: if len(snakemake.input) > 0:
snakemake.params['dg_params']['detector_config']['model_state'] = \ params['detector_config']['model_state'] = snakemake.input[0]
snakemake.input[0]
print(params)
# Initialize mesh with two ghost cells on each side
mesh = Mesh(num_cells=params.pop('num_mesh_cells', 64),
left_bound=params.pop('left_bound', -1),
right_bound=params.pop('right_bound', 1),
num_ghost_cells=2)
# Initialize basis
basis = OrthonormalLegendre(params.pop('polynomial_degree', 2))
# Initialize limiter after checking its legitimacy
limiter_name = params.pop('limiter', 'ModifiedMinMod')
limiter_config = params.pop('limiter_config', {})
limiter_config['cell_len'] = mesh.cell_len
if not hasattr(Limiter, limiter_name):
raise ValueError('Invalid limiter: "%s"' % limiter_name)
limiter = getattr(Limiter, limiter_name)(config=limiter_config)
# Initialize troubled cell detector after checking its legitimacy
detector_name = params.pop('detector')
detector_config = params.pop('detector_config', {})
if not hasattr(Troubled_Cell_Detector, detector_name):
raise ValueError('Invalid detector: "%s"' % detector_name)
detector = getattr(Troubled_Cell_Detector, detector_name)(
config=detector_config,
mesh=mesh, basis=basis)
# Initialize update scheme after checking its legitimacy
scheme_name = params.pop('update_scheme', 'SSPRK3')
wave_speed = params.pop('wave_speed', 1)
if not hasattr(Update_Scheme, scheme_name):
raise ValueError('Invalid update scheme: "%s"' % scheme_name)
update_scheme = getattr(Update_Scheme, scheme_name)(
polynomial_degree=basis.polynomial_degree,
num_cells=mesh.num_cells, detector=detector,
limiter=limiter, wave_speed=wave_speed)
# Initialize quadrature after checking its legitimacy
quadrature_name = params.pop('quadrature', 'Gauss')
quadrature_config = params.pop('quadrature_config', {})
if not hasattr(Quadrature, quadrature_name):
raise ValueError('Invalid quadrature: "%s"' % quadrature_name)
quadrature = getattr(Quadrature, quadrature_name)(
config=quadrature_config)
# Initialize initial condition after checking its legitimacy
init_name = params.pop('init_cond', 'Sine')
init_config = params.pop('init_config', {})
if not hasattr(Initial_Condition, init_name):
raise ValueError('Invalid initial condition: "%s"'
% init_name)
init_cond = getattr(Initial_Condition, init_name)(config=init_config)
print(snakemake.params['dg_params']) dg_scheme = DGScheme(detector=detector, quadrature=quadrature,
dg_scheme = DGScheme(**snakemake.params['dg_params']) init_cond=init_cond, update_scheme=update_scheme,
mesh=mesh, basis=basis, wave_speed=wave_speed,
**params)
dg_scheme.approximate( dg_scheme.approximate(
data_file=snakemake.params['plot_dir'] + '/' + data_file=snakemake.params['plot_dir'] + '/' +
......
...@@ -9,14 +9,13 @@ from sympy import Symbol ...@@ -9,14 +9,13 @@ from sympy import Symbol
import math import math
import seaborn as sns import seaborn as sns
from . import Troubled_Cell_Detector from .Basis_Function import Basis
from . import Initial_Condition
from . import Limiter
from . import Quadrature
from . import Update_Scheme
from .Basis_Function import OrthonormalLegendre
from .encoding_utils import encode_ndarray from .encoding_utils import encode_ndarray
from .Initial_Condition import InitialCondition
from .Mesh import Mesh from .Mesh import Mesh
from .Quadrature import Quadrature
from .Troubled_Cell_Detector import TroubledCellDetector
from .Update_Scheme import UpdateScheme
x = Symbol('x') x = Symbol('x')
sns.set() sns.set()
...@@ -28,13 +27,6 @@ class DGScheme: ...@@ -28,13 +27,6 @@ class DGScheme:
Approximates linear advection equation using Discontinuous Galerkin Method Approximates linear advection equation using Discontinuous Galerkin Method
with troubled-cell-based limiting. with troubled-cell-based limiting.
Attributes
----------
basis : Basis object
Basis for calculation.
mesh : Mesh
Mesh for calculation.
Methods Methods
------- -------
approximate() approximate()
...@@ -45,115 +37,62 @@ class DGScheme: ...@@ -45,115 +37,62 @@ class DGScheme:
Builds training data set. Builds training data set.
""" """
def __init__(self, detector, **kwargs): def __init__(self, detector: TroubledCellDetector,
quadrature: Quadrature, init_cond: InitialCondition,
update_scheme: UpdateScheme, mesh: Mesh, basis: Basis,
wave_speed, **kwargs):
"""Initializes DGScheme. """Initializes DGScheme.
Parameters Parameters
---------- ----------
detector : str detector : TroubledCellDetector object
Name of troubled cell detector class. Troubled cell detector.
quadrature : Quadrature object
Quadrature for evaluation.
init_cond : InitialCondition object
Initial condition for evaluation.
update_scheme : UpdateScheme object
Update scheme for evaluation.
mesh : Mesh object
Mesh for calculation.
basis : Basis object
Basis for calculation.
wave_speed : float, optional
Speed of wave in rightward direction.
Other Parameters Other Parameters
---------------- ----------------
wave_speed : float, optional
Speed of wave in rightward direction. Default: 1.
polynomial_degree : int, optional
Polynomial degree. Default: 2.
cfl_number : float, optional cfl_number : float, optional
CFL number to ensure stability. Default: 0.2. CFL number to ensure stability. Default: 0.2.
num_mesh_cells : int, optional
Number of cells in the mesh. Usually exponential of 2. Default: 64.
final_time : float, optional final_time : float, optional
Final time for which approximation is calculated. Default: 1. Final time for which approximation is calculated. Default: 1.
left_bound : float, optional
Left boundary of interval. Default: -1.
right_bound : float, optional
Right boundary of interval. Default: 1.
verbose : bool, optional verbose : bool, optional
Flag whether commentary in console is wanted. Default: False. Flag whether commentary in console is wanted. Default: False.
history_threshold : float, optional history_threshold : float, optional
Threshold when history will be recorded. Threshold when history will be recorded.
Default: math.ceil(0.2/cfl_number). Default: math.ceil(0.2/cfl_number).
detector_config : dict, optional
Additional parameters for detector object. Default: {}.
init_cond : str, optional
Name of initial condition for evaluation. Default: 'Sine'
init_config : dict, optional
Additional parameters for initial condition object. Default: {}.
limiter : str, optional
Name of limiter for evaluation. Default: 'ModifiedMinMod'.
limiter_config : dict, optional
Additional parameters for limiter. object. Default: {}:
quadrature : str, optional
Name of quadrature for evaluation. Default: 'Gauss'.
quadrature_config : dict, optional
Additional parameters for quadrature object. Default: {}.
update_scheme : str, optional
Name of update scheme for evaluation. Default: 'SSPRK3'.
""" """
self._detector = detector
self._quadrature = quadrature
self._init_cond = init_cond
self._update_scheme = update_scheme
self._mesh = mesh
self._basis = basis
self._wave_speed = wave_speed
# Unpack keyword arguments # Unpack keyword arguments
self._wave_speed = kwargs.pop('wave_speed', 1)
self._cfl_number = kwargs.pop('cfl_number', 0.2) self._cfl_number = kwargs.pop('cfl_number', 0.2)
self._final_time = kwargs.pop('final_time', 1) self._final_time = kwargs.pop('final_time', 1)
self._verbose = kwargs.pop('verbose', False) self._verbose = kwargs.pop('verbose', False)
self._history_threshold = kwargs.pop('history_threshold', self._history_threshold = kwargs.pop('history_threshold',
math.ceil(0.2/self._cfl_number)) math.ceil(0.2/self._cfl_number))
self._detector = detector
self._detector_config = kwargs.pop('detector_config', {})
self._init_cond = kwargs.pop('init_cond', 'Sine')
self._init_config = kwargs.pop('init_config', {})
self._limiter = kwargs.pop('limiter', 'ModifiedMinMod')
self._limiter_config = kwargs.pop('limiter_config', {})
self._quadrature = kwargs.pop('quadrature', 'Gauss')
self._quadrature_config = kwargs.pop('quadrature_config', {})
self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3')
self._basis = OrthonormalLegendre(kwargs.pop('polynomial_degree', 2))
# Initialize mesh with two ghost cells on each side
self._mesh = Mesh(num_cells=kwargs.pop('num_mesh_cells', 64),
left_bound=kwargs.pop('left_bound', -1),
right_bound=kwargs.pop('right_bound', 1),
num_ghost_cells=2)
# print(len(self._mesh.cells))
# print(type(self._mesh.cells))
# Throw an error if there are extra keyword arguments # Throw an error if there are extra keyword arguments
if len(kwargs) > 0: if len(kwargs) > 0:
extra = ', '.join('"%s"' % k for k in list(kwargs.keys())) extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
raise ValueError('Unrecognized arguments: %s' % extra) raise ValueError('Unrecognized arguments: %s' % extra)
# Make sure all classes actually exist
if not hasattr(Troubled_Cell_Detector, self._detector):
raise ValueError('Invalid detector: "%s"' % self._detector)
if not hasattr(Initial_Condition, self._init_cond):
raise ValueError('Invalid initial condition: "%s"'
% self._init_cond)
if not hasattr(Limiter, self._limiter):
raise ValueError('Invalid limiter: "%s"' % self._limiter)
if not hasattr(Quadrature, self._quadrature):
raise ValueError('Invalid quadrature: "%s"' % self._quadrature)
if not hasattr(Update_Scheme, self._update_scheme):
raise ValueError('Invalid update scheme: "%s"'
% self._update_scheme)
self._reset()
# Replace the string names with the actual class instances
# (and add the instance variables for the quadrature)
self._init_cond = getattr(Initial_Condition, self._init_cond)(
config=self._init_config)
self._limiter = getattr(Limiter, self._limiter)(
config=self._limiter_config)
self._quadrature = getattr(Quadrature, self._quadrature)(
config=self._quadrature_config)
self._detector = getattr(Troubled_Cell_Detector, self._detector)(
config=self._detector_config, mesh=self._mesh, basis=self._basis)
self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
polynomial_degree=self._basis.polynomial_degree,
num_cells=self._mesh.num_cells, detector=self._detector,
limiter=self._limiter, wave_speed=self._wave_speed)
def approximate(self, data_file): def approximate(self, data_file):
"""Approximates projection. """Approximates projection.
...@@ -216,11 +155,6 @@ class DGScheme: ...@@ -216,11 +155,6 @@ class DGScheme:
as json_file: as json_file:
json_file.write(json.dumps(approx_stats)) json_file.write(json.dumps(approx_stats))
def _reset(self):
"""Resets instance variables."""
# Set additional necessary config parameters
self._limiter_config['cell_len'] = self._mesh.cell_len
def do_initial_projection(init_cond, mesh, basis, quadrature, def do_initial_projection(init_cond, mesh, basis, quadrature,
x_shift=0): x_shift=0):
...@@ -233,7 +167,7 @@ def do_initial_projection(init_cond, mesh, basis, quadrature, ...@@ -233,7 +167,7 @@ def do_initial_projection(init_cond, mesh, basis, quadrature,
---------- ----------
init_cond : InitialCondition object init_cond : InitialCondition object
Initial condition used for calculation. Initial condition used for calculation.
mesh : Mesh mesh : Mesh object
Mesh for calculation. Mesh for calculation.
basis: Basis object basis: Basis object
Basis used for calculation. Basis used for calculation.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment