Skip to content
Snippets Groups Projects
Commit cbcf57bf authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Included 'check_wavelet()' for all cases in 'build_training_data().

parent 97f71771
No related branches found
No related tags found
No related merge requests found
...@@ -133,15 +133,9 @@ class TrainingDataGenerator(object): ...@@ -133,15 +133,9 @@ class TrainingDataGenerator(object):
quadrature_config={'num_eval_points': polynomial_degree+1}) quadrature_config={'num_eval_points': polynomial_degree+1})
if initial_condition.is_smooth(): if initial_condition.is_smooth():
basis_coeffs = dg_scheme.build_training_data(0, initial_condition) input_data[i] = dg_scheme.build_training_data(0, initial_condition)
else: else:
basis_coeffs = dg_scheme.build_training_data(centers[1], initial_condition) input_data[i] = dg_scheme.build_training_data(centers[1], initial_condition)
data_center = dg_scheme.check_wavelet(basis_coeffs, self._quadrature_center, 0)
data_left = dg_scheme.check_wavelet(basis_coeffs, self._quadrature_left, polynomial_degree)
data_right = dg_scheme.check_wavelet(basis_coeffs, self._quadrature_right, polynomial_degree)
input_data[i] = np.array(list(zip(data_center[:, 0], data_left[:, 1], data_center[:, 1],
data_right[:, 1], data_center[:, 2])))
# Update Function ID # Update Function ID
if (i % num_function_samples == num_function_samples - 1) and (function_id != len(initial_conditions)-1): if (i % num_function_samples == num_function_samples - 1) and (function_id != len(initial_conditions)-1):
......
...@@ -8,6 +8,8 @@ TODO: Replace loops with list comprehension if feasible ...@@ -8,6 +8,8 @@ TODO: Replace loops with list comprehension if feasible
TODO: Write documentation for all methods TODO: Write documentation for all methods
TODO: Contemplate how to make shock tubes comparable TODO: Contemplate how to make shock tubes comparable
TODO: Fix bug in approximation -> Done (used num_grid_cells instead of cell_len for cfl_number in last step) TODO: Fix bug in approximation -> Done (used num_grid_cells instead of cell_len for cfl_number in last step)
TODO: Include check_wavelet() for all cases in build_training_data() -> Done
TODO: Fix data_left/right/center to create data with an x-point stencil
""" """
import numpy as np import numpy as np
...@@ -188,8 +190,9 @@ class DGScheme(object): ...@@ -188,8 +190,9 @@ class DGScheme(object):
if initial_condition is None: if initial_condition is None:
initial_condition = self._init_cond initial_condition = self._init_cond
projection = self._do_initial_projection(initial_condition, adjustment) projection = self._do_initial_projection(initial_condition, adjustment)
return projection[:, 1:-1]
def check_wavelet(self, projection, quadrature, polynomial_degree): data_center = self._detector.calculate_approximate_solution(projection[:, 1:-1], [0], 0)
projection = self._detector.calculate_approximate_solution(projection, quadrature, polynomial_degree) data_left = self._detector.calculate_approximate_solution(projection[:, 1:-1], [-1], self._polynomial_degree)
return projection data_right = self._detector.calculate_approximate_solution(projection[:, 1:-1], [1], self._polynomial_degree)
return np.array(list(zip(data_center[:, 0], data_left[:, 1], data_center[:, 1], data_right[:, 1],
data_center[:, 2])))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment