Skip to content
Snippets Groups Projects
Select Git revision
  • 978b159457a8f031995406fb0b038024529426ba
  • master default protected
2 results

Troubled_Cell_Detector.py

  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Troubled_Cell_Detector.py 21.79 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead
        of every second)
    TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
    TODO: Give detailed description of wavelet detection
    
    """
    import numpy as np
    import matplotlib
    from matplotlib import pyplot as plt
    import seaborn as sns
    import torch
    from sympy import Symbol
    
    import ANN_Model
    from Plotting import plot_solution_and_approx, plot_semilog_error, \
        plot_error, plot_shock_tube, plot_details, \
        calculate_approximate_solution, calculate_exact_solution
    from projection_utils import calculate_cell_average
    
    matplotlib.use('Agg')
    x = Symbol('x')
    z = Symbol('z')
    
    
    class TroubledCellDetector:
        """Class for troubled-cell detection.
    
        Detects troubled cells, i.e., cells in the mesh containing instabilities.
    
        Attributes
        ----------
        interval_len : float
            Length of the interval between left and right boundary.
        cell_len : float
            Length of a cell in mesh.
    
        Methods
        -------
        get_name()
            Returns string of class name.
        get_cells(projection)
            Calculates troubled cells in a given projection.
        plot_results(projection, troubled_cell_history, time_history)
            Plots results and troubled cells of a projection.
    
        """
        def __init__(self, config, init_cond, quadrature, basis, mesh,
                     wave_speed=1, polynomial_degree=2, num_grid_cells=64,
                     final_time=1, left_bound=-1, right_bound=1):
            """Initializes TroubledCellDetector.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
            init_cond : InitialCondition object
                Initial condition for evaluation.
            quadrature : Quadrature object
                Quadrature for evaluation.
            basis : Basis object
                Basis for calculation.
            mesh : ndarray
                List of mesh valuation points.
            wave_speed : float, optional
                Speed of wave in rightward direction. Default: 1.
            polynomial_degree : int, optional
                Polynomial degree. Default: 2.
            num_grid_cells : int, optional
                Number of cells in the mesh. Usually exponential of 2. Default: 64.
            final_time : float, optional
                Final time for which approximation is calculated. Default: 1.
            left_bound : float, optional
                Left boundary of interval. Default: -1.
            right_bound : float, optional
                Right boundary of interval. Default: 1.
    
            """
            self._mesh = mesh
            self._wave_speed = wave_speed
            self._polynomial_degree = polynomial_degree
            self._num_grid_cells = num_grid_cells
            self._final_time = final_time
            self._left_bound = left_bound
            self._right_bound = right_bound
            self._interval_len = right_bound - left_bound
            self._cell_len = self._interval_len / num_grid_cells
            self._basis = basis
            self._init_cond = init_cond
            self._quadrature = quadrature
    
            # Set parameters from config if existing
            self._colors = config.pop('colors', {})
    
            self._check_colors()
            self._reset(config)
    
        def _check_colors(self):
            """Checks plot colors.
    
            Checks whether colors for plots were given and sets them if required.
    
            """
            self._colors['exact'] = self._colors.get('exact', 'k-')
            self._colors['approx'] = self._colors.get('approx', 'y')
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            sns.set()
    
        def get_name(self):
            """Returns string of class name."""
            return self.__class__.__name__
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            """
            pass
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            """Plots results and troubled cells of a projection.
    
            Plots results and troubled cells of a projection given its evaluation
            history.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
            troubled_cell_history : list
                List of detected troubled cells for each time step.
            time_history : list
                List of value of each time step.
    
            """
            plot_shock_tube(self._num_grid_cells, troubled_cell_history,
                            time_history)
            max_error = self._plot_mesh(projection)
    
            print('p =', self._polynomial_degree)
            print('N =', self._num_grid_cells)
            print('maximum error =', max_error)
    
        def _plot_mesh(self, projection):
            """Plots exact and approximate solution as well as errors.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            max_error : float
                Maximum error between exact and approximate solution.
    
            """
            grid, exact = calculate_exact_solution(
                self._mesh[2:-2], self._cell_len, self._wave_speed,
                self._final_time, self._interval_len, self._quadrature,
                self._init_cond)
            approx = calculate_approximate_solution(
                projection[:, 1:-1], self._quadrature.get_eval_points(),
                self._polynomial_degree, self._basis.get_basis_vector())
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            plot_solution_and_approx(grid, exact, approx, self._colors['exact'],
                                     self._colors['approx'])
            plt.legend(['Exact', 'Approx'])
            plot_semilog_error(grid, pointwise_error)
            plot_error(grid, exact, approx)
    
            return max_error
    
    
    class NoDetection(TroubledCellDetector):
        """Class without any troubled-cell detection.
    
        Methods
        -------
        get_cells(projection)
            Returns no troubled cells.
    
        """
        def get_cells(self, projection):
            """Returns no troubled cells.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            return []
    
    
    class ArtificialNeuralNetwork(TroubledCellDetector):
        """Class for troubled-cell detection using ANNs.
    
        Attributes
        ----------
        stencil_length : int
            Size of input data array.
        model : torch.nn.Model
            ANN model instance for evaluation.
    
        Methods
        -------
        get_cells(projection)
            Calculates troubled cells in a given projection.
    
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            self._stencil_len = config.pop('stencil_len', 3)
            self._add_reconstructions = config.pop('add_reconstructions', True)
            self._model = config.pop('model', 'ThreeLayerReLu')
            num_datapoints = self._stencil_len
            if self._add_reconstructions:
                num_datapoints += 2
            self._model_config = config.pop('model_config', {
                'input_size': num_datapoints, 'first_hidden_size': 8,
                'second_hidden_size': 4, 'output_size': 2,
                'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
            model_state = config.pop('model_state', 'Snakemake-Test/trained '
                                                    'models/model__Adam.pt')
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            self._model = getattr(ANN_Model, self._model)(self._model_config)
    
            # Load the model state and set it to evaluation mode
            self._model.load_state_dict(torch.load(str(model_state)))
            self._model.eval()
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            # Reset ghost cells to adjust for stencil length
            num_ghost_cells = self._stencil_len//2
            projection = projection[:, 1:-1]
            projection = np.concatenate((projection[:, -num_ghost_cells:],
                                         projection,
                                         projection[:, :num_ghost_cells]), axis=1)
    
            # Calculate input data depending on stencil length
            input_data = torch.from_numpy(np.vstack([calculate_cell_average(
                projection=projection[
                           :, cell-num_ghost_cells:cell+num_ghost_cells+1],
                stencil_length=self._stencil_len, basis=self._basis,
                polynomial_degree=self._polynomial_degree if
                self._add_reconstructions else -1)
                    for cell in range(num_ghost_cells,
                                      len(projection[0])-num_ghost_cells)]))
    
            # Determine troubled cells
            model_output = torch.argmax(self._model(input_data.float()), dim=1)
            return [cell for cell in range(len(model_output))
                    if model_output[cell] == torch.tensor([1])]
    
    
    class WaveletDetector(TroubledCellDetector):
        """Class for troubled-cell detection based on wavelet coefficients.
    
        ???
    
        """
        def _check_colors(self):
            """Checks plot colors.
    
            Checks whether colors for plots were given and sets them if required.
    
            """
            self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
            self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
            self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
            self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Set additional necessary parameter
            self._num_coarse_grid_cells = self._num_grid_cells//2
            self._wavelet_projection_left, self._wavelet_projection_right \
                = self._basis.get_multiwavelet_projections()
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(
                projection[:, 1: -1])
            return self._get_cells(multiwavelet_coeffs, projection)
    
        def _calculate_wavelet_coeffs(self, projection):
            """Calculates wavelet coefficients used for projection to coarser grid.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            ndarray
                Matrix of wavelet coefficients.
    
            """
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5*(
                        projection[:, 2*i] @ self._wavelet_projection_left
                        + projection[:, 2*i+1] @ self._wavelet_projection_right)
                output_matrix.append(new_entry)
            return np.transpose(np.array(output_matrix))
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            return []
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            """Plots results and troubled cells of a projection.
    
            Plots results on coarse and fine grid, errors, troubled cells,
            and coefficient details given the projections evaluation history.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
            troubled_cell_history : list
                List of detected troubled cells for each time step.
            time_history : list
                List of value of each time step.
    
            """
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            coarse_projection = self._calculate_coarse_projection(projection)
            plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection,
                         self._basis.get_basis_vector(),
                         self._basis.get_wavelet_vector(), multiwavelet_coeffs,
                         self._num_coarse_grid_cells,
                         self._polynomial_degree)
            super().plot_results(projection, troubled_cell_history, time_history)
    
        def _calculate_coarse_projection(self, projection):
            """Calculates coarse projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            ndarray
                Matrix of projection on coarse grid for each polynomial degree.
    
            """
            basis_projection_left, basis_projection_right\
                = self._basis.get_basis_projections()
    
            # Remove ghost cells
            projection = projection[:, 1:-1]
    
            # Calculate projection on coarse mesh
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5 * (
                        projection[:, 2 * i] @ basis_projection_left
                        + projection[:, 2 * i + 1] @ basis_projection_right)
                output_matrix.append(new_entry)
            coarse_projection = np.transpose(np.array(output_matrix))
    
            return coarse_projection
    
        def _plot_mesh(self, projection):
            """Plots exact and approximate solution as well as errors.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            max_error : float
                Maximum error between exact and approximate solution.
    
            """
    
            grid, exact = calculate_exact_solution(
                self._mesh[2:-2], self._cell_len, self._wave_speed,
                self._final_time, self._interval_len, self._quadrature,
                self._init_cond)
            approx = calculate_approximate_solution(
                projection[:, 1:-1], self._quadrature.get_eval_points(),
                self._polynomial_degree, self._basis.get_basis_vector())
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_coarse_mesh(projection)
            plot_solution_and_approx(grid, exact, approx,
                                     self._colors['fine_exact'],
                                     self._colors['fine_approx'])
            plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)',
                        'Approx (Fine)'])
            plot_semilog_error(grid, pointwise_error)
            plot_error(grid, exact, approx)
    
            return max_error
    
        def _plot_coarse_mesh(self, projection):
            """Plots exact and approximate solution as well as errors for a coarse
            projection.
    
            Parameters
            ----------
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            """
            coarse_cell_len = 2*self._cell_len
            coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len),
                                    self._right_bound + (1.5*coarse_cell_len),
                                    coarse_cell_len)
    
            coarse_projection = self._calculate_coarse_projection(projection)
    
            # Plot exact and approximate solutions for coarse mesh
            grid, exact = calculate_exact_solution(
                coarse_mesh[1:-1], coarse_cell_len, self._wave_speed,
                self._final_time, self._interval_len, self._quadrature,
                self._init_cond)
            approx = calculate_approximate_solution(
                coarse_projection, self._quadrature.get_eval_points(),
                self._polynomial_degree, self._basis.get_basis_vector())
            plot_solution_and_approx(
                grid, exact, approx, self._colors['coarse_exact'],
                self._colors['coarse_approx'])
    
    
    class Boxplot(WaveletDetector):
        """Class for troubled-cell detection based on Boxplots.
    
        Attributes
        ----------
        fold_len : int
            Length of folds considered in one Boxplot.
        whisker_len : int
            Length of Boxplot whiskers.
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._fold_len = config.pop('fold_len', 16)
            self._whisker_len = config.pop('whisker_len', 3)
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            indexed_coeffs = [[multiwavelet_coeffs[0, i], i]
                              for i in range(self._num_coarse_grid_cells)]
    
            if self._num_coarse_grid_cells < self._fold_len:
                self._fold_len = self._num_coarse_grid_cells
    
            num_folds = self._num_coarse_grid_cells//self._fold_len
            troubled_cells = []
    
            for fold in range(num_folds):
                sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:
                                                    (fold+1) * self._fold_len])
    
                boundary_index = self._fold_len//4
                balance_factor = self._fold_len/4.0 - boundary_index
    
                first_quartile = (1-balance_factor) \
                    * sorted_fold[boundary_index-1][0] \
                    + balance_factor * sorted_fold[boundary_index][0]
                third_quartile = (1-balance_factor) \
                    * sorted_fold[3*boundary_index-1][0]\
                    + balance_factor * sorted_fold[3*boundary_index][0]
    
                lower_bound = first_quartile \
                    - self._whisker_len * (third_quartile-first_quartile)
                upper_bound = third_quartile \
                    + self._whisker_len * (third_quartile-first_quartile)
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold:
                    if cell[0] < lower_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold[::-1][:]:
                    if cell[0] > upper_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
            return sorted(troubled_cells)
    
    
    class Theoretical(WaveletDetector):
        """Class for troubled-cell detection based on the projection averages and
        a cutoff factor.
    
        Attributes
        ----------
        cutoff_factor : float
            Cutoff factor above which a cell is considered troubled.
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._cutoff_factor = config.pop('cutoff_factor',
                                             np.sqrt(2) * self._cell_len)
            # comment to line above: or 2 or 3
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            projection : ndarray
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            troubled_cells = []
            max_avg = np.sqrt(0.5) \
                * max(1, max(abs(projection[0][cell+1])
                             for cell in range(self._num_coarse_grid_cells)))
    
            for cell in range(self._num_coarse_grid_cells):
                if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                    troubled_cells.append(cell)
    
            return troubled_cells
    
        def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
            """Checks whether a cell is troubled.
    
            Parameters
            ----------
            multiwavelet_coeffs : ndarray
                Matrix of multiwavelet coefficients.
            cell : int
                Index of cell.
            max_avg : float
                Maximum average of projection.
    
            Returns
            -------
            bool
                Flag whether cell is troubled.
    
            """
            max_value = max(abs(multiwavelet_coeffs[degree][cell])
                            for degree in range(self._polynomial_degree+1))/max_avg
            eps = self._cutoff_factor\
                / (self._cell_len*self._num_coarse_grid_cells*2)
    
            return max_value > eps