diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index ce9de57e5350f8db33217aa6590df37ba667df7a..ea478462e860f265d812d7627bb78b689fbe0a7e 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -3,10 +3,11 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Give option to choose from multiwavelet degrees (first, last or - highest magnitude) + highest magnitude) -> Done TODO: Change method of calculating quartiles TODO: Include overlapping cells in quartile calculation (if needed) TODO: Determine max_value for Theoretical only over highest degree + -> Done (optional) TODO: Check if indexing in wavelets is correct TODO: Combine get_cells() and _get_cells() TODO: Add TC condition to only flag cell if left-adjacent one is flagged as @@ -219,9 +220,19 @@ class WaveletDetector(TroubledCellDetector): config : dict Additional parameters for detector. + Raises + ------ + ValueError + If multiwavelet degree is not in ['first', 'last', 'max']. + """ super()._reset(config) + self._wavelet_degree = config.pop('multiwavelet_degree', 'first') + if self._wavelet_degree not in ['first', 'last', 'max']: + raise ValueError('Invalid entry for multiwavelet degree. It must ' + 'be either "first", "last", or "max".') + # Set wavelet projections self._wavelet_projection_left, self._wavelet_projection_right \ = self._basis.multiwavelet_projection @@ -241,6 +252,7 @@ class WaveletDetector(TroubledCellDetector): """ multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) + multiwavelet_coeffs = self._select_degree(multiwavelet_coeffs) return self._get_cells(multiwavelet_coeffs, projection) def _calculate_wavelet_coeffs(self, projection): @@ -254,7 +266,7 @@ class WaveletDetector(TroubledCellDetector): Returns ------- ndarray - Matrix of wavelet coefficients. + Matrix of multiwavelet coefficients. """ output_matrix = [] @@ -265,6 +277,32 @@ class WaveletDetector(TroubledCellDetector): output_matrix.append(new_entry) return np.transpose(np.array(output_matrix)) + def _select_degree(self, wavelet_matrix): + """Select degree of wavelet coefficients for troubled cell detection. + + Select either the first, last, or highest megnitude degree for each + cell from the multiwavelet coefficients. + + Parameters + ---------- + wavelet_matrix : ndarray + Matrix of multiwavelet coefficients. + + Returns + ------- + ndarray + Matrix of multiwavelet coefficients of selected degree. + + """ + if self._wavelet_degree == 'first': + return wavelet_matrix[0] + elif self._wavelet_degree == 'last': + return wavelet_matrix[-1] + else: + max_values = np.max(wavelet_matrix, axis=0) + min_values = np.min(wavelet_matrix, axis=0) + return np.where(-min_values > max_values, min_values, max_values) + @abstractmethod def _get_cells(self, multiwavelet_coeffs, projection): """Calculates troubled cells using multiwavelet coefficients. @@ -378,12 +416,12 @@ class Boxplot(WaveletDetector): fold * self._fold_len - num_overlapping_cells, (fold+1) * self._fold_len + num_overlapping_cells)]) - def _get_cells(self, multiwavelet_coeffs, projection): + def _get_cells(self, coeffs, projection): """Calculate troubled cells using multiwavelet coefficients. Parameters ---------- - multiwavelet_coeffs : ndarray + coeffs : ndarray Matrix of multiwavelet coefficients. projection : ndarray Matrix of projection for each polynomial degree. @@ -395,7 +433,6 @@ class Boxplot(WaveletDetector): """ # Select and sort fold domains - coeffs = multiwavelet_coeffs[0] folds = coeffs[self._fold_indices] folds.sort() @@ -477,12 +514,12 @@ class Theoretical(WaveletDetector): np.sqrt(2) * self._mesh.cell_len) # comment to line above: or 2 or 3 - def _get_cells(self, multiwavelet_coeffs, projection): + def _get_cells(self, coeffs, projection): """Calculates troubled cells using multiwavelet coefficients. Parameters ---------- - multiwavelet_coeffs : ndarray + coeffs : ndarray Matrix of multiwavelet coefficients. projection : ndarray Matrix of projection for each polynomial degree. @@ -499,17 +536,17 @@ class Theoretical(WaveletDetector): for cell in range(self._mesh.num_grid_cells))) for cell in range(self._mesh.num_grid_cells): - if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg): + if self._is_troubled_cell(coeffs, cell, max_avg): troubled_cells.append(cell) return troubled_cells - def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg): + def _is_troubled_cell(self, coeffs, cell, max_avg): """Checks whether a cell is troubled. Parameters ---------- - multiwavelet_coeffs : ndarray + coeffs : ndarray Matrix of multiwavelet coefficients. cell : int Index of cell. @@ -522,9 +559,7 @@ class Theoretical(WaveletDetector): Flag whether cell is troubled. """ - max_value = max(abs(multiwavelet_coeffs[degree][cell]) - for degree in range( - self._basis.polynomial_degree+1))/max_avg + max_value = abs(coeffs[cell])/max_avg eps = self._cutoff_factor\ / (self._mesh.cell_len*self._mesh.num_grid_cells)