diff --git a/Plotting.py b/Plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..5267856aea3023bd478e297780d55cfcf4cc330a
--- /dev/null
+++ b/Plotting.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+"""
+@author: Laura C. Kühle
+
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from sympy import Symbol
+
+
+x = Symbol('x')
+z = Symbol('z')
+sns.set()
+
+
+def plot_solution_and_approx(grid, exact, approx, color_exact, color_approx):
+    print(color_exact, color_approx)
+    plt.figure('exact_and_approx')
+    plt.plot(grid[0], exact[0], color_exact)
+    plt.plot(grid[0], approx[0], color_approx)
+    plt.xlabel('x')
+    plt.ylabel('u(x,t)')
+    plt.title('Solution and Approximation')
+
+
+def plot_semilog_error(grid, pointwise_error):
+    plt.figure('semilog_error')
+    plt.semilogy(grid[0], pointwise_error[0])
+    plt.xlabel('x')
+    plt.ylabel('|u(x,t)-uh(x,t)|')
+    plt.title('Semilog Error plotted at Evaluation points')
+
+
+def plot_error(grid, exact, approx):
+    plt.figure('error')
+    plt.plot(grid[0], exact[0]-approx[0])
+    plt.xlabel('X')
+    plt.ylabel('u(x,t)-uh(x,t)')
+    plt.title('Errors')
+
+
+def plot_shock_tube(num_grid_cells, troubled_cell_history, time_history):
+    plt.figure('shock_tube')
+    for pos in range(len(time_history)):
+        current_cells = troubled_cell_history[pos]
+        for cell in current_cells:
+            plt.plot(cell, time_history[pos], 'k.')
+    plt.xlim((0, num_grid_cells // 2))
+    plt.xlabel('Cell')
+    plt.ylabel('Time')
+    plt.title('Shock Tubes')
+
+
+def plot_details(fine_projection, fine_mesh, coarse_projection, basis, wavelet, multiwavelet_coeffs,
+                 num_coarse_grid_cells, polynomial_degree):
+    averaged_projection = [[coarse_projection[degree][cell] * basis[degree].subs(x, value)
+                            for cell in range(num_coarse_grid_cells)
+                            for value in [-0.5, 0.5]]
+                           for degree in range(polynomial_degree + 1)]
+
+    wavelet_projection = [[multiwavelet_coeffs[degree][cell] * wavelet[degree].subs(z, 0.5) * value
+                           for cell in range(num_coarse_grid_cells)
+                           for value in [(-1) ** (polynomial_degree + degree + 1), 1]]
+                          for degree in range(polynomial_degree + 1)]
+
+    projected_coarse = np.sum(averaged_projection, axis=0)
+    projected_fine = np.sum([fine_projection[degree] * basis[degree].subs(x, 0)
+                             for degree in range(polynomial_degree + 1)], axis=0)
+    projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
+
+    plt.figure('coeff_details')
+    plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
+    plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
+    plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
+    plt.xlabel('X')
+    plt.ylabel('Detail Coefficients')
+    plt.title('Wavelet Coefficients')
+
+
+def calculate_approximate_solution(projection, points, polynomial_degree, basis):
+    num_points = len(points)
+
+    basis_matrix = [[basis[degree].subs(x, points[point]) for point in range(num_points)]
+                    for degree in range(polynomial_degree+1)]
+
+    approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
+                   for degree in range(polynomial_degree+1))
+               for point in range(num_points)]
+              for cell in range(len(projection[0]))]
+
+    return np.reshape(np.array(approx), (1, len(approx) * num_points))
+
+
+def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_len, quadrature, init_cond):
+    grid = []
+    exact = []
+    num_periods = np.floor(wave_speed * final_time / interval_len)
+
+    for cell in range(len(mesh)):
+        eval_points = mesh[cell]+cell_len / 2 * quadrature.get_eval_points()
+
+        eval_values = []
+        for point in range(len(eval_points)):
+            new_entry = init_cond.calculate(eval_points[point] - wave_speed * final_time + num_periods * interval_len)
+            eval_values.append(new_entry)
+
+        grid.append(eval_points)
+        exact.append(eval_values)
+
+    exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0])))
+    grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))
+
+    return grid, exact
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index 35c4d241427bc7304b2f87a6f3b40a6a193cef4b..35a0f5005c7fb71d11db9fbfcc17e5a65986dfe3 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -2,7 +2,7 @@
 """
 @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
 
-TODO: Move plotting to separate file (try to adjust for different equations)
+TODO: Move plotting to separate file (try to adjust for different equations later) -> Done
 TODO: Improve figure identifiers -> Done
 
 """
@@ -13,6 +13,8 @@ import torch
 from sympy import Symbol
 
 import ANN_Model
+from Plotting import plot_solution_and_approx, plot_semilog_error, plot_error, plot_shock_tube, plot_details, \
+    calculate_approximate_solution, calculate_exact_solution
 
 x = Symbol('x')
 z = Symbol('z')
@@ -60,110 +62,40 @@ class TroubledCellDetector(object):
 
         Here come some parameter.
         """
-        cell_averages = self._calculate_approximate_solution(projection, [0], 0)
-        left_reconstructions = self._calculate_approximate_solution(projection, [-1], self._polynomial_degree)
-        right_reconstructions = self._calculate_approximate_solution(projection, [1], self._polynomial_degree)
+        cell_averages = calculate_approximate_solution(projection, [0], 0, self._basis.get_basis_vector())
+        left_reconstructions = calculate_approximate_solution(projection, [-1], self._polynomial_degree,
+                                                              self._basis.get_basis_vector())
+        right_reconstructions = calculate_approximate_solution(projection, [1], self._polynomial_degree,
+                                                               self._basis.get_basis_vector())
         middle_idx = stencil_length//2
         return np.array(list(map(np.float64, zip(cell_averages[:, :middle_idx],
                         left_reconstructions[:, middle_idx], cell_averages[:, middle_idx],
                         right_reconstructions[:, middle_idx], cell_averages[:, middle_idx+1:]))))
 
     def plot_results(self, projection, troubled_cell_history, time_history):
-        self._plot_shock_tube(troubled_cell_history, time_history)
+        plot_shock_tube(self._num_grid_cells, troubled_cell_history, time_history)
         max_error = self._plot_mesh(projection)
 
         print('p =', self._polynomial_degree)
         print('N =', self._num_grid_cells)
         print('maximum error =', max_error)
 
-    def _plot_shock_tube(self, troubled_cell_history, time_history):
-        plt.figure('shock_tube')
-        for pos in range(len(time_history)):
-            current_cells = troubled_cell_history[pos]
-            for cell in current_cells:
-                plt.plot(cell, time_history[pos], 'k.')
-        plt.xlim((0, self._num_grid_cells//2))
-        plt.xlabel('Cell')
-        plt.ylabel('Time')
-        plt.title('Shock Tubes')
-
     def _plot_mesh(self, projection):
-        grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
-        approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
-                                                      self._polynomial_degree)
+        grid, exact = calculate_exact_solution(self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
+                                               self._interval_len, self._quadrature, self._init_cond)
+        approx = calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
+                                                self._polynomial_degree, self._basis.get_basis_vector())
 
         pointwise_error = np.abs(exact-approx)
         max_error = np.max(pointwise_error)
 
-        self._plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
+        plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
         plt.legend(['Exact', 'Approx'])
-        self._plot_semilog_error(grid, pointwise_error)
-        self._plot_error(grid, exact, approx)
+        plot_semilog_error(grid, pointwise_error)
+        plot_error(grid, exact, approx)
 
         return max_error
 
-    @staticmethod
-    def _plot_solution_and_approx(grid, exact, approx, color_exact, color_approx):
-        print(color_exact, color_approx)
-        plt.figure('exact_and_approx')
-        plt.plot(grid[0], exact[0], color_exact)
-        plt.plot(grid[0], approx[0], color_approx)
-        plt.xlabel('x')
-        plt.ylabel('u(x,t)')
-        plt.title('Solution and Approximation')
-
-    @staticmethod
-    def _plot_semilog_error(grid, pointwise_error):
-        plt.figure('semilog_error')
-        plt.semilogy(grid[0], pointwise_error[0])
-        plt.xlabel('x')
-        plt.ylabel('|u(x,t)-uh(x,t)|')
-        plt.title('Semilog Error plotted at Evaluation points')
-
-    @staticmethod
-    def _plot_error(grid, exact, approx):
-        plt.figure('error')
-        plt.plot(grid[0], exact[0]-approx[0])
-        plt.xlabel('X')
-        plt.ylabel('u(x,t)-uh(x,t)')
-        plt.title('Errors')
-
-    def _calculate_exact_solution(self, mesh, cell_len):
-        grid = []
-        exact = []
-        num_periods = np.floor(self._wave_speed * self._final_time / self._interval_len)
-
-        for cell in range(len(mesh)):
-            eval_points = mesh[cell] + cell_len/2 * self._quadrature.get_eval_points()
-
-            eval_values = []
-            for point in range(len(eval_points)):
-                new_entry = self._init_cond.calculate(eval_points[point] - self._wave_speed*self._final_time
-                                                      + num_periods*self._interval_len)
-                eval_values.append(new_entry)
-
-            grid.append(eval_points)
-            exact.append(eval_values)
-
-        exact = np.reshape(np.array(exact), (1, len(exact)*len(exact[0])))
-        grid = np.reshape(np.array(grid), (1, len(grid)*len(grid[0])))
-
-        return grid, exact
-
-    def _calculate_approximate_solution(self, projection, points, polynomial_degree):
-        num_points = len(points)
-        basis = self._basis.get_basis_vector()
-
-        basis_matrix = [[basis[degree].subs(x, points[point]) for point in range(num_points)]
-                        for degree in range(polynomial_degree+1)]
-
-        approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
-                       for degree in range(polynomial_degree+1))
-                   for point in range(num_points)]
-                  for cell in range(len(projection[0]))]
-
-        return np.reshape(np.array(approx), (1, len(approx) * num_points))
-
 
 class NoDetection(TroubledCellDetector):
     def get_cells(self, projection):
@@ -237,40 +169,12 @@ class WaveletDetector(TroubledCellDetector):
         return []
 
     def plot_results(self, projection, troubled_cell_history, time_history):
-        self._plot_details(projection)
-        super().plot_results(projection, troubled_cell_history, time_history)
-
-    def _plot_details(self, projection):
-        fine_mesh = self._mesh[2:-2]
-
-        fine_projection = projection[:, 1:-1]
-        coarse_projection = self._calculate_coarse_projection(projection)
         multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
-        basis = self._basis.get_basis_vector()
-        wavelet = self._basis.get_wavelet_vector()
-
-        averaged_projection = [[coarse_projection[degree][cell] * basis[degree].subs(x, value)
-                                for cell in range(self._num_coarse_grid_cells)
-                                for value in [-0.5, 0.5]]
-                               for degree in range(self._polynomial_degree + 1)]
-
-        wavelet_projection = [[multiwavelet_coeffs[degree][cell] * wavelet[degree].subs(z, 0.5) * value
-                               for cell in range(self._num_coarse_grid_cells)
-                               for value in [(-1) ** (self._polynomial_degree + degree + 1), 1]]
-                              for degree in range(self._polynomial_degree + 1)]
-
-        projected_coarse = np.sum(averaged_projection, axis=0)
-        projected_fine = np.sum([fine_projection[degree] * basis[degree].subs(x, 0)
-                                 for degree in range(self._polynomial_degree + 1)], axis=0)
-        projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
-
-        plt.figure('coeff_details')
-        plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
-        plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
-        plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
-        plt.xlabel('X')
-        plt.ylabel('Detail Coefficients')
-        plt.title('Wavelet Coefficients')
+        coarse_projection = self._calculate_coarse_projection(projection)
+        plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection, self._basis.get_basis_vector(),
+                     self._basis.get_wavelet_vector(), multiwavelet_coeffs, self._num_coarse_grid_cells,
+                     self._polynomial_degree)
+        super().plot_results(projection, troubled_cell_history, time_history)
 
     def _calculate_coarse_projection(self, projection):
         basis_projection_left, basis_projection_right = self._basis.get_basis_projections()
@@ -289,18 +193,19 @@ class WaveletDetector(TroubledCellDetector):
         return coarse_projection
 
     def _plot_mesh(self, projection):
-        grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
-        approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
-                                                      self._polynomial_degree)
+        grid, exact = calculate_exact_solution(self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
+                                               self._interval_len, self._quadrature, self._init_cond)
+        approx = calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
+                                                self._polynomial_degree, self._basis.get_basis_vector())
 
         pointwise_error = np.abs(exact-approx)
         max_error = np.max(pointwise_error)
 
         self._plot_coarse_mesh(projection)
-        self._plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'], self._colors['fine_approx'])
+        plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'], self._colors['fine_approx'])
         plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', 'Approx (Fine)'])
-        self._plot_semilog_error(grid, pointwise_error)
-        self._plot_error(grid, exact, approx)
+        plot_semilog_error(grid, pointwise_error)
+        plot_error(grid, exact, approx)
 
         return max_error
 
@@ -312,10 +217,11 @@ class WaveletDetector(TroubledCellDetector):
         coarse_projection = self._calculate_coarse_projection(projection)
 
         # Plot exact and approximate solutions for coarse mesh
-        grid, exact = self._calculate_exact_solution(coarse_mesh[1:-1], coarse_cell_len)
-        approx = self._calculate_approximate_solution(coarse_projection, self._quadrature.get_eval_points(),
-                                                      self._polynomial_degree)
-        self._plot_solution_and_approx(grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])
+        grid, exact = calculate_exact_solution(coarse_mesh[1:-1], coarse_cell_len, self._wave_speed, self._final_time,
+                                               self._interval_len, self._quadrature, self._init_cond)
+        approx = calculate_approximate_solution(coarse_projection, self._quadrature.get_eval_points(),
+                                                self._polynomial_degree, self._basis.get_basis_vector())
+        plot_solution_and_approx(grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])
 
 
 class Boxplot(WaveletDetector):