From 4dce056508647ec675db1995b39017260b8f3f6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Fri, 23 Sep 2022 13:51:31 +0200 Subject: [PATCH] Added option to set discontinuity to cell boundary for ANN training. --- ANN_Data_Generator.py | 4 ++++ Initial_Condition.py | 27 +++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index c5f0789..8228b95 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -221,6 +221,10 @@ class TrainingDataGenerator: # Induce adjustment to capture troubled cells adjustment = 0 if initial_condition.is_smooth() \ else mesh.non_ghost_cells[stencil_length//2] + if initial_condition.discontinuity_position == 'left': + adjustment -= mesh.cell_len/2 + elif initial_condition.discontinuity_position == 'right': + adjustment += mesh.cell_len/2 initial_condition.induce_adjustment(-mesh.cell_len/3) # Calculate basis coefficients for stencil diff --git a/Initial_Condition.py b/Initial_Condition.py index cb03f31..4c66fb0 100644 --- a/Initial_Condition.py +++ b/Initial_Condition.py @@ -10,6 +10,13 @@ import numpy as np class InitialCondition(ABC): """Abstract class for initial condition function. + Attributes + ---------- + discontinuity_position : str + Position where discontinuity is located relative to its cell during + training data generation. Can be either at the left boundary ('left'), + cell center ('middle') or right boundary ('right'). + Methods ------- get_name() @@ -43,8 +50,23 @@ class InitialCondition(ABC): config : dict Additional parameters for initial condition. + Raises + ------ + ValueError + If the discontinuity position is not in ['left', 'middle', + 'right']. + """ - pass + self._discontinuity_position = config.pop('discontinuity_position', + 'middle') + if self._discontinuity_position not in ['left', 'middle', 'right']: + raise ValueError(f'The discontinuity position has to be ' + f'either "left", "middle", or "right".') + + @property + def discontinuity_position(self) -> str: + """Return discontinuity position.""" + return self._discontinuity_position def get_name(self): """Returns string of class name.""" @@ -676,7 +698,8 @@ class HeavisideOneSided(InitialCondition): left_factor = config.pop('left_factor', np.random.choice([-100, 100])) right_factor = config.pop('right_factor', np.random.choice([-100, 100])) - config = {'left_factor': left_factor, 'right_factor': right_factor} + config = {'left_factor': left_factor, 'right_factor': right_factor, + 'discontinuity_position': self._discontinuity_position} self._reset(config) def _get_point(self, x): -- GitLab