From ff8fba4b5d42e29e02feb2d212b1b388dfb96b26 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Fri, 23 Sep 2022 13:04:26 +0200
Subject: [PATCH] Introduced 'training_data_mode' to allow forbidden mesh size
 during ANN training.

---
 projection_utils.py | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/projection_utils.py b/projection_utils.py
index 27e87af..91ba273 100644
--- a/projection_utils.py
+++ b/projection_utils.py
@@ -41,7 +41,8 @@ class Mesh:
     """
 
     def __init__(self, num_grid_cells: int, num_ghost_cells: int,
-                 left_bound: float, right_bound: float) -> None:
+                 left_bound: float, right_bound: float,
+                 training_data_mode: bool = False) -> None:
         """Initialize Mesh.
 
         Parameters
@@ -55,11 +56,16 @@ class Mesh:
             Left boundary of the mesh interval.
         right_bound : float
             Right boundary of the mesh interval.
+        training_data_mode : bool, optional
+            Flag indicating whether the mesh is used for training data
+            generation. Default: False.
+
         """
         self._num_grid_cells = num_grid_cells
-        if not math.log(self._num_grid_cells, 2).is_integer():
-            raise ValueError('The number of cells in the mesh has to be an '
-                             'exponential of 2')
+        if not training_data_mode:
+            if not math.log(self._num_grid_cells, 2).is_integer():
+                raise ValueError('The number of cells in the mesh has to be '
+                                 'an exponential of 2')
         self._num_ghost_cells = num_ghost_cells
         self._left_bound = left_bound
         self._right_bound = right_bound
@@ -130,7 +136,8 @@ class Mesh:
         # Return new mesh instance
         return Mesh(left_bound=point - stencil_length/2 * grid_spacing,
                     right_bound=point + stencil_length/2 * grid_spacing,
-                    num_grid_cells=stencil_length, num_ghost_cells=2)
+                    num_grid_cells=stencil_length, num_ghost_cells=2,
+                    training_data_mode=True)
 
 
 def calculate_approximate_solution(
-- 
GitLab