Select Git revision
Troubled_Cell_Detector.py
Laura Christine Kühle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Troubled_Cell_Detector.py 17.26 KiB
# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Vectorize _get_cells() in Boxplot method
TODO: Introduce lower/upper extreme outliers in Boxplot
(each cell is also checked for neighboring domains if existing)
TODO: Determine max_value for Theoretical only over highest degree
TODO: Check if indexing in wavelets is correct
TODO: Add ThresholdDetector
TODO: Add TC condition to only flag cell if left-adjacent one is flagged as
well (remove this condition)
TODO: Check coarse_projection calculation for indexing errors
TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection
"""
from abc import ABC, abstractmethod
import numpy as np
import torch
import ANN_Model
from projection_utils import Mesh
class TroubledCellDetector(ABC):
"""Abstract class for troubled-cell detection.
Detects troubled cells, i.e., cells in the mesh containing instabilities.
Methods
-------
get_name()
Returns string of class name.
get_cells(projection)
Calculates troubled cells in a given projection.
create_data_dict(projection)
Return dictionary with data necessary to plot troubled cells.
"""
def __init__(self, config, basis, mesh):
"""Initializes TroubledCellDetector.
Parameters
----------
config : dict
Additional parameters for detector.
basis : Basis object
Basis for calculation.
mesh : Mesh
Mesh for calculation.
"""
self._mesh = mesh
self._basis = basis
self._reset(config)
def _reset(self, config):
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for detector.
"""
pass
def get_name(self):
"""Returns string of class name."""
return self.__class__.__name__
@abstractmethod
def get_cells(self, projection):
"""Calculates troubled cells in a given projection.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
"""
pass
def create_data_dict(self, projection):
"""Return dictionary with data necessary to plot troubled cells."""
return {'projection': projection,
'basis': self._basis.create_data_dict(),
'mesh': self._mesh.create_data_dict()
}
class NoDetection(TroubledCellDetector):
"""Class without any troubled-cell detection.
Methods
-------
get_cells(projection)
Returns no troubled cells.
"""
def get_cells(self, projection):
"""Returns no troubled cells.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
return []
class ArtificialNeuralNetwork(TroubledCellDetector):
"""Class for troubled-cell detection using ANNs.
Attributes
----------
stencil_length : int
Size of input data array.
model : torch.nn.Model
ANN model instance for evaluation.
Methods
-------
get_cells(projection)
Calculates troubled cells in a given projection.
"""
def _reset(self, config):
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for detector.
"""
super()._reset(config)
self._stencil_len = config.pop('stencil_len', 3)
self._add_reconstructions = config.pop('add_reconstructions', True)
self._model = config.pop('model', 'ThreeLayerReLu')
num_datapoints = self._stencil_len
if self._add_reconstructions:
num_datapoints += 2
self._model_config = config.pop('model_config', {
'input_size': num_datapoints, 'first_hidden_size': 8,
'second_hidden_size': 4, 'output_size': 2,
'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
model_state = config.pop('model_state', 'Snakemake-Test/trained '
'models/model__Adam.pt')
if not hasattr(ANN_Model, self._model):
raise ValueError('Invalid model: "%s"' % self._model)
self._model = getattr(ANN_Model, self._model)(self._model_config)
# Load the model state and set it to evaluation mode
self._model.load_state_dict(torch.load(str(model_state)))
self._model.eval()
def get_cells(self, projection):
"""Calculates troubled cells in a given projection.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
# Reset ghost cells to adjust for stencil length
num_ghost_cells = self._stencil_len//2
projection = projection[:, 1:-1]
projection = np.concatenate((projection[:, -num_ghost_cells:],
projection,
projection[:, :num_ghost_cells]), axis=1)
# Calculate input data depending on stencil length
input_data = torch.from_numpy(np.vstack([
self._basis.calculate_cell_average(
projection=projection[
:, cell-num_ghost_cells:cell+num_ghost_cells+1],
stencil_length=self._stencil_len,
add_reconstructions=self._add_reconstructions)
for cell in range(num_ghost_cells,
len(projection[0])-num_ghost_cells)]))
# Determine troubled cells
model_output = torch.argmax(self._model(input_data.float()), dim=1)
return [cell for cell in range(len(model_output))
if model_output[cell] == torch.tensor([1])]
class WaveletDetector(TroubledCellDetector):
"""Abstract class for wavelet coefficient based troubled-cell detection.
???
"""
def _reset(self, config):
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for detector.
"""
super()._reset(config)
# Set wavelet projections
self._wavelet_projection_left, self._wavelet_projection_right \
= self._basis.multiwavelet_projection
def get_cells(self, projection):
"""Calculates troubled cells in a given projection.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
return self._get_cells(multiwavelet_coeffs, projection)
def _calculate_wavelet_coeffs(self, projection):
"""Calculates wavelet coefficients used for projection to coarser grid.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
ndarray
Matrix of wavelet coefficients.
"""
output_matrix = []
for i in range(self._mesh.num_grid_cells):
new_entry = 0.5*(
projection[:, i] @ self._wavelet_projection_left
+ projection[:, i+1] @ self._wavelet_projection_right)
output_matrix.append(new_entry)
return np.transpose(np.array(output_matrix))
@abstractmethod
def _get_cells(self, multiwavelet_coeffs, projection):
"""Calculates troubled cells using multiwavelet coefficients.
Parameters
----------
multiwavelet_coeffs : ndarray
Matrix of multiwavelet coefficients.
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
pass
def _calculate_coarse_projection(self, projection):
"""Calculates coarse projection.
Parameters
----------
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
ndarray
Matrix of projection on coarse grid for each polynomial degree.
"""
basis_projection_left, basis_projection_right\
= self._basis.basis_projection
# Remove ghost cells
projection = projection[:, 1:-1]
# Calculate projection on coarse mesh
output_matrix = []
for i in range(self._mesh.num_grid_cells//2):
new_entry = 0.5 * (
projection[:, 2 * i] @ basis_projection_left
+ projection[:, 2 * i + 1] @ basis_projection_right)
output_matrix.append(new_entry)
coarse_projection = np.transpose(np.array(output_matrix))
return coarse_projection
def create_data_dict(self, projection):
"""Return dictionary with data necessary to plot troubled cells."""
# Create general directory
data_dict = super().create_data_dict(projection)
# Save multiwavelet-specific data in dictionary
multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
coarse_projection = self._calculate_coarse_projection(projection)
data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs
data_dict['coarse_projection'] = coarse_projection
return data_dict
class Boxplot(WaveletDetector):
"""Class for troubled-cell detection based on Boxplots.
Attributes
----------
fold_len : int
Length of folds considered in one Boxplot.
whisker_len : int
Length of Boxplot whiskers.
adjust_outer_fences : bool
Flag whether outer fences should be adjusted using global mean.
num_overlapping_cells : int
Number of cells overlapping with adjacent folds.
folds : ndarray
Array with indices for elements of each fold (including
overlaps).
"""
def _reset(self, config):
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for detector.
"""
super()._reset(config)
# Unpack necessary configurations
self._fold_len = config.pop('fold_len', 16)
self._whisker_len = config.pop('whisker_len', 3)
self._adjust_outer_fences = config.pop('adjust_outer_fences', True)
self._num_overlapping_cells = config.pop('num_overlapping_cells', 1)
num_folds = self._mesh.num_grid_cells//self._fold_len
self._folds = np.zeros([num_folds, self._fold_len
+ 2 * self._num_overlapping_cells]).astype(int)
for fold in range(num_folds):
self._folds[fold] = np.array(
[i % self._mesh.num_grid_cells for i in range(
fold * self._fold_len - self._num_overlapping_cells,
(fold+1) * self._fold_len + self._num_overlapping_cells)])
# print(self._folds)
def _get_cells(self, multiwavelet_coeffs, projection):
"""Calculates troubled cells using multiwavelet coefficients.
Parameters
----------
multiwavelet_coeffs : ndarray
Matrix of multiwavelet coefficients.
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
# indexed_coeffs = [[multiwavelet_coeffs[0, i], i]
# for i in range(self._mesh.num_grid_cells)]
coeffs = multiwavelet_coeffs[0]
# print(coeffs.shape)
if self._mesh.num_grid_cells < self._fold_len:
self._fold_len = self._mesh.num_grid_cells
num_folds = self._mesh.num_grid_cells//self._fold_len
troubled_cells = []
# troubled_cells_new = []
for fold in range(num_folds):
# indexed_fold = np.array(indexed_coeffs)[self._folds[fold]]
# sorted_fold_old = indexed_fold[indexed_fold[:, 0].argsort()]
sorted_fold = sorted(coeffs[self._folds[fold]])
# print(sorted_fold == sorted_fold_old[:, 0])
boundary_index = self._fold_len//4
balance_factor = self._fold_len/4.0 - boundary_index
first_quartile = (1-balance_factor) \
* sorted_fold[boundary_index-1] \
+ balance_factor * sorted_fold[boundary_index]
third_quartile = (1-balance_factor) \
* sorted_fold[3*boundary_index-1]\
+ balance_factor * sorted_fold[3*boundary_index]
lower_bound = first_quartile \
- self._whisker_len * (third_quartile-first_quartile)
upper_bound = third_quartile \
+ self._whisker_len * (third_quartile-first_quartile)
# Adjust outer fences if flag is set
if self._adjust_outer_fences:
global_mean = np.mean(abs(coeffs))
lower_bound = min(-global_mean, lower_bound)
upper_bound = max(global_mean, upper_bound)
# # Check for lower extreme outliers and add respective cells
# for cell in sorted_fold:
# if cell[0] < lower_bound:
# troubled_cells.append(int(cell[1]))
# else:
# break
#
# # Check for upper extreme outliers and add respective cells
# for cell in sorted_fold[::-1][:]:
# if cell[0] > upper_bound:
# troubled_cells.append(int(cell[1]))
# else:
# break
# Check for extreme outlier and add respective cells
for cell in self._folds[fold]:
if (coeffs[cell] > upper_bound) \
or (coeffs[cell] < lower_bound):
troubled_cells.append(int(cell))
# print(upper_bound, lower_bound)
# print(sorted_fold_new)
# print(type(sorted_fold_new))
# print(sorted_fold_new > upper_bound)
# print(sorted_fold_new < lower_bound)
# test =
# print(type(test), test)
# print(list(test), list(test[0]))
# troubled_cells_new += list(np.flatnonzero(np.logical_or(
# sorted_fold_new > upper_bound,
# sorted_fold_new < lower_bound)).astype(int))
# print(troubled_cells_new)
# troubled_cells_new = sorted(troubled_cells_new)
# print(troubled_cells_new)
# print(troubled_cells)
# print(sorted(troubled_cells) == sorted(troubled_cells_new))
# print(type(troubled_cells_new[0]), type(troubled_cells[0]))
return sorted(troubled_cells)
class Theoretical(WaveletDetector):
"""Class for troubled-cell detection based on the projection averages and
a cutoff factor.
Attributes
----------
cutoff_factor : float
Cutoff factor above which a cell is considered troubled.
"""
def _reset(self, config):
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for detector.
"""
super()._reset(config)
# Unpack necessary configurations
self._cutoff_factor = config.pop('cutoff_factor',
np.sqrt(2) * self._mesh.cell_len)
# comment to line above: or 2 or 3
def _get_cells(self, multiwavelet_coeffs, projection):
"""Calculates troubled cells using multiwavelet coefficients.
Parameters
----------
multiwavelet_coeffs : ndarray
Matrix of multiwavelet coefficients.
projection : ndarray
Matrix of projection for each polynomial degree.
Returns
-------
list
List of indices for all detected troubled cells.
"""
troubled_cells = []
max_avg = np.sqrt(0.5) \
* max(1, max(abs(projection[0][cell+1])
for cell in range(self._mesh.num_grid_cells)))
for cell in range(self._mesh.num_grid_cells):
if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
troubled_cells.append(cell)
return troubled_cells
def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
"""Checks whether a cell is troubled.
Parameters
----------
multiwavelet_coeffs : ndarray
Matrix of multiwavelet coefficients.
cell : int
Index of cell.
max_avg : float
Maximum average of projection.
Returns
-------
bool
Flag whether cell is troubled.
"""
max_value = max(abs(multiwavelet_coeffs[degree][cell])
for degree in range(
self._basis.polynomial_degree+1))/max_avg
eps = self._cutoff_factor\
/ (self._mesh.cell_len*self._mesh.num_grid_cells)
return max_value > eps