diff --git a/Plotting.py b/Plotting.py index 0db5713422ac89ee249757f7e0b96186ed61d1f3..5c63e04dcce824b480b00d81ced9b58a2ca65940 100644 --- a/Plotting.py +++ b/Plotting.py @@ -5,7 +5,6 @@ TODO: Give option to select plotting color """ -from typing import Tuple import numpy as np import matplotlib @@ -14,9 +13,6 @@ import seaborn as sns from numpy import ndarray from sympy import Symbol -from Quadrature import Quadrature -from Initial_Condition import InitialCondition - matplotlib.use('Agg') x = Symbol('x') @@ -171,96 +167,6 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, plt.title('Wavelet Coefficients') -def calculate_approximate_solution( - projection: ndarray, points: ndarray, polynomial_degree: int, - basis: ndarray) -> ndarray: - """Calculates approximate solution. - - Parameters - ---------- - projection : ndarray - Matrix of projection for each polynomial degree. - points : ndarray - List of evaluation points for mesh. - polynomial_degree : int - Polynomial degree. - basis : ndarray - Basis vector for calculation. - - Returns - ------- - ndarray - Array containing approximate evaluation of a function. - - """ - num_points = len(points) - - basis_matrix = [[basis[degree].subs(x, points[point]) - for point in range(num_points)] - for degree in range(polynomial_degree+1)] - - approx = [[sum(projection[degree][cell] * basis_matrix[degree][point] - for degree in range(polynomial_degree+1)) - for point in range(num_points)] - for cell in range(len(projection[0]))] - - return np.reshape(np.array(approx), (1, len(approx) * num_points)) - - -def calculate_exact_solution( - mesh: ndarray, cell_len: float, wave_speed: float, final_time: - float, interval_len: float, quadrature: Quadrature, init_cond: - InitialCondition) -> Tuple[ndarray, ndarray]: - """Calculates exact solution. - - Parameters - ---------- - mesh : ndarray - List of mesh valuation points. - cell_len : float - Length of a cell in mesh. - wave_speed : float - Speed of wave in rightward direction. - final_time : float - Final time for which approximation is calculated. - interval_len : float - Length of the interval between left and right boundary. - quadrature : Quadrature object - Quadrature for evaluation. - init_cond : InitialCondition object - Initial condition for evaluation. - - Returns - ------- - grid : ndarray - Array containing evaluation grid for a function. - exact : ndarray - Array containing exact evaluation of a function. - - """ - grid = [] - exact = [] - num_periods = np.floor(wave_speed * final_time / interval_len) - - for cell in range(len(mesh)): - eval_points = mesh[cell]+cell_len / 2 * quadrature.get_eval_points() - - eval_values = [] - for point in range(len(eval_points)): - new_entry = init_cond.calculate(eval_points[point] - - wave_speed * final_time - + num_periods * interval_len) - eval_values.append(new_entry) - - grid.append(eval_points) - exact.append(eval_values) - - exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0]))) - grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0]))) - - return grid, exact - - def plot_classification_barplot(evaluation_dict: dict, colors: dict) -> None: """Plots classification accuracy. diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index 2d8aaeb737ccff3167aeb76a7890a173ac0c9ac3..f2d8691e935c2251b1c1a504fe8872612299adff 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -13,17 +13,14 @@ import matplotlib from matplotlib import pyplot as plt import seaborn as sns import torch -from sympy import Symbol import ANN_Model from Plotting import plot_solution_and_approx, plot_semilog_error, \ - plot_error, plot_shock_tube, plot_details, \ + plot_error, plot_shock_tube, plot_details +from projection_utils import calculate_cell_average,\ calculate_approximate_solution, calculate_exact_solution -from projection_utils import calculate_cell_average matplotlib.use('Agg') -x = Symbol('x') -z = Symbol('z') class TroubledCellDetector: diff --git a/projection_utils.py b/projection_utils.py index e860b102f99160a7a8400f3ca24beeb0ba8ecd15..ebc0d0941545c8c9a9e6c653416f33713cf59de2 100644 --- a/projection_utils.py +++ b/projection_utils.py @@ -3,9 +3,107 @@ @author: Laura C. Kühle """ + +from typing import Tuple import numpy as np +from numpy import ndarray +from sympy import Symbol + +from Quadrature import Quadrature +from Initial_Condition import InitialCondition + + +x = Symbol('x') + + +def calculate_approximate_solution( + projection: ndarray, points: ndarray, polynomial_degree: int, + basis: ndarray) -> ndarray: + """Calculate approximate solution. + + Parameters + ---------- + projection : ndarray + Matrix of projection for each polynomial degree. + points : ndarray + List of evaluation points for mesh. + polynomial_degree : int + Polynomial degree. + basis : ndarray + Basis vector for calculation. + + Returns + ------- + ndarray + Array containing approximate evaluation of a function. + + """ + num_points = len(points) + + basis_matrix = [[basis[degree].subs(x, points[point]) + for point in range(num_points)] + for degree in range(polynomial_degree+1)] + + approx = [[sum(projection[degree][cell] * basis_matrix[degree][point] + for degree in range(polynomial_degree+1)) + for point in range(num_points)] + for cell in range(len(projection[0]))] + + return np.reshape(np.array(approx), (1, len(approx) * num_points)) + + +def calculate_exact_solution( + mesh: ndarray, cell_len: float, wave_speed: float, final_time: + float, interval_len: float, quadrature: Quadrature, init_cond: + InitialCondition) -> Tuple[ndarray, ndarray]: + """Calculate exact solution. + + Parameters + ---------- + mesh : ndarray + List of mesh evaluation points. + cell_len : float + Length of a cell in mesh. + wave_speed : float + Speed of wave in rightward direction. + final_time : float + Final time for which approximation is calculated. + interval_len : float + Length of the interval between left and right boundary. + quadrature : Quadrature object + Quadrature for evaluation. + init_cond : InitialCondition object + Initial condition for evaluation. + + Returns + ------- + grid : ndarray + Array containing evaluation grid for a function. + exact : ndarray + Array containing exact evaluation of a function. + + """ + grid = [] + exact = [] + num_periods = np.floor(wave_speed * final_time / interval_len) + + for cell_center in mesh: + eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points() + + eval_values = [] + for eval_point in eval_points: + new_entry = init_cond.calculate(eval_point + - wave_speed * final_time + + num_periods * interval_len) + eval_values.append(new_entry) + + grid.append(eval_points) + exact.append(eval_values) + + exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0]))) + grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0]))) -from Plotting import calculate_approximate_solution + return grid, exact def calculate_cell_average(projection, basis, stencil_length,