Skip to content
Snippets Groups Projects
Commit 7c2c0818 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Extracted solution calculations from 'Plotting'.

parent 978b1594
Branches
No related tags found
No related merge requests found
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
TODO: Give option to select plotting color TODO: Give option to select plotting color
""" """
from typing import Tuple
import numpy as np import numpy as np
import matplotlib import matplotlib
...@@ -14,9 +13,6 @@ import seaborn as sns ...@@ -14,9 +13,6 @@ import seaborn as sns
from numpy import ndarray from numpy import ndarray
from sympy import Symbol from sympy import Symbol
from Quadrature import Quadrature
from Initial_Condition import InitialCondition
matplotlib.use('Agg') matplotlib.use('Agg')
x = Symbol('x') x = Symbol('x')
...@@ -171,96 +167,6 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray, ...@@ -171,96 +167,6 @@ def plot_details(fine_projection: ndarray, fine_mesh: ndarray,
plt.title('Wavelet Coefficients') plt.title('Wavelet Coefficients')
def calculate_approximate_solution(
projection: ndarray, points: ndarray, polynomial_degree: int,
basis: ndarray) -> ndarray:
"""Calculates approximate solution.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
points : ndarray
List of evaluation points for mesh.
polynomial_degree : int
Polynomial degree.
basis : ndarray
Basis vector for calculation.
Returns
-------
ndarray
Array containing approximate evaluation of a function.
"""
num_points = len(points)
basis_matrix = [[basis[degree].subs(x, points[point])
for point in range(num_points)]
for degree in range(polynomial_degree+1)]
approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
for degree in range(polynomial_degree+1))
for point in range(num_points)]
for cell in range(len(projection[0]))]
return np.reshape(np.array(approx), (1, len(approx) * num_points))
def calculate_exact_solution(
mesh: ndarray, cell_len: float, wave_speed: float, final_time:
float, interval_len: float, quadrature: Quadrature, init_cond:
InitialCondition) -> Tuple[ndarray, ndarray]:
"""Calculates exact solution.
Parameters
----------
mesh : ndarray
List of mesh valuation points.
cell_len : float
Length of a cell in mesh.
wave_speed : float
Speed of wave in rightward direction.
final_time : float
Final time for which approximation is calculated.
interval_len : float
Length of the interval between left and right boundary.
quadrature : Quadrature object
Quadrature for evaluation.
init_cond : InitialCondition object
Initial condition for evaluation.
Returns
-------
grid : ndarray
Array containing evaluation grid for a function.
exact : ndarray
Array containing exact evaluation of a function.
"""
grid = []
exact = []
num_periods = np.floor(wave_speed * final_time / interval_len)
for cell in range(len(mesh)):
eval_points = mesh[cell]+cell_len / 2 * quadrature.get_eval_points()
eval_values = []
for point in range(len(eval_points)):
new_entry = init_cond.calculate(eval_points[point]
- wave_speed * final_time
+ num_periods * interval_len)
eval_values.append(new_entry)
grid.append(eval_points)
exact.append(eval_values)
exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0])))
grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))
return grid, exact
def plot_classification_barplot(evaluation_dict: dict, colors: dict) -> None: def plot_classification_barplot(evaluation_dict: dict, colors: dict) -> None:
"""Plots classification accuracy. """Plots classification accuracy.
......
...@@ -13,17 +13,14 @@ import matplotlib ...@@ -13,17 +13,14 @@ import matplotlib
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import seaborn as sns import seaborn as sns
import torch import torch
from sympy import Symbol
import ANN_Model import ANN_Model
from Plotting import plot_solution_and_approx, plot_semilog_error, \ from Plotting import plot_solution_and_approx, plot_semilog_error, \
plot_error, plot_shock_tube, plot_details, \ plot_error, plot_shock_tube, plot_details
from projection_utils import calculate_cell_average,\
calculate_approximate_solution, calculate_exact_solution calculate_approximate_solution, calculate_exact_solution
from projection_utils import calculate_cell_average
matplotlib.use('Agg') matplotlib.use('Agg')
x = Symbol('x')
z = Symbol('z')
class TroubledCellDetector: class TroubledCellDetector:
......
...@@ -3,9 +3,107 @@ ...@@ -3,9 +3,107 @@
@author: Laura C. Kühle @author: Laura C. Kühle
""" """
from typing import Tuple
import numpy as np import numpy as np
from numpy import ndarray
from sympy import Symbol
from Quadrature import Quadrature
from Initial_Condition import InitialCondition
x = Symbol('x')
def calculate_approximate_solution(
projection: ndarray, points: ndarray, polynomial_degree: int,
basis: ndarray) -> ndarray:
"""Calculate approximate solution.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
points : ndarray
List of evaluation points for mesh.
polynomial_degree : int
Polynomial degree.
basis : ndarray
Basis vector for calculation.
Returns
-------
ndarray
Array containing approximate evaluation of a function.
"""
num_points = len(points)
basis_matrix = [[basis[degree].subs(x, points[point])
for point in range(num_points)]
for degree in range(polynomial_degree+1)]
approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
for degree in range(polynomial_degree+1))
for point in range(num_points)]
for cell in range(len(projection[0]))]
return np.reshape(np.array(approx), (1, len(approx) * num_points))
def calculate_exact_solution(
mesh: ndarray, cell_len: float, wave_speed: float, final_time:
float, interval_len: float, quadrature: Quadrature, init_cond:
InitialCondition) -> Tuple[ndarray, ndarray]:
"""Calculate exact solution.
Parameters
----------
mesh : ndarray
List of mesh evaluation points.
cell_len : float
Length of a cell in mesh.
wave_speed : float
Speed of wave in rightward direction.
final_time : float
Final time for which approximation is calculated.
interval_len : float
Length of the interval between left and right boundary.
quadrature : Quadrature object
Quadrature for evaluation.
init_cond : InitialCondition object
Initial condition for evaluation.
Returns
-------
grid : ndarray
Array containing evaluation grid for a function.
exact : ndarray
Array containing exact evaluation of a function.
"""
grid = []
exact = []
num_periods = np.floor(wave_speed * final_time / interval_len)
for cell_center in mesh:
eval_points = cell_center+cell_len / 2 * quadrature.get_eval_points()
eval_values = []
for eval_point in eval_points:
new_entry = init_cond.calculate(eval_point
- wave_speed * final_time
+ num_periods * interval_len)
eval_values.append(new_entry)
grid.append(eval_points)
exact.append(eval_values)
exact = np.reshape(np.array(exact), (1, len(exact) * len(exact[0])))
grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0])))
from Plotting import calculate_approximate_solution return grid, exact
def calculate_cell_average(projection, basis, stencil_length, def calculate_cell_average(projection, basis, stencil_length,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment