From 4dce056508647ec675db1995b39017260b8f3f6c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Fri, 23 Sep 2022 13:51:31 +0200
Subject: [PATCH] Added option to set discontinuity to cell boundary for ANN
 training.

---
 ANN_Data_Generator.py |  4 ++++
 Initial_Condition.py  | 27 +++++++++++++++++++++++++--
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py
index c5f0789..8228b95 100644
--- a/ANN_Data_Generator.py
+++ b/ANN_Data_Generator.py
@@ -221,6 +221,10 @@ class TrainingDataGenerator:
             # Induce adjustment to capture troubled cells
             adjustment = 0 if initial_condition.is_smooth() \
                 else mesh.non_ghost_cells[stencil_length//2]
+            if initial_condition.discontinuity_position == 'left':
+                adjustment -= mesh.cell_len/2
+            elif initial_condition.discontinuity_position == 'right':
+                adjustment += mesh.cell_len/2
             initial_condition.induce_adjustment(-mesh.cell_len/3)
 
             # Calculate basis coefficients for stencil
diff --git a/Initial_Condition.py b/Initial_Condition.py
index cb03f31..4c66fb0 100644
--- a/Initial_Condition.py
+++ b/Initial_Condition.py
@@ -10,6 +10,13 @@ import numpy as np
 class InitialCondition(ABC):
     """Abstract class for initial condition function.
 
+    Attributes
+    ----------
+    discontinuity_position : str
+        Position where discontinuity is located relative to its cell during
+        training data generation. Can be either at the left boundary ('left'),
+        cell center ('middle') or right boundary ('right').
+
     Methods
     -------
     get_name()
@@ -43,8 +50,23 @@ class InitialCondition(ABC):
         config : dict
             Additional parameters for initial condition.
 
+        Raises
+        ------
+        ValueError
+            If the discontinuity position is not in ['left', 'middle',
+            'right'].
+
         """
-        pass
+        self._discontinuity_position = config.pop('discontinuity_position',
+                                                  'middle')
+        if self._discontinuity_position not in ['left', 'middle', 'right']:
+            raise ValueError(f'The discontinuity position has to be '
+                             f'either "left", "middle", or "right".')
+
+    @property
+    def discontinuity_position(self) -> str:
+        """Return discontinuity position."""
+        return self._discontinuity_position
 
     def get_name(self):
         """Returns string of class name."""
@@ -676,7 +698,8 @@ class HeavisideOneSided(InitialCondition):
         left_factor = config.pop('left_factor', np.random.choice([-100, 100]))
         right_factor = config.pop('right_factor',
                                   np.random.choice([-100, 100]))
-        config = {'left_factor': left_factor, 'right_factor': right_factor}
+        config = {'left_factor': left_factor, 'right_factor': right_factor,
+                  'discontinuity_position': self._discontinuity_position}
         self._reset(config)
 
     def _get_point(self, x):
-- 
GitLab