# -*- coding: utf-8 -*- """ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) """ from typing import Tuple from abc import ABC, abstractmethod import numpy as np from numpy import ndarray import torch from . import ANN_Model from .Boundary_Condition import enforce_boxplot_boundaries, \ enforce_fold_boundaries from .Mesh import Mesh class TroubledCellDetector(ABC): """Abstract class for troubled-cell detection. Detects troubled cells, i.e., cells in the mesh containing instabilities. Methods ------- get_name() Returns string of class name. get_cells(projection) Calculates troubled cells in a given projection. create_data_dict(projection) Return dictionary with data necessary to plot troubled cells. """ def __init__(self, config, basis, mesh): """Initializes TroubledCellDetector. Parameters ---------- config : dict Additional parameters for detector. basis : Basis object Basis for calculation. mesh : Mesh Mesh for calculation. """ self._mesh = mesh self._basis = basis self._reset(config) def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ pass def get_name(self): """Returns string of class name.""" return self.__class__.__name__ @abstractmethod def get_cells(self, projection): """Calculates troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. """ pass def create_data_dict(self, projection): """Return dictionary with data necessary to plot troubled cells.""" return {'projection': projection, 'basis': self._basis.create_data_dict(), 'mesh': self._mesh.create_data_dict() } class NoDetection(TroubledCellDetector): """Class without any troubled-cell detection. Methods ------- get_cells(projection) Returns no troubled cells. """ def get_cells(self, projection): """Returns no troubled cells. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ return [] class ArtificialNeuralNetwork(TroubledCellDetector): """Class for troubled-cell detection using ANNs. Attributes ---------- stencil_length : int Size of input data array. model : torch.nn.Model ANN model instance for evaluation. Methods ------- get_cells(projection) Calculates troubled cells in a given projection. """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) self._stencil_len = config.pop('stencil_len', 3) self._add_reconstructions = config.pop('add_reconstructions', True) # Set mask for stencil sliding window self._window_mask = np.arange(self._mesh.num_cells)[None, :] + \ np.arange(self._stencil_len)[:, None] self._model = config.pop('model', 'ThreeLayerReLu') num_datapoints = self._stencil_len if self._add_reconstructions: num_datapoints += 2 self._model_config = config.pop('model_config', { 'input_size': num_datapoints, 'first_hidden_size': 8, 'second_hidden_size': 4, 'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}}) model_state = config.pop('model_state', 'Snakemake-Test/trained ' 'models/model__Adam.pt') if not hasattr(ANN_Model, self._model): raise ValueError('Invalid model: "%s"' % self._model) self._model = getattr(ANN_Model, self._model)(self._model_config) # Load the model state and set it to evaluation mode self._model.load_state_dict(torch.load(str(model_state))) self._model.eval() def get_cells(self, projection): """Calculates troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ # Reset ghost cells to adjust for stencil length num_ghost_cells = self._stencil_len//2 projection = projection[:, self._mesh.num_ghost_cells: -self._mesh.num_ghost_cells] projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]), axis=1) # Calculate input data depending on stencil length projection_window = projection[:, self._window_mask] input_data = torch.from_numpy(self._basis.calculate_cell_average( projection=projection_window, stencil_len=self._stencil_len, add_reconstructions=self._add_reconstructions)) # Determine troubled cells model_output = torch.argmax(self._model(input_data.float()), dim=1) return np.flatnonzero(model_output.numpy()).tolist() class WaveletDetector(TroubledCellDetector): """Abstract class for wavelet coefficient based troubled-cell detection. ??? """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. Raises ------ ValueError If multiwavelet degree is not in ['first', 'last', 'max']. """ super()._reset(config) self._wavelet_degree = config.pop('multiwavelet_degree', 'first') if self._wavelet_degree not in ['first', 'last', 'max']: raise ValueError('Invalid entry for multiwavelet degree. It must ' 'be either "first", "last", or "max".') # Set wavelet projections self._wavelet_projection_left, self._wavelet_projection_right \ = self._basis.multiwavelet_projection def _calculate_wavelet_coeffs(self, projection): """Calculates wavelet coefficients used for projection to coarser mesh. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- ndarray Matrix of multiwavelet coefficients. """ return 0.5 * (self._wavelet_projection_left.T @ projection[:, self._mesh.num_ghost_cells: -self._mesh.num_ghost_cells] + self._wavelet_projection_right.T @ projection[:, self._mesh.num_ghost_cells+1: projection.shape[-1] - self._mesh.num_ghost_cells+1]) def _select_degree(self, wavelet_matrix): """Select degree of wavelet coefficients for troubled cell detection. Select either the first, last, or highest megnitude degree for each cell from the multiwavelet coefficients. Parameters ---------- wavelet_matrix : ndarray Matrix of multiwavelet coefficients. Returns ------- ndarray Matrix of multiwavelet coefficients of selected degree. """ if self._wavelet_degree == 'first': return wavelet_matrix[0] elif self._wavelet_degree == 'last': return wavelet_matrix[-1] else: max_values = np.max(wavelet_matrix, axis=0) min_values = np.min(wavelet_matrix, axis=0) return np.where(-min_values > max_values, min_values, max_values) def _calculate_coarse_projection(self, projection): """Calculates coarse projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- ndarray Matrix of projection on coarse mesh for each polynomial degree. """ basis_projection_left, basis_projection_right\ = self._basis.basis_projection # Remove ghost cells projection = projection[:, self._mesh.num_ghost_cells: -self._mesh.num_ghost_cells] # Calculate projection on coarse mesh return 0.5 * (basis_projection_left.T @ projection[:, ::2] + basis_projection_right.T @ projection[:, 1::2]) def create_data_dict(self, projection): """Return dictionary with data necessary to plot troubled cells.""" # Create general directory data_dict = super().create_data_dict(projection) # Save multiwavelet-specific data in dictionary multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) coarse_projection = self._calculate_coarse_projection(projection) data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs data_dict['coarse_projection'] = coarse_projection return data_dict class Boxplot(WaveletDetector): """Class for troubled-cell detection based on Boxplots. Attributes ---------- fold_len : int Length of folds considered in one Boxplot. whisker_len : int Length of Boxplot whiskers. adjust_outer_fences : bool Flag whether outer fences should be adjusted using global mean. extreme_outlier_only : bool Flag whether outliers also have to be detected in neighbouring folds. quantile_method : str Method used to calculate quantiles. fold_indices : ndarray Array with indices for elements of each fold (including overlaps). """ def _reset(self, config): """Reset instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) # Unpack necessary configurations self._fold_len = min(self._mesh.num_cells, config.pop('fold_len', 16)) self._whisker_len = config.pop('whisker_len', 3) self._adjust_outer_fences = config.pop('adjust_outer_fences', True) self._extreme_outlier_only = config.pop('extreme_outlier_only', True) self._quantile_method = config.pop('quantile_method', 'weibull') num_overlapping_cells = config.pop('num_overlapping_cells', 1) self._fold_indices = self._compute_folds(num_overlapping_cells) @enforce_fold_boundaries def _compute_folds(self, num_overlapping_cells: int) -> ndarray: """Compute indices for all folds used in Boxplot calculation. Parameters ---------- num_overlapping_cells : int Number of cells overlapping between adjacent folds. Returns ------- fold_indices : ndarray Array of projection indices in each fold. """ num_folds = self._mesh.num_cells//self._fold_len fold_indices = np.zeros([num_folds, self._fold_len + 2 * num_overlapping_cells]).astype(np.int32) for fold in range(num_folds): fold_indices[fold] = np.array([i for i in range( fold * self._fold_len - num_overlapping_cells, (fold+1) * self._fold_len + num_overlapping_cells)]) return fold_indices def get_cells(self, projection): """Calculate troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ # Determine bounds coeffs = self._select_degree(self._calculate_wavelet_coeffs( projection)) lower_bounds, upper_bounds = self._compute_bounds(coeffs) # Adjust outer fences if flag is set if self._adjust_outer_fences: global_mean = np.mean(abs(coeffs)) lower_bounds[lower_bounds > -global_mean] = -global_mean upper_bounds[upper_bounds < global_mean] = global_mean # Select outliers as troubled cells lower_outlier = coeffs < np.repeat(lower_bounds[1:-1], self._fold_len) upper_outlier = coeffs > np.repeat(upper_bounds[1:-1], self._fold_len) # Adjust for extreme outliers if flag is set if self._extreme_outlier_only: lower_outlier = np.logical_and(lower_outlier, np.logical_and( coeffs < np.repeat(lower_bounds[:-2], self._fold_len), coeffs < np.repeat(lower_bounds[2:], self._fold_len))) upper_outlier = np.logical_and(upper_outlier, np.logical_and( coeffs > np.repeat(upper_bounds[:-2], self._fold_len), coeffs > np.repeat(upper_bounds[2:], self._fold_len))) troubled_cells = np.flatnonzero(np.logical_or(lower_outlier, upper_outlier)).tolist() return troubled_cells @enforce_boxplot_boundaries def _compute_bounds(self, coeffs: ndarray) -> Tuple[ndarray, ndarray]: """Compute lower and upper bound for Boxplot outliers. Parameters ---------- coeffs : ndarray Matrix of multiwavelet coefficients of projection. Returns ------- lower_bounds : ndarray Array of lower bounds for outlier. upper_bounds : ndarray Array of upper bounds for outlier. """ # Determine quartiles of folds folds = coeffs[self._fold_indices] first_quartiles = np.quantile(folds, 0.25, axis=1, method=self._quantile_method) third_quartiles = np.quantile(folds, 0.75, axis=1, method=self._quantile_method) # Determine bounds based on quartiles of a boxplot lower_bounds = np.zeros(len(first_quartiles) + 2) upper_bounds = np.zeros(len(first_quartiles) + 2) lower_bounds[1:-1] = first_quartiles - self._whisker_len * ( third_quartiles-first_quartiles) upper_bounds[1:-1] = third_quartiles + self._whisker_len * ( third_quartiles-first_quartiles) return lower_bounds, upper_bounds class Theoretical(WaveletDetector): """Class for troubled-cell detection based on theoretical thresholding. Attributes ---------- threshold : float Threshold above which a cell is considered troubled. """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) # Unpack necessary configurations cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._mesh.cell_len) self._threshold = cutoff_factor / ( self._mesh.cell_len*self._mesh.num_cells) def get_cells(self, projection): """Calculate troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ coeffs = self._select_degree(self._calculate_wavelet_coeffs( projection)) max_avg = np.sqrt(0.5) * max(1, np.max(np.abs(projection[0]))) troubled_cells = np.flatnonzero( np.abs(coeffs)/max_avg > self._threshold).tolist() return troubled_cells