From a258873485af97d100a8d0c0ab1ee400a9e86961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Sun, 9 Oct 2022 02:57:13 +0200 Subject: [PATCH] Vectorized '_build_multiwavelet_matrix()' in OrthonormalLegendre basis. --- Snakefile | 9 +++++++- scripts/tcd/Basis_Function.py | 42 +++++++++++++++++------------------ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/Snakefile b/Snakefile index aad8e25..b6a5ce1 100644 --- a/Snakefile +++ b/Snakefile @@ -23,7 +23,13 @@ TODO: Discuss name for quadrature mesh (now: grid) TODO: Contemplate using lambdify for basis Urgent: -TODO: Extract object initialization from DGScheme +TODO: Extract object initialization from DGScheme -> Done +TODO: Extract seaborn initialization from DGScheme -> Done +TODO: Vectorize '_normalize_data()' during data generation -> Done +TODO: Vectorize '_build_inverse_mass_matrix()' in OrthonormalLegendre -> Done +TODO: Vectorize '_build_basis_matrix()' in OrthonormalLegendre -> Done +TODO: Vectorize '_build_multiwavelet_matrix()' in OrthonormalLegendre -> Done +TODO: Vectorize '_calculate_reconstructions()' in OrthonormalLegendre TODO: Replace loops/list comprehension with vectorization if feasible TODO: Replace loops with list comprehension if feasible TODO: Rework ICs to allow vector input @@ -37,6 +43,7 @@ TODO: Combine ANN workflows if feasible TODO: Investigate profiling for speed up Critical, but not urgent: +TODO: Check whether all ValueError are set (correctly) TODO: Introduce env files for each SM rule TODO: Add an environment file for Snakemake TODO: Rename files according to standard diff --git a/scripts/tcd/Basis_Function.py b/scripts/tcd/Basis_Function.py index ef334a9..b5463d5 100644 --- a/scripts/tcd/Basis_Function.py +++ b/scripts/tcd/Basis_Function.py @@ -403,15 +403,14 @@ class OrthonormalLegendre(Legendre): Matrix containing the integral of basis products. """ - basis_row = np.array([self.basis[idx].subs(x, first_param) for idx in + basis_col = np.array([self.basis[idx].subs(x, first_param) for idx in range(self._polynomial_degree+1)])[: np.newaxis] - basis_col = np.array([self.basis[idx].subs(x, second_param) for idx in + basis_row = np.array([self.basis[idx].subs(x, second_param) for idx in range(self._polynomial_degree+1)])[: np.newaxis] - basis_matrix = np.matmul(basis_row[:, np.newaxis], - basis_col[:, np.newaxis].T) - basis_matrix = np.float64(np.vectorize( + basis_matrix = np.matmul(basis_col[:, np.newaxis], + basis_row[:, np.newaxis].T) + return np.float64(np.vectorize( lambda y: integrate(y, (z, -1, 1)))(basis_matrix)) - return basis_matrix @property @cache @@ -436,9 +435,9 @@ class OrthonormalLegendre(Legendre): z, 0.5*(z+1), False) return left_wavelet_projection, right_wavelet_projection - def _build_multiwavelet_matrix(self, first_param: float, - second_param: float, is_left_matrix: bool) \ - -> ndarray: + def _build_multiwavelet_matrix( + self, first_param: float, second_param: float, is_left_matrix: + bool) -> ndarray: """Construct a multiwavelet matrix. Parameters @@ -457,18 +456,19 @@ class OrthonormalLegendre(Legendre): vector. """ - matrix = [] - for i in range(self._polynomial_degree+1): - row = [] - for j in range(self._polynomial_degree+1): - entry = integrate(self.basis[i].subs(x, first_param) - * self.wavelet[j].subs(z, second_param), - (z, -1, 1)) - if is_left_matrix: - entry = entry * (-1)**(j + self._polynomial_degree + 1) - row.append(np.float64(entry)) - matrix.append(row) - return np.array(matrix) + basis_col = np.array([ + self.basis[idx].subs(x, first_param) for idx in + range(self._polynomial_degree+1)])[: np.newaxis] + basis_row = np.array([ + self.wavelet[idx].subs(z, second_param) for idx in + range(self._polynomial_degree+1)])[: np.newaxis] + if is_left_matrix: + basis_row *= np.array([(-1)**(idx+self._polynomial_degree+1) for + idx in range(self._polynomial_degree+1)]) + wavelet_matrix = np.matmul(basis_col[:, np.newaxis], + basis_row[:, np.newaxis].T) + return np.float64(np.vectorize( + lambda y: integrate(y, (z, -1, 1)))(wavelet_matrix)) def calculate_cell_average(self, projection: ndarray, stencil_len: int, add_reconstructions: bool = True) -> ndarray: -- GitLab