diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py
index 4f753dd53ae456b10929797c08a3a8e5d2af3e89..acf0d9535462909648045dc09893b0251b6bb79e 100644
--- a/ANN_Data_Generator.py
+++ b/ANN_Data_Generator.py
@@ -22,50 +22,27 @@ class TrainingDataGenerator:
 
     Generates random training data for given initial conditions.
 
-    Attributes
-    ----------
-    smooth_functions : list
-        List of smooth initial/continuous conditions.
-    troubled_functions : list
-        List of discontinuous initial conditions.
-    data_dir : str
-        Path to directory in which training data is saved.
-
     Methods
     -------
     build_training_data(num_samples)
         Builds random training data.
 
     """
-    def __init__(self, initial_conditions, left_bound=-1, right_bound=1,
-                 balance=0.5, stencil_length=3, directory='test_data',
-                 add_reconstructions=True):
+    def __init__(self, left_bound=-1, right_bound=1, stencil_length=3):
         """Initializes TrainingDataGenerator.
 
         Parameters
         ----------
-        initial_conditions : list
-            List of names of initial conditions for training.
         left_bound : float, optional
             Left boundary of interval. Default: -1.
         right_bound : float, optional
             Right boundary of interval. Default: 1.
-        balance: float, optional
-            Ratio between smooth and discontinuous training data. Default: 0.5.
         stencil_length : int, optional
             Size of training data array. Default: 3.
-        directory : str, optional
-            Path to directory in which training data is saved.
-            Default: 'test_data'.
-        add_reconstructions: bool, optional
-            Flag whether reconstructions of the middle cell are included.
-            Default: True.
 
         """
-        self._balance = balance
         self._left_bound = left_bound
         self._right_bound = right_bound
-        self._add_reconstructions = add_reconstructions
 
         # Set stencil length
         if stencil_length % 2 == 0:
@@ -73,29 +50,26 @@ class TrainingDataGenerator:
                              % stencil_length)
         self._stencil_length = stencil_length
 
-        # Separate smooth and discontinuous initial conditions
-        self._smooth_functions = []
-        self._troubled_functions = []
-        for function in initial_conditions:
-            if function['function'].is_smooth():
-                self._smooth_functions.append(function)
-            else:
-                self._troubled_functions.append(function)
-
-        # Set directory
-        self._data_dir = directory
-        if not os.path.exists(self._data_dir):
-            os.makedirs(self._data_dir)
-
-    def build_training_data(self, num_samples):
+    def build_training_data(self, initial_conditions, num_samples, balance=0.5,
+                            directory='test_data', add_reconstructions=True):
         """Builds random training data.
 
         Creates training data consisting of random ANN input and saves it.
 
         Parameters
         ----------
+        initial_conditions : list
+            List of names of initial conditions for training.
         num_samples : int
             Number of training data samples to generate.
+        balance : float, optional
+            Ratio between smooth and discontinuous training data. Default: 0.5.
+        directory : str, optional
+            Path to directory in which training data is saved.
+            Default: 'test_data'.
+        add_reconstructions : bool, optional
+            Flag whether reconstructions of the middle cell are included.
+            Default: True.
 
         Returns
         -------
@@ -106,15 +80,18 @@ class TrainingDataGenerator:
         """
         tic = time.perf_counter()
         print('Calculating training data...\n')
-        data_dict = self._calculate_data_set(num_samples)
+        data_dict = self._calculate_data_set(initial_conditions,
+                                             num_samples, balance,
+                                             add_reconstructions)
         print('Finished calculating training data!')
 
-        self._save_data(data_dict)
+        self._save_data(directory=directory, data=data_dict)
         toc = time.perf_counter()
         print(f'Total runtime: {toc - tic:0.4f}s')
         return data_dict
 
-    def _calculate_data_set(self, num_samples):
+    def _calculate_data_set(self, initial_conditions, num_samples, balance,
+                            add_reconstructions):
         """Calculates random training data of given stencil length.
 
         Creates training data with a given ratio between smooth and
@@ -122,8 +99,14 @@ class TrainingDataGenerator:
 
         Parameters
         ----------
+        initial_conditions : list
+            List of names of initial conditions for training.
         num_samples : int
             Number of training data samples to generate.
+        balance : float
+            Ratio between smooth and discontinuous training data.
+        add_reconstructions : bool
+            Flag whether reconstructions of the middle cell are included.
 
         Returns
         -------
@@ -132,13 +115,24 @@ class TrainingDataGenerator:
             output data.
 
         """
-        num_smooth_samples = round(num_samples * self._balance)
+        # print(type(initial_conditions))
+        # Separate smooth and discontinuous initial conditions
+        smooth_functions = []
+        troubled_functions = []
+        for function in initial_conditions:
+            if function['function'].is_smooth():
+                smooth_functions.append(function)
+            else:
+                troubled_functions.append(function)
+
+        num_smooth_samples = round(num_samples * balance)
         smooth_input, smooth_output = self._generate_cell_data(
-            num_smooth_samples, self._smooth_functions, True)
+            num_smooth_samples, smooth_functions, add_reconstructions, True)
 
         num_troubled_samples = num_samples - num_smooth_samples
         troubled_input, troubled_output = self._generate_cell_data(
-            num_troubled_samples, self._troubled_functions, False)
+            num_troubled_samples, troubled_functions, add_reconstructions,
+            False)
 
         # Merge Data
         input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
@@ -157,7 +151,8 @@ class TrainingDataGenerator:
         return {'input_data.raw': input_matrix, 'output_data': output_matrix,
                 'input_data.normalized': norm_input_matrix}
 
-    def _generate_cell_data(self, num_samples, initial_conditions, is_smooth):
+    def _generate_cell_data(self, num_samples, initial_conditions,
+                            add_reconstructions, is_smooth):
         """Generates random training input and output.
 
         Generates random training input and output for either smooth or
@@ -170,6 +165,8 @@ class TrainingDataGenerator:
             Number of training data samples to generate.
         initial_conditions : list
             List of names of initial conditions for training.
+        add_reconstructions : bool
+            Flag whether reconstructions of the middle cell are included.
         is_smooth : bool
             Flag whether initial conditions are smooth.
 
@@ -181,13 +178,14 @@ class TrainingDataGenerator:
             Array containing output data.
 
         """
+        # print(type(initial_conditions))
         troubled_indicator = 'without' if is_smooth else 'with'
         print('Calculating data ' + troubled_indicator + ' troubled cells...')
         print('Samples to complete:', num_samples)
         tic = time.perf_counter()
 
         num_datapoints = self._stencil_length
-        if self._add_reconstructions:
+        if add_reconstructions:
             num_datapoints += 2
         input_data = np.zeros((num_samples, num_datapoints))
         num_init_cond = len(initial_conditions)
@@ -212,19 +210,19 @@ class TrainingDataGenerator:
             # Calculate basis coefficients for stencil
             polynomial_degree = np.random.randint(1, high=5)
 
-            basis = basis_list[polynomial_degree]
             mesh = Mesh(num_grid_cells=self._stencil_length, num_ghost_cells=2,
                         left_bound=left_bound, right_bound=right_bound)
             projection = do_initial_projection(
                 initial_condition=initial_condition, mesh=mesh,
-                basis=basis,
+                basis=basis_list[polynomial_degree],
                 quadrature=quadrature_list[polynomial_degree],
                 adjustment=adjustment)
 
-            input_data[i] = basis.calculate_cell_average(
+            input_data[i] = basis_list[
+                polynomial_degree].calculate_cell_average(
                 projection=projection[:, 1:-1],
                 stencil_length=self._stencil_length,
-                add_reconstructions=self._add_reconstructions)
+                add_reconstructions=add_reconstructions)
 
             count += 1
             if count % 1000 == 0:
@@ -298,9 +296,14 @@ class TrainingDataGenerator:
             normalized_input_data.append(entry / max_function_value)
         return np.array(normalized_input_data)
 
-    def _save_data(self, data):
+    @staticmethod
+    def _save_data(directory, data):
         """Saves data."""
+        # Set directory
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
         print('Saving training data.')
         for key in data.keys():
-            name = self._data_dir + '/' + key + '.npy'
+            name = directory + '/' + key + '.npy'
             np.save(name, data[key])
diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk
index df966e1354d4b253d156514afa2c382c71bbd529..e4a33c101d522460a3f192fdf6a658ce59292db3 100644
--- a/workflows/ANN_data.smk
+++ b/workflows/ANN_data.smk
@@ -38,10 +38,9 @@ rule generate_data:
         with open(str(log), 'w') as logfile:
             sys.stdout = logfile
             generator = ANN_Data_Generator.TrainingDataGenerator(
-                initial_conditions=initial_conditions,
                 left_bound=params.left_bound, right_bound=params.right_bound,
-                balance=params.balance,
-                stencil_length=params.stencil_length, directory=DIR,
-                add_reconstructions=params.reconstruction_flag)
-            data = generator.build_training_data(
-                num_samples=params.sample_number)
\ No newline at end of file
+                stencil_length=params.stencil_length)
+            data = generator.build_training_data(balance=params.balance,
+                initial_conditions=initial_conditions, directory=DIR,
+                num_samples=params.sample_number,
+                add_reconstructions=params.reconstruction_flag)
\ No newline at end of file