Skip to content
Snippets Groups Projects
Select Git revision
  • 6fd7daeef95e3dff552d7a0bb8b6a8fdab5a6a69
  • master default protected
2 results

Troubled_Cell_Detector.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Troubled_Cell_Detector.py 21.99 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Adjust TCs for wavelet detectors (sliding window over all cells instead of every second)
    TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
    TODO: Give detailed description of wavelet detection
    TODO: Load ANN state and config in reset
    
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import torch
    from sympy import Symbol
    
    import ANN_Model
    from Plotting import plot_solution_and_approx, plot_semilog_error, plot_error, plot_shock_tube, \
        plot_details, calculate_approximate_solution, calculate_exact_solution
    
    x = Symbol('x')
    z = Symbol('z')
    
    
    class TroubledCellDetector(object):
        """Class for troubled-cell detection.
    
        Detects troubled cells, i.e., cells in the mesh containing instabilities.
    
        Attributes
        ----------
        interval_len : float
            Length of the interval between left and right boundary.
        cell_len : float
            Length of a cell in mesh.
    
        Methods
        -------
        get_name()
            Returns string of class name.
        get_cells(projection)
            Calculates troubled cells in a given projection.
        calculate_cell_average_and_reconstructions(projection, stencil_length)
            Calculates cell averages and reconstructions for a given projection.
        plot_results(projection, troubled_cell_history, time_history)
            Plots results and troubled cells of a projection given its evaluation history.
    
        """
        def __init__(self, config, mesh, wave_speed, polynomial_degree, num_grid_cells, final_time,
                     left_bound, right_bound, basis, init_cond, quadrature):
            """Initializes TroubledCellDetector.
    
            Parameters
            ----------
            mesh : array
                List of mesh valuation points.
            wave_speed : float
                Speed of wave in rightward direction.
            polynomial_degree : int
                Polynomial degree.
            num_grid_cells : int
                Number of cells in the mesh. Usually exponential of 2.
            final_time : float
                Final time for which approximation is calculated.
            left_bound : float
                Left boundary of interval.
            right_bound : float
                Right boundary of interval.
            basis : Basis object
                Basis for calculation.
            init_cond : InitialCondition object
                Initial condition for evaluation.
            quadrature : Quadrature object
                Quadrature for evaluation.
    
            """
            self._mesh = mesh
            self._wave_speed = wave_speed
            self._polynomial_degree = polynomial_degree
            self._num_grid_cells = num_grid_cells
            self._final_time = final_time
            self._left_bound = left_bound
            self._right_bound = right_bound
            self._interval_len = right_bound - left_bound
            self._cell_len = self._interval_len / num_grid_cells
            self._basis = basis
            self._init_cond = init_cond
            self._quadrature = quadrature
    
            # Set parameters from config if existing
            self._colors = config.pop('colors', {})
    
            self._check_colors()
            self._reset(config)
    
        def _check_colors(self):
            """Checks plot colors.
    
            Checks whether colors for plots were given and sets them if required.
    
            """
            self._colors['exact'] = self._colors.get('exact', 'k-')
            self._colors['approx'] = self._colors.get('approx', 'y')
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            sns.set()
    
        def get_name(self):
            """Returns string of class name."""
            return self.__class__.__name__
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            """
            pass
    
        def calculate_cell_average_and_reconstructions(self, projection, stencil_length):
            """Calculates cell averages and reconstructions for a given projection.
    
            Calculate the cell averages of all cells in a projection. Reconstructions are only
            calculated for the middle cell and added left and right to it, respectively.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
            stencil_length : int
                Size of data array.
    
            Returns
            -------
            np.array
                Matrix containing cell averages and reconstructions for initial projection.
    
            """
            cell_averages = calculate_approximate_solution(
                projection, [0], 0, self._basis.get_basis_vector())
            left_reconstructions = calculate_approximate_solution(
                projection, [-1], self._polynomial_degree, self._basis.get_basis_vector())
            right_reconstructions = calculate_approximate_solution(
                projection, [1], self._polynomial_degree, self._basis.get_basis_vector())
            middle_idx = stencil_length//2
            return np.array(list(map(np.float64, zip(cell_averages[:, :middle_idx],
                            left_reconstructions[:, middle_idx], cell_averages[:, middle_idx],
                            right_reconstructions[:, middle_idx], cell_averages[:, middle_idx+1:]))))
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            """Plots results and troubled cells of a projection given its evaluation history.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
            troubled_cell_history : list
                List of detected troubled cells for each time step.
            time_history:
                List of value of each time step.
    
            """
            plot_shock_tube(self._num_grid_cells, troubled_cell_history, time_history)
            max_error = self._plot_mesh(projection)
    
            print('p =', self._polynomial_degree)
            print('N =', self._num_grid_cells)
            print('maximum error =', max_error)
    
        def _plot_mesh(self, projection):
            """Plots exact and approximate solution as well as errors.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            max_error
                Maximum error between exact and approximate solution.
    
            """
            grid, exact = calculate_exact_solution(
                self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
                self._interval_len, self._quadrature, self._init_cond)
            approx = calculate_approximate_solution(
                projection[:, 1:-1], self._quadrature.get_eval_points(), self._polynomial_degree,
                self._basis.get_basis_vector())
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
            plt.legend(['Exact', 'Approx'])
            plot_semilog_error(grid, pointwise_error)
            plot_error(grid, exact, approx)
    
            return max_error
    
    
    class NoDetection(TroubledCellDetector):
        """Class without any troubled-cell detection.
    
        Methods
        -------
        get_cells(projection)
            Returns no troubled cells.
    
        """
        def get_cells(self, projection):
            """Returns no troubled cells.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            """
            return []
    
    
    class ArtificialNeuralNetwork(TroubledCellDetector):
        """Class for troubled-cell detection using ANNs.
    
        Attributes
        ----------
        stencil_length : int
            Size of input data array.
        model : ANNModel object
            ANN model instance for evaluation.
    
        Methods
        -------
        get_cells(projection)
            Calculates troubled cells in a given projection.
    
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            self._stencil_len = config.pop('stencil_len', 3)
            self._model = config.pop('model', 'ThreeLayerReLu')
            self._model_config = config.pop('model_config', {
                'input_size': self._stencil_len+2, 'first_hidden_size': 8, 'second_hidden_size': 4,
                'output_size': 2, 'activation_function': 'Softmax', 'activation_config': {'dim': 1}})
            self._model_state = config.pop(
                'model_state', 'Snakemake-Test/trained models/model__Adam.pt')
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            self._model = getattr(ANN_Model, self._model)(self._model_config)
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            # Reset ghost cells to adjust for stencil length
            num_ghost_cells = self._stencil_len//2
            projection = projection[:, 1:-1]
            projection = np.concatenate((projection[:, -num_ghost_cells:], projection,
                                         projection[:, :num_ghost_cells]),
                                        axis=1)
    
            # Calculate input data depending on stencil length
            input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
                projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
                for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
    
            # Evaluate troubled cell probabilities
            self._model.load_state_dict(torch.load(self._model_state))
            self._model.eval()
    
            # Return troubled cells
            model_output = torch.round(self._model(input_data.float()))
            return [cell for cell in range(len(model_output))
                    if model_output[cell, 0] == torch.tensor([1])]
    
    
    class WaveletDetector(TroubledCellDetector):
        """Class for troubled-cell detection based on wavelet coefficients.
    
        ???
    
        """
        def _check_colors(self):
            """Checks plot colors.
    
            Checks whether colors for plots were given and sets them if required.
    
            """
            self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
            self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
            self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
            self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')
    
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Set additional necessary parameter
            self._num_coarse_grid_cells = self._num_grid_cells//2
            self._wavelet_projection_left, self._wavelet_projection_right \
                = self._basis.get_multiwavelet_projections()
    
        def get_cells(self, projection):
            """Calculates troubled cells in a given projection.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection[:, 1: -1])
            return self._get_cells(multiwavelet_coeffs, projection)
    
        def _calculate_wavelet_coeffs(self, projection):
            """Calculates wavelet coefficients used for projection to coarser grid.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            np.array
                Matrix of wavelet coefficients.
    
            """
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5*(projection[:, 2*i] @ self._wavelet_projection_left
                                 + projection[:, 2*i+1] @ self._wavelet_projection_right)
                output_matrix.append(new_entry)
            return np.transpose(np.array(output_matrix))
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : np.array
                Matrix of multiwavelet coefficients.
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            return []
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            """Plots results and troubled cells of a projection given its evaluation history.
    
            Plots results on coarse and fine grid, errors, troubled cells, and coefficient details.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
            troubled_cell_history : list
                List of detected troubled cells for each time step.
            time_history:
                List of value of each time step.
    
            """
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            coarse_projection = self._calculate_coarse_projection(projection)
            plot_details(projection[:, 1:-1], self._mesh[2:-2], coarse_projection,
                         self._basis.get_basis_vector(), self._basis.get_wavelet_vector(),
                         multiwavelet_coeffs, self._num_coarse_grid_cells,
                         self._polynomial_degree)
            super().plot_results(projection, troubled_cell_history, time_history)
    
        def _calculate_coarse_projection(self, projection):
            """Calculates coarse projection.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            np.array
                Matrix of projection on coarse grid for each polynomial degree.
    
            """
            basis_projection_left, basis_projection_right = self._basis.get_basis_projections()
    
            # Remove ghost cells
            projection = projection[:, 1:-1]
    
            # Calculate projection on coarse mesh
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5 * (projection[:, 2 * i] @ basis_projection_left
                                   + projection[:, 2 * i + 1] @ basis_projection_right)
                output_matrix.append(new_entry)
            coarse_projection = np.transpose(np.array(output_matrix))
    
            return coarse_projection
    
        def _plot_mesh(self, projection):
            """Plots exact and approximate solution as well as errors.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            max_error
                Maximum error between exact and approximate solution.
    
            """
    
            grid, exact = calculate_exact_solution(
                self._mesh[2:-2], self._cell_len, self._wave_speed, self._final_time,
                self._interval_len, self._quadrature, self._init_cond)
            approx = calculate_approximate_solution(
                projection[:, 1:-1], self._quadrature.get_eval_points(), self._polynomial_degree,
                self._basis.get_basis_vector())
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_coarse_mesh(projection)
            plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'],
                                     self._colors['fine_approx'])
            plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', 'Approx (Fine)'])
            plot_semilog_error(grid, pointwise_error)
            plot_error(grid, exact, approx)
    
            return max_error
    
        def _plot_coarse_mesh(self, projection):
            """Plots exact and approximate solution as well as errors for a coarse projection.
    
            Parameters
            ----------
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            """
            coarse_cell_len = 2*self._cell_len
            coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len),
                                    self._right_bound + (1.5*coarse_cell_len),
                                    coarse_cell_len)
    
            coarse_projection = self._calculate_coarse_projection(projection)
    
            # Plot exact and approximate solutions for coarse mesh
            grid, exact = calculate_exact_solution(
                coarse_mesh[1:-1], coarse_cell_len, self._wave_speed, self._final_time,
                self._interval_len, self._quadrature, self._init_cond)
            approx = calculate_approximate_solution(
                coarse_projection, self._quadrature.get_eval_points(), self._polynomial_degree,
                self._basis.get_basis_vector())
            plot_solution_and_approx(
                grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])
    
    
    class Boxplot(WaveletDetector):
        """Class for troubled-cell detection based on Boxplots.
    
        Attributes
        ----------
        fold_len : int
            Length of folds considered in one Boxplot.
        whisker_len : int
            Length of Boxplot whiskers.
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._fold_len = config.pop('fold_len', 16)
            self._whisker_len = config.pop('whisker_len', 3)
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : np.array
                Matrix of multiwavelet coefficients.
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            indexed_coeffs = [[multiwavelet_coeffs[0, i], i]for i in range(self._num_coarse_grid_cells)]
    
            if self._num_coarse_grid_cells < self._fold_len:
                self._fold_len = self._num_coarse_grid_cells
    
            num_folds = self._num_coarse_grid_cells//self._fold_len
            troubled_cells = []
    
            for fold in range(num_folds):
                sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:(fold+1) * self._fold_len])
    
                boundary_index = self._fold_len//4
                balance_factor = self._fold_len/4.0 - boundary_index
    
                first_quartile = (1-balance_factor) * sorted_fold[boundary_index-1][0] \
                    + balance_factor * sorted_fold[boundary_index][0]
                third_quartile = (1-balance_factor) * sorted_fold[3*boundary_index-1][0]\
                    + balance_factor * sorted_fold[3*boundary_index][0]
    
                lower_bound = first_quartile - self._whisker_len * (third_quartile-first_quartile)
                upper_bound = third_quartile + self._whisker_len * (third_quartile-first_quartile)
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold:
                    if cell[0] < lower_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold[::-1][:]:
                    if cell[0] > upper_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
            return sorted(troubled_cells)
    
    
    class Theoretical(WaveletDetector):
        """"Class for troubled-cell detection based on the projection averages and a cutoff factor.
    
        Attributes
        ----------
        cutoff_factor : float
            Cutoff factor above which a cell is considered troubled.
    
        """
        def _reset(self, config):
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for detector.
    
            """
            super()._reset(config)
    
            # Unpack necessary configurations
            self._cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._cell_len)
            # comment to line above: or 2 or 3
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            """Calculates troubled cells using multiwavelet coefficients.
    
            Parameters
            ----------
            multiwavelet_coeffs : np.array
                Matrix of multiwavelet coefficients.
            projection : np.array
                Matrix of projection for each polynomial degree.
    
            Returns
            -------
            list
                List of indices for all detected troubled cells.
    
            """
            troubled_cells = []
            max_avg = np.sqrt(0.5) * max(1, max(abs(projection[0][cell+1])
                                                for cell in range(self._num_coarse_grid_cells)))
    
            for cell in range(self._num_coarse_grid_cells):
                if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                    troubled_cells.append(cell)
    
            return troubled_cells
    
        def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
            """Checks whether a cell is troubled.
    
            Parameters
            ----------
            multiwavelet_coeffs : np.array
                Matrix of multiwavelet coefficients.
            cell : int
                Index of cell.
            max_avg
                Maximum average of projection.
    
            Returns
            -------
            boolean
                Flag whether cell is troubled.
    
            """
            max_value = max(abs(multiwavelet_coeffs[degree][cell])
                            for degree in range(self._polynomial_degree+1))/max_avg
            eps = self._cutoff_factor / (self._cell_len*self._num_coarse_grid_cells*2)
    
            return max_value > eps