# -*- coding: utf-8 -*- """ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Introduce Adjusted Outer Fence method in Boxplot using global_mean -> Done TODO: Introduce overlapping cells for adjacent folds in Boxplot -> Done TODO: Extract fold computing from TC checking -> Done TODO: Vectorize _get_cells() in Boxplot method TODO: Introduce lower/upper extreme outliers in Boxplot (each cell is also checked for neighboring domains if existing) TODO: Determine max_value for Theoretical only over highest degree TODO: Check if indexing in wavelets is correct TODO: Add ThresholdDetector TODO: Add TC condition to only flag cell if left-adjacent one is flagged as well (remove this condition) TODO: Check coarse_projection calculation for indexing errors TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) TODO: Give detailed description of wavelet detection """ from abc import ABC, abstractmethod import numpy as np import torch import ANN_Model from projection_utils import Mesh class TroubledCellDetector(ABC): """Abstract class for troubled-cell detection. Detects troubled cells, i.e., cells in the mesh containing instabilities. Methods ------- get_name() Returns string of class name. get_cells(projection) Calculates troubled cells in a given projection. create_data_dict(projection) Return dictionary with data necessary to plot troubled cells. """ def __init__(self, config, basis, mesh): """Initializes TroubledCellDetector. Parameters ---------- config : dict Additional parameters for detector. basis : Basis object Basis for calculation. mesh : Mesh Mesh for calculation. """ self._mesh = mesh self._basis = basis self._reset(config) def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ pass def get_name(self): """Returns string of class name.""" return self.__class__.__name__ @abstractmethod def get_cells(self, projection): """Calculates troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. """ pass def create_data_dict(self, projection): """Return dictionary with data necessary to plot troubled cells.""" return {'projection': projection, 'basis': self._basis.create_data_dict(), 'mesh': self._mesh.create_data_dict() } class NoDetection(TroubledCellDetector): """Class without any troubled-cell detection. Methods ------- get_cells(projection) Returns no troubled cells. """ def get_cells(self, projection): """Returns no troubled cells. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ return [] class ArtificialNeuralNetwork(TroubledCellDetector): """Class for troubled-cell detection using ANNs. Attributes ---------- stencil_length : int Size of input data array. model : torch.nn.Model ANN model instance for evaluation. Methods ------- get_cells(projection) Calculates troubled cells in a given projection. """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) self._stencil_len = config.pop('stencil_len', 3) self._add_reconstructions = config.pop('add_reconstructions', True) self._model = config.pop('model', 'ThreeLayerReLu') num_datapoints = self._stencil_len if self._add_reconstructions: num_datapoints += 2 self._model_config = config.pop('model_config', { 'input_size': num_datapoints, 'first_hidden_size': 8, 'second_hidden_size': 4, 'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}}) model_state = config.pop('model_state', 'Snakemake-Test/trained ' 'models/model__Adam.pt') if not hasattr(ANN_Model, self._model): raise ValueError('Invalid model: "%s"' % self._model) self._model = getattr(ANN_Model, self._model)(self._model_config) # Load the model state and set it to evaluation mode self._model.load_state_dict(torch.load(str(model_state))) self._model.eval() def get_cells(self, projection): """Calculates troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ # Reset ghost cells to adjust for stencil length num_ghost_cells = self._stencil_len//2 projection = projection[:, 1:-1] projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]), axis=1) # Calculate input data depending on stencil length input_data = torch.from_numpy(np.vstack([ self._basis.calculate_cell_average( projection=projection[ :, cell-num_ghost_cells:cell+num_ghost_cells+1], stencil_length=self._stencil_len, add_reconstructions=self._add_reconstructions) for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) # Determine troubled cells model_output = torch.argmax(self._model(input_data.float()), dim=1) return [cell for cell in range(len(model_output)) if model_output[cell] == torch.tensor([1])] class WaveletDetector(TroubledCellDetector): """Abstract class for wavelet coefficient based troubled-cell detection. ??? """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) # Set wavelet projections self._wavelet_projection_left, self._wavelet_projection_right \ = self._basis.multiwavelet_projection def get_cells(self, projection): """Calculates troubled cells in a given projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) return self._get_cells(multiwavelet_coeffs, projection) def _calculate_wavelet_coeffs(self, projection): """Calculates wavelet coefficients used for projection to coarser grid. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- ndarray Matrix of wavelet coefficients. """ output_matrix = [] for i in range(self._mesh.num_grid_cells): new_entry = 0.5*( projection[:, i] @ self._wavelet_projection_left + projection[:, i+1] @ self._wavelet_projection_right) output_matrix.append(new_entry) return np.transpose(np.array(output_matrix)) @abstractmethod def _get_cells(self, multiwavelet_coeffs, projection): """Calculates troubled cells using multiwavelet coefficients. Parameters ---------- multiwavelet_coeffs : ndarray Matrix of multiwavelet coefficients. projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ pass def _calculate_coarse_projection(self, projection): """Calculates coarse projection. Parameters ---------- projection : ndarray Matrix of projection for each polynomial degree. Returns ------- ndarray Matrix of projection on coarse grid for each polynomial degree. """ basis_projection_left, basis_projection_right\ = self._basis.basis_projection # Remove ghost cells projection = projection[:, 1:-1] # Calculate projection on coarse mesh output_matrix = [] for i in range(self._mesh.num_grid_cells//2): new_entry = 0.5 * ( projection[:, 2 * i] @ basis_projection_left + projection[:, 2 * i + 1] @ basis_projection_right) output_matrix.append(new_entry) coarse_projection = np.transpose(np.array(output_matrix)) return coarse_projection def create_data_dict(self, projection): """Return dictionary with data necessary to plot troubled cells.""" # Create general directory data_dict = super().create_data_dict(projection) # Save multiwavelet-specific data in dictionary multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection) coarse_projection = self._calculate_coarse_projection(projection) data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs data_dict['coarse_projection'] = coarse_projection return data_dict class Boxplot(WaveletDetector): """Class for troubled-cell detection based on Boxplots. Attributes ---------- fold_len : int Length of folds considered in one Boxplot. whisker_len : int Length of Boxplot whiskers. adjust_outer_fences : bool Flag whether outer fences should be adjusted using global mean. num_overlapping_cells : int Number of cells overlapping with adjacent folds. folds : ndarray Array with indices for elements of each fold (including overlaps). """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) # Unpack necessary configurations self._fold_len = config.pop('fold_len', 16) self._whisker_len = config.pop('whisker_len', 3) self._adjust_outer_fences = config.pop('adjust_outer_fences', True) self._num_overlapping_cells = config.pop('num_overlapping_cells', 1) num_folds = self._mesh.num_grid_cells//self._fold_len self._folds = np.zeros([num_folds, self._fold_len + 2 * self._num_overlapping_cells]).astype(int) for fold in range(num_folds): self._folds[fold] = np.array( [i % self._mesh.num_grid_cells for i in range( fold * self._fold_len - self._num_overlapping_cells, (fold+1) * self._fold_len + self._num_overlapping_cells)]) # print(self._folds) def _get_cells(self, multiwavelet_coeffs, projection): """Calculates troubled cells using multiwavelet coefficients. Parameters ---------- multiwavelet_coeffs : ndarray Matrix of multiwavelet coefficients. projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ # indexed_coeffs = [[multiwavelet_coeffs[0, i], i] # for i in range(self._mesh.num_grid_cells)] coeffs = multiwavelet_coeffs[0] # print(coeffs.shape) if self._mesh.num_grid_cells < self._fold_len: self._fold_len = self._mesh.num_grid_cells num_folds = self._mesh.num_grid_cells//self._fold_len troubled_cells = [] # troubled_cells_new = [] for fold in range(num_folds): # indexed_fold = np.array(indexed_coeffs)[self._folds[fold]] # sorted_fold_old = indexed_fold[indexed_fold[:, 0].argsort()] sorted_fold = sorted(coeffs[self._folds[fold]]) # print(sorted_fold == sorted_fold_old[:, 0]) boundary_index = self._fold_len//4 balance_factor = self._fold_len/4.0 - boundary_index first_quartile = (1-balance_factor) \ * sorted_fold[boundary_index-1] \ + balance_factor * sorted_fold[boundary_index] third_quartile = (1-balance_factor) \ * sorted_fold[3*boundary_index-1]\ + balance_factor * sorted_fold[3*boundary_index] lower_bound = first_quartile \ - self._whisker_len * (third_quartile-first_quartile) upper_bound = third_quartile \ + self._whisker_len * (third_quartile-first_quartile) # Adjust outer fences if flag is set if self._adjust_outer_fences: global_mean = np.mean(abs(coeffs)) lower_bound = min(-global_mean, lower_bound) upper_bound = max(global_mean, upper_bound) # # Check for lower extreme outliers and add respective cells # for cell in sorted_fold: # if cell[0] < lower_bound: # troubled_cells.append(int(cell[1])) # else: # break # # # Check for upper extreme outliers and add respective cells # for cell in sorted_fold[::-1][:]: # if cell[0] > upper_bound: # troubled_cells.append(int(cell[1])) # else: # break # Check for extreme outlier and add respective cells for cell in self._folds[fold]: if (coeffs[cell] > upper_bound) \ or (coeffs[cell] < lower_bound): troubled_cells.append(int(cell)) # print(upper_bound, lower_bound) # print(sorted_fold_new) # print(type(sorted_fold_new)) # print(sorted_fold_new > upper_bound) # print(sorted_fold_new < lower_bound) # test = # print(type(test), test) # print(list(test), list(test[0])) # troubled_cells_new += list(np.flatnonzero(np.logical_or( # sorted_fold_new > upper_bound, # sorted_fold_new < lower_bound)).astype(int)) # print(troubled_cells_new) # troubled_cells_new = sorted(troubled_cells_new) # print(troubled_cells_new) # print(troubled_cells) # print(sorted(troubled_cells) == sorted(troubled_cells_new)) # print(type(troubled_cells_new[0]), type(troubled_cells[0])) return sorted(troubled_cells) class Theoretical(WaveletDetector): """Class for troubled-cell detection based on the projection averages and a cutoff factor. Attributes ---------- cutoff_factor : float Cutoff factor above which a cell is considered troubled. """ def _reset(self, config): """Resets instance variables. Parameters ---------- config : dict Additional parameters for detector. """ super()._reset(config) # Unpack necessary configurations self._cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._mesh.cell_len) # comment to line above: or 2 or 3 def _get_cells(self, multiwavelet_coeffs, projection): """Calculates troubled cells using multiwavelet coefficients. Parameters ---------- multiwavelet_coeffs : ndarray Matrix of multiwavelet coefficients. projection : ndarray Matrix of projection for each polynomial degree. Returns ------- list List of indices for all detected troubled cells. """ troubled_cells = [] max_avg = np.sqrt(0.5) \ * max(1, max(abs(projection[0][cell+1]) for cell in range(self._mesh.num_grid_cells))) for cell in range(self._mesh.num_grid_cells): if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg): troubled_cells.append(cell) return troubled_cells def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg): """Checks whether a cell is troubled. Parameters ---------- multiwavelet_coeffs : ndarray Matrix of multiwavelet coefficients. cell : int Index of cell. max_avg : float Maximum average of projection. Returns ------- bool Flag whether cell is troubled. """ max_value = max(abs(multiwavelet_coeffs[degree][cell]) for degree in range( self._basis.polynomial_degree+1))/max_avg eps = self._cutoff_factor\ / (self._mesh.cell_len*self._mesh.num_grid_cells) return max_value > eps