diff --git a/Basis_Function.py b/Basis_Function.py
index e434fce15dfc0675e281f3c8b47edd072515c5a0..43fba039966184b62cd46287afcc6543fe277aaf 100644
--- a/Basis_Function.py
+++ b/Basis_Function.py
@@ -9,6 +9,8 @@ TODO: Contemplate whether calculating projections during initialization can
 import numpy as np
 from sympy import Symbol, integrate
 
+from projection_utils import calculate_approximate_solution
+
 x = Symbol('x')
 z = Symbol('z')
 
@@ -37,6 +39,8 @@ class Basis:
         Returns basis projections.
     get_wavelet_projections()
         Returns wavelet projections.
+    calculate_cell_average(projection, stencil_length, add_reconstructions)
+        Calculate cell averages for a given projection.
 
     """
     def __init__(self, polynomial_degree):
@@ -116,6 +120,54 @@ class Basis:
         """Returns wavelet projection."""
         pass
 
+    def calculate_cell_average(self, projection, stencil_length,
+                               add_reconstructions=True):
+        """Calculate cell averages for a given projection.
+
+        Calculate the cell averages of all cells in a projection.
+        If desired, reconstructions are calculated for the middle cell
+        and added left and right to it, respectively.
+
+        Parameters
+        ----------
+        projection : ndarray
+            Matrix of projection for each polynomial degree.
+        stencil_length : int
+            Size of data array.
+        add_reconstructions: bool, optional
+            Flag whether reconstructions of the middle cell are included.
+            Default: True.
+
+        Returns
+        -------
+        ndarray
+            Matrix containing cell averages (and reconstructions) for given
+            projection.
+
+        """
+        cell_averages = calculate_approximate_solution(
+            projection, np.array([0]), 0, self._basis)
+
+        if add_reconstructions:
+            middle_idx = stencil_length // 2
+            left_reconstructions, right_reconstructions = \
+                self._calculate_reconstructions(
+                    projection[:, middle_idx:middle_idx+1])
+            return np.array(list(map(
+                np.float64, zip(cell_averages[:, :middle_idx],
+                                left_reconstructions,
+                                cell_averages[:, middle_idx],
+                                right_reconstructions,
+                                cell_averages[:, middle_idx+1:]))))
+        return np.array(list(map(np.float64, cell_averages)))
+
+    def _calculate_reconstructions(self, projection):
+        left_reconstructions = calculate_approximate_solution(
+            projection, np.array([-1]), self._polynomial_degree, self._basis)
+        right_reconstructions = calculate_approximate_solution(
+            projection, np.array([1]), self._polynomial_degree, self._basis)
+        return left_reconstructions, right_reconstructions
+
 
 class Legendre(Basis):
     """Class for Legendre basis."""
diff --git a/DG_Approximation.py b/DG_Approximation.py
index 27ad5ece578748fd8705b370ceef4581c48aefea..6268644f4ae5b65c66902ae5e5b4d499b4382d31 100644
--- a/DG_Approximation.py
+++ b/DG_Approximation.py
@@ -7,7 +7,7 @@ TODO: Contemplate saving 5-CV split and evaluating models separately
 TODO: Contemplate separating cell average and reconstruction calculations
 
 Urgent:
-TODO: Move calculate_cell_average() to Basis
+TODO: Move calculate_cell_average() to Basis -> Done
 TODO: Hard-code simplification of cell average/reconstruction in basis
 TODO: Make basis variables public (if feasible)
 TODO: Contain polynomial degree in basis
@@ -66,7 +66,6 @@ import Limiter
 import Quadrature
 import Update_Scheme
 from Basis_Function import OrthonormalLegendre
-from projection_utils import calculate_cell_average
 from encoding_utils import encode_ndarray
 
 x = Symbol('x')
@@ -322,10 +321,9 @@ class DGScheme:
             left_bound=self._left_bound, right_bound=self._right_bound,
             polynomial_degree=self._polynomial_degree, adjustment=adjustment)
 
-        return calculate_cell_average(
+        return self._basis.calculate_cell_average(
             projection=projection[:, 1:-1], stencil_length=stencil_length,
-            polynomial_degree=self._polynomial_degree if add_reconstructions
-            else -1, basis=self._basis)
+            add_reconstructions=add_reconstructions)
 
 
 def do_initial_projection(initial_condition, basis, quadrature,
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 4d2aa60bb5edc5dac7f58e03a581f2a2f60409e3..00447cf219e0c3906973564fbc9159c83803c84a 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -12,7 +12,6 @@ import numpy as np
 import torch
 
 import ANN_Model
-from projection_utils import calculate_cell_average
 
 
 class TroubledCellDetector:
@@ -215,14 +214,14 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
                                      projection[:, :num_ghost_cells]), axis=1)
 
         # Calculate input data depending on stencil length
-        input_data = torch.from_numpy(np.vstack([calculate_cell_average(
-            projection=projection[
-                       :, cell-num_ghost_cells:cell+num_ghost_cells+1],
-            stencil_length=self._stencil_len, basis=self._basis,
-            polynomial_degree=self._polynomial_degree if
-            self._add_reconstructions else -1)
-                for cell in range(num_ghost_cells,
-                                  len(projection[0])-num_ghost_cells)]))
+        input_data = torch.from_numpy(np.vstack([
+            self._basis.calculate_cell_average(
+                projection=projection[
+                           :, cell-num_ghost_cells:cell+num_ghost_cells+1],
+                stencil_length=self._stencil_len,
+                add_reconstructions=self._add_reconstructions)
+            for cell in range(num_ghost_cells,
+                              len(projection[0])-num_ghost_cells)]))
 
         # Determine troubled cells
         model_output = torch.argmax(self._model(input_data.float()), dim=1)
diff --git a/projection_utils.py b/projection_utils.py
index ebc0d0941545c8c9a9e6c653416f33713cf59de2..7b5d7749ad93382c25fe1e00ff7c43839f29f37c 100644
--- a/projection_utils.py
+++ b/projection_utils.py
@@ -104,49 +104,3 @@ def calculate_exact_solution(
     grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))
 
     return grid, exact
-
-
-def calculate_cell_average(projection, basis, stencil_length,
-                           polynomial_degree=-1):
-    """Calculate cell averages for a given projection.
-
-    Calculate the cell averages of all cells in a projection.
-    If desired, reconstructions are calculated for the middle cell
-    and added left and right to it, respectively.
-
-    Parameters
-    ----------
-    projection : ndarray
-        Matrix of projection for each polynomial degree.
-    basis : Basis object
-        Basis for calculation.
-    stencil_length : int
-        Size of data array.
-    polynomial_degree : int, optional
-        Polynomial degree for reconstructions of the middle cell. If -1 no
-        reconstructions will be included. Default: -1.
-
-    Returns
-    -------
-    ndarray
-        Matrix containing cell averages (and reconstructions) for initial
-        projection.
-
-    """
-    basis_vector = basis.get_basis_vector()
-    cell_averages = calculate_approximate_solution(
-        projection, np.array([0]), 0, basis_vector)
-
-    if polynomial_degree != -1:
-        left_reconstructions = calculate_approximate_solution(
-            projection, np.array([-1]), polynomial_degree, basis_vector)
-        right_reconstructions = calculate_approximate_solution(
-            projection, np.array([1]), polynomial_degree, basis_vector)
-        middle_idx = stencil_length // 2
-        return np.array(list(map(
-            np.float64, zip(cell_averages[:, :middle_idx],
-                            left_reconstructions[:, middle_idx],
-                            cell_averages[:, middle_idx],
-                            right_reconstructions[:, middle_idx],
-                            cell_averages[:, middle_idx+1:]))))
-    return np.array(list(map(np.float64, cell_averages)))