diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index c5f07896b924a2dc6fefb95a9189c8a2540e473a..8228b951cdc6cd9cf88ba8b0440f5ce655f759a1 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -221,6 +221,10 @@ class TrainingDataGenerator: # Induce adjustment to capture troubled cells adjustment = 0 if initial_condition.is_smooth() \ else mesh.non_ghost_cells[stencil_length//2] + if initial_condition.discontinuity_position == 'left': + adjustment -= mesh.cell_len/2 + elif initial_condition.discontinuity_position == 'right': + adjustment += mesh.cell_len/2 initial_condition.induce_adjustment(-mesh.cell_len/3) # Calculate basis coefficients for stencil diff --git a/Initial_Condition.py b/Initial_Condition.py index cb03f31efaa0f07bf6811d3f0a94552f306de4c8..4c66fb05f0854ad0b1b02271da27e9cb499ae72a 100644 --- a/Initial_Condition.py +++ b/Initial_Condition.py @@ -10,6 +10,13 @@ import numpy as np class InitialCondition(ABC): """Abstract class for initial condition function. + Attributes + ---------- + discontinuity_position : str + Position where discontinuity is located relative to its cell during + training data generation. Can be either at the left boundary ('left'), + cell center ('middle') or right boundary ('right'). + Methods ------- get_name() @@ -43,8 +50,23 @@ class InitialCondition(ABC): config : dict Additional parameters for initial condition. + Raises + ------ + ValueError + If the discontinuity position is not in ['left', 'middle', + 'right']. + """ - pass + self._discontinuity_position = config.pop('discontinuity_position', + 'middle') + if self._discontinuity_position not in ['left', 'middle', 'right']: + raise ValueError(f'The discontinuity position has to be ' + f'either "left", "middle", or "right".') + + @property + def discontinuity_position(self) -> str: + """Return discontinuity position.""" + return self._discontinuity_position def get_name(self): """Returns string of class name.""" @@ -676,7 +698,8 @@ class HeavisideOneSided(InitialCondition): left_factor = config.pop('left_factor', np.random.choice([-100, 100])) right_factor = config.pop('right_factor', np.random.choice([-100, 100])) - config = {'left_factor': left_factor, 'right_factor': right_factor} + config = {'left_factor': left_factor, 'right_factor': right_factor, + 'discontinuity_position': self._discontinuity_position} self._reset(config) def _get_point(self, x):