Skip to content
Snippets Groups Projects
Select Git revision
  • b325651deb1c4ddfd93ecabcc34985125062f4c2
  • master default protected
2 results

Troubled_Cell_Detector.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Troubled_Cell_Detector.py 17.26 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Vectorize _get_cells() in Boxplot method
    TODO: Introduce lower/upper extreme outliers in Boxplot
        (each cell is also checked for neighboring domains if existing)
    TODO: Determine max_value for Theoretical only over highest degree
    TODO: Check if indexing in wavelets is correct
    TODO: Add ThresholdDetector
    TODO: Add TC condition to only flag cell if left-adjacent one is flagged as
        well (remove this condition)
    TODO: Check coarse_projection calculation for indexing errors
    TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
    TODO: Give detailed description of wavelet detection
    
    """
    from abc import ABC, abstractmethod
    import numpy as np
    import torch
    
    import ANN_Model
    from projection_utils import Mesh
    
    
    class TroubledCellDetector(ABC):
        """Abstract class for troubled-cell detection.
    
        Detects troubled cells, i.e., cells in the mesh containing instabilities.
    
        Methods
        -------
        get_name()
            Returns string of class name.
        get_cells(projection)
            Calculates troubled cells in a given projection.
        create_data_dict(projection)
            Return dictionary with data necessary to plot troubled cells.
    
        """
        def __init__(self, config, basis, mesh):
            """Initializes TroubledCellDetector.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
            basis : Basis object
                Basis for calculation.
            mesh : Mesh
                Mesh for calculation.
    
            """
            self._mesh = mesh
            self._basis = basis
    
            self._reset(config)
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            pass
    
        def get_name(self):
            """Returns string of class name."""
            return self.__class__.__name__
    
        @abstractmethod
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            """
            pass
    
        def create_data_dict(self, projection):
            """Return dictionary with data necessary to plot troubled cells."""
            return {'projection': projection,
                    'basis': self._basis.create_data_dict(),
                    'mesh': self._mesh.create_data_dict()
                    }
    
    
    class NoDetection(TroubledCellDetector):
        """Class without any troubled-cell detection.
    
        Methods
        -------
        get_cells(projection)
            Returns no troubled cells.
    
        """
        def get_cells(self, projection):
            """Returns no troubled cells.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            return []
    
    
    class ArtificialNeuralNetwork(TroubledCellDetector):
        """Class for troubled-cell detection using ANNs.
    
        Attributes
        ----------
        stencil_length : int
            Size of input data array.
        model : torch.nn.Model
            ANN model instance for evaluation.
    
        Methods
        -------
        get_cells(projection)
            Calculates troubled cells in a given projection.
    
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            self._stencil_len = config.pop('stencil_len', 3)
            self._add_reconstructions = config.pop('add_reconstructions', True)
            self._model = config.pop('model', 'ThreeLayerReLu')
            num_datapoints = self._stencil_len
            if self._add_reconstructions:
                num_datapoints += 2
            self._model_config = config.pop('model_config', {
                'input_size': num_datapoints, 'first_hidden_size': 8,
                'second_hidden_size': 4, 'output_size': 2,
                'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
            model_state = config.pop('model_state', 'Snakemake-Test/trained '
                                                    'models/model__Adam.pt')
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            self._model = getattr(ANN_Model, self._model)(self._model_config)
    
            # Load the model state and set it to evaluation mode
            self._model.load_state_dict(torch.load(str(model_state)))
            self._model.eval()
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            # Reset ghost cells to adjust for stencil length
            num_ghost_cells = self._stencil_len//2
            projection = projection[:, 1:-1]
            projection = np.concatenate((projection[:, -num_ghost_cells:],
                                         projection,
                                         projection[:, :num_ghost_cells]), axis=1)
    
            # Calculate input data depending on stencil length
            input_data = torch.from_numpy(np.vstack([
                self._basis.calculate_cell_average(
                    projection=projection[
                               :, cell-num_ghost_cells:cell+num_ghost_cells+1],
                    stencil_length=self._stencil_len,
                    add_reconstructions=self._add_reconstructions)
                for cell in range(num_ghost_cells,
                                  len(projection[0])-num_ghost_cells)]))
    
            # Determine troubled cells
            model_output = torch.argmax(self._model(input_data.float()), dim=1)
            return [cell for cell in range(len(model_output))
                    if model_output[cell] == torch.tensor([1])]
    
    
    class WaveletDetector(TroubledCellDetector):
        """Abstract class for wavelet coefficient based troubled-cell detection.
    
        ???
    
        """
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Set wavelet projections
            self._wavelet_projection_left, self._wavelet_projection_right \
                = self._basis.multiwavelet_projection
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            return self._get_cells(multiwavelet_coeffs, projection)
    
        def _calculate_wavelet_coeffs(self, projection):
            """Calculates wavelet coefficients used for projection to coarser grid.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            ndarray
                Matrix of wavelet coefficients.
    
            """
            output_matrix = []
            for i in range(self._mesh.num_grid_cells):
                new_entry = 0.5*(
                        projection[:, i] @ self._wavelet_projection_left
                        + projection[:, i+1] @ self._wavelet_projection_right)
                output_matrix.append(new_entry)
            return np.transpose(np.array(output_matrix))
    
        @abstractmethod
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            pass
    
        def _calculate_coarse_projection(self, projection):
            """Calculates coarse projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            ndarray
                Matrix of projection on coarse grid for each polynomial degree.
    
            """
            basis_projection_left, basis_projection_right\
                = self._basis.basis_projection
    
            # Remove ghost cells
            projection = projection[:, 1:-1]
    
            # Calculate projection on coarse mesh
            output_matrix = []
            for i in range(self._mesh.num_grid_cells//2):
                new_entry = 0.5 * (
                        projection[:, 2 * i] @ basis_projection_left
                        + projection[:, 2 * i + 1] @ basis_projection_right)
                output_matrix.append(new_entry)
            coarse_projection = np.transpose(np.array(output_matrix))
    
            return coarse_projection
    
        def create_data_dict(self, projection):
            """Return dictionary with data necessary to plot troubled cells."""
            # Create general directory
            data_dict = super().create_data_dict(projection)
    
            # Save multiwavelet-specific data in dictionary
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            coarse_projection = self._calculate_coarse_projection(projection)
            data_dict['multiwavelet_coeffs'] = multiwavelet_coeffs
            data_dict['coarse_projection'] = coarse_projection
    
            return data_dict
    
    
    class Boxplot(WaveletDetector):
        """Class for troubled-cell detection based on Boxplots.
    
        Attributes
        ----------
        fold_len : int
            Length of folds considered in one Boxplot.
        whisker_len : int
            Length of Boxplot whiskers.
        adjust_outer_fences : bool
            Flag whether outer fences should be adjusted using global mean.
        num_overlapping_cells : int
            Number of cells overlapping with adjacent folds.
        folds : ndarray
            Array with indices for elements of each fold (including
            overlaps).
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._fold_len = config.pop('fold_len', 16)
            self._whisker_len = config.pop('whisker_len', 3)
            self._adjust_outer_fences = config.pop('adjust_outer_fences', True)
            self._num_overlapping_cells = config.pop('num_overlapping_cells', 1)
            num_folds = self._mesh.num_grid_cells//self._fold_len
            self._folds = np.zeros([num_folds, self._fold_len
                                    + 2 * self._num_overlapping_cells]).astype(int)
            for fold in range(num_folds):
                self._folds[fold] = np.array(
                    [i % self._mesh.num_grid_cells for i in range(
                        fold * self._fold_len - self._num_overlapping_cells,
                        (fold+1) * self._fold_len + self._num_overlapping_cells)])
            # print(self._folds)
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            # indexed_coeffs = [[multiwavelet_coeffs[0, i], i]
            #                   for i in range(self._mesh.num_grid_cells)]
            coeffs = multiwavelet_coeffs[0]
            # print(coeffs.shape)
    
            if self._mesh.num_grid_cells < self._fold_len:
                self._fold_len = self._mesh.num_grid_cells
    
            num_folds = self._mesh.num_grid_cells//self._fold_len
            troubled_cells = []
            # troubled_cells_new = []
    
            for fold in range(num_folds):
                # indexed_fold = np.array(indexed_coeffs)[self._folds[fold]]
                # sorted_fold_old = indexed_fold[indexed_fold[:, 0].argsort()]
    
                sorted_fold = sorted(coeffs[self._folds[fold]])
                # print(sorted_fold == sorted_fold_old[:, 0])
    
                boundary_index = self._fold_len//4
                balance_factor = self._fold_len/4.0 - boundary_index
    
                first_quartile = (1-balance_factor) \
                    * sorted_fold[boundary_index-1] \
                    + balance_factor * sorted_fold[boundary_index]
                third_quartile = (1-balance_factor) \
                    * sorted_fold[3*boundary_index-1]\
                    + balance_factor * sorted_fold[3*boundary_index]
    
                lower_bound = first_quartile \
                    - self._whisker_len * (third_quartile-first_quartile)
                upper_bound = third_quartile \
                    + self._whisker_len * (third_quartile-first_quartile)
    
                # Adjust outer fences if flag is set
                if self._adjust_outer_fences:
                    global_mean = np.mean(abs(coeffs))
                    lower_bound = min(-global_mean, lower_bound)
                    upper_bound = max(global_mean, upper_bound)
    
                # # Check for lower extreme outliers and add respective cells
                # for cell in sorted_fold:
                #     if cell[0] < lower_bound:
                #         troubled_cells.append(int(cell[1]))
                #     else:
                #         break
                #
                # # Check for upper extreme outliers and add respective cells
                # for cell in sorted_fold[::-1][:]:
                #     if cell[0] > upper_bound:
                #         troubled_cells.append(int(cell[1]))
                #     else:
                #         break
    
                # Check for extreme outlier and add respective cells
                for cell in self._folds[fold]:
                    if (coeffs[cell] > upper_bound) \
                            or (coeffs[cell] < lower_bound):
                        troubled_cells.append(int(cell))
    
                # print(upper_bound, lower_bound)
                # print(sorted_fold_new)
                # print(type(sorted_fold_new))
                # print(sorted_fold_new > upper_bound)
                # print(sorted_fold_new < lower_bound)
                # test =
                # print(type(test), test)
                # print(list(test), list(test[0]))
    
                # troubled_cells_new += list(np.flatnonzero(np.logical_or(
                #     sorted_fold_new > upper_bound,
                #     sorted_fold_new < lower_bound)).astype(int))
                # print(troubled_cells_new)
    
            # troubled_cells_new = sorted(troubled_cells_new)
            # print(troubled_cells_new)
            # print(troubled_cells)
            # print(sorted(troubled_cells) == sorted(troubled_cells_new))
            # print(type(troubled_cells_new[0]), type(troubled_cells[0]))
    
            return sorted(troubled_cells)
    
    
    class Theoretical(WaveletDetector):
        """Class for troubled-cell detection based on the projection averages and
        a cutoff factor.
    
        Attributes
        ----------
        cutoff_factor : float
            Cutoff factor above which a cell is considered troubled.
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._cutoff_factor = config.pop('cutoff_factor',
                                             np.sqrt(2) * self._mesh.cell_len)
            # comment to line above: or 2 or 3
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            troubled_cells = []
            max_avg = np.sqrt(0.5) \
                * max(1, max(abs(projection[0][cell+1])
                             for cell in range(self._mesh.num_grid_cells)))
    
            for cell in range(self._mesh.num_grid_cells):
                if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                    troubled_cells.append(cell)
    
            return troubled_cells
    
        def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
            """Checks whether a cell is troubled.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            cell : int
                Index of cell.
            max_avg : float
                Maximum average of projection.
    
            Returns
            -------
            bool
                Flag whether cell is troubled.
    
            """
            max_value = max(abs(multiwavelet_coeffs[degree][cell])
                            for degree in range(
                self._basis.polynomial_degree+1))/max_avg
            eps = self._cutoff_factor\
                / (self._mesh.cell_len*self._mesh.num_grid_cells)
    
            return max_value > eps