Skip to content
Snippets Groups Projects
Select Git revision
  • a323f373d6f2e88b47d026f7c1d99ef0bc09d7b2
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

README.md

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Troubled_Cell_Detector.py 19.40 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Move plotting to separate file (try to adjust for different equations)
    
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import torch
    from sympy import Symbol
    
    import ANN_Model
    
    x = Symbol('x')
    z = Symbol('z')
    
    
    class TroubledCellDetector(object):
        def __init__(self, config, mesh, wave_speed, polynomial_degree, num_grid_cells, final_time, left_bound, right_bound,
                     basis, init_cond, quadrature):
            self._mesh = mesh
            self._wave_speed = wave_speed
            self._polynomial_degree = polynomial_degree
            self._num_grid_cells = num_grid_cells
            self._final_time = final_time
            self._left_bound = left_bound
            self._right_bound = right_bound
            self._interval_len = right_bound - left_bound
            self._cell_len = self._interval_len / num_grid_cells
            self._basis = basis
            self._init_cond = init_cond
            self._quadrature = quadrature
    
            # Set parameters from config if existing
            self._plot_dir = config.pop('plot_dir', 'fig')
            self._colors = config.pop('colors', {})
    
            self._check_colors()
            self._reset(config)
    
        def _check_colors(self):
            self._colors['exact'] = self._colors.get('exact', 'k-')
            self._colors['approx'] = self._colors.get('approx', 'y')
    
        def _reset(self, config):
            sns.set()
    
        def get_name(self):
            return self.__class__.__name__
    
        def get_cells(self, projection):
            pass
    
        def calculate_cell_average_and_reconstructions(self, projection, stencil_length):
            """
            Calculate the cell averages of all cells in a projection. Reconstructions are only calculated for the middle
            cell and added left and right to it, respectively.
    
            Here come some parameter.
            """
            cell_averages = self._calculate_approximate_solution(projection, [0], 0)
            left_reconstructions = self._calculate_approximate_solution(projection, [-1], self._polynomial_degree)
            right_reconstructions = self._calculate_approximate_solution(projection, [1], self._polynomial_degree)
            middle_idx = stencil_length//2
            return np.array(list(map(np.float64, zip(cell_averages[:, :middle_idx],
                            left_reconstructions[:, middle_idx], cell_averages[:, middle_idx],
                            right_reconstructions[:, middle_idx], cell_averages[:, middle_idx+1:]))))
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            self._plot_shock_tube(troubled_cell_history, time_history)
            max_error = self._plot_mesh(projection)
    
            print('p =', self._polynomial_degree)
            print('N =', self._num_grid_cells)
            print('maximum error =', max_error)
    
        def _plot_shock_tube(self, troubled_cell_history, time_history):
            plt.figure(6)
            for pos in range(len(time_history)):
                current_cells = troubled_cell_history[pos]
                for cell in current_cells:
                    plt.plot(cell, time_history[pos], 'k.')
            plt.xlim((0, self._num_grid_cells//2))
            plt.xlabel('Cell')
            plt.ylabel('Time')
            plt.title('Shock Tubes')
    
        def _plot_mesh(self, projection):
            grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
            approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
            plt.legend(['Exact', 'Approx'])
            self._plot_semilog_error(grid, pointwise_error)
            self._plot_error(grid, exact, approx)
    
            return max_error
    
        @staticmethod
        def _plot_solution_and_approx(grid, exact, approx, color_exact, color_approx):
            print(color_exact, color_approx)
            plt.figure(1)
            plt.plot(grid[0], exact[0], color_exact)
            plt.plot(grid[0], approx[0], color_approx)
            plt.xlabel('x')
            plt.ylabel('u(x,t)')
            plt.title('Solution and Approximation')
    
        @staticmethod
        def _plot_semilog_error(grid, pointwise_error):
            plt.figure(2)
            plt.semilogy(grid[0], pointwise_error[0])
            plt.xlabel('x')
            plt.ylabel('|u(x,t)-uh(x,t)|')
            plt.title('Semilog Error plotted at Evaluation points')
    
        @staticmethod
        def _plot_error(grid, exact, approx):
            plt.figure(3)
            plt.plot(grid[0], exact[0]-approx[0])
            plt.xlabel('X')
            plt.ylabel('u(x,t)-uh(x,t)')
            plt.title('Errors')
    
        def _calculate_exact_solution(self, mesh, cell_len):
            grid = []
            exact = []
            num_periods = np.floor(self._wave_speed * self._final_time / self._interval_len)
    
            for cell in range(len(mesh)):
                eval_points = mesh[cell] + cell_len/2 * self._quadrature.get_eval_points()
    
                eval_values = []
                for point in range(len(eval_points)):
                    new_entry = self._init_cond.calculate(eval_points[point] - self._wave_speed*self._final_time
                                                          + num_periods*self._interval_len)
                    eval_values.append(new_entry)
    
                grid.append(eval_points)
                exact.append(eval_values)
    
            exact = np.reshape(np.array(exact), (1, len(exact)*len(exact[0])))
            grid = np.reshape(np.array(grid), (1, len(grid)*len(grid[0])))
    
            return grid, exact
    
        def _calculate_approximate_solution(self, projection, points, polynomial_degree):
            num_points = len(points)
            basis = self._basis.get_basis_vector()
    
            basis_matrix = [[basis[degree].subs(x, points[point]) for point in range(num_points)]
                            for degree in range(polynomial_degree+1)]
    
            approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
                           for degree in range(polynomial_degree+1))
                       for point in range(num_points)]
                      for cell in range(len(projection[0]))]
    
            return np.reshape(np.array(approx), (1, len(approx) * num_points))
    
        def save_plots(self, name):
            # Set paths for plot files if not existing already
            if not os.path.exists(self._plot_dir):
                os.makedirs(self._plot_dir)
    
            if not os.path.exists(self._plot_dir + '/exact_and_approx'):
                os.makedirs(self._plot_dir + '/exact_and_approx')
    
            if not os.path.exists(self._plot_dir + '/semilog_error'):
                os.makedirs(self._plot_dir + '/semilog_error')
    
            if not os.path.exists(self._plot_dir + '/error'):
                os.makedirs(self._plot_dir + '/error')
    
            if not os.path.exists(self._plot_dir + '/shock_tube'):
                os.makedirs(self._plot_dir + '/shock_tube')
    
            # Save plots
            plt.figure(1)
            plt.savefig(self._plot_dir + '/exact_and_approx/' + name + '.pdf')
    
            plt.figure(2)
            plt.savefig(self._plot_dir + '/semilog_error/' + name + '.pdf')
    
            plt.figure(3)
            plt.savefig(self._plot_dir + '/error/' + name + '.pdf')
    
            plt.figure(6)
            plt.savefig(self._plot_dir + '/shock_tube/' + name + '.pdf')
    
    
    class NoDetection(TroubledCellDetector):
        def get_cells(self, projection):
            return []
    
    
    class ArtificialNeuralNetwork(TroubledCellDetector):
        def _reset(self, config):
            super()._reset(config)
    
            self._stencil_len = config.pop('stencil_len', 3)
            self._model = config.pop('model', 'ThreeLayerReLu')
            self._model_config = config.pop('model_config', {'input_size': self._stencil_len+2, 'first_hidden_size': 8,
                                                             'second_hidden_size': 4, 'output_size': 2,
                                                             'activation_function': 'Softmax',
                                                             'activation_config': {'dim': 1}})
            self._model_state = config.pop('model_state', 'Train24k24k_Valid8k8k_Norm12ReLU8+4nodesSM1Adamlr1e-2MSE.pt')
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            self._model = getattr(ANN_Model, self._model)(self._model_config)
    
        def get_cells(self, projection):
            # Reset ghost cells to adjust for stencil length
            num_ghost_cells = self._stencil_len//2
            projection = projection[:, 1:-1]
            projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]),
                                        axis=1)
    
            # Calculate input data depending on stencil length
            input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
                projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1], self._stencil_len)
                for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
    
            # Evaluate troubled cell probabilities
            self._model.load_state_dict(torch.load(self._model_state))
            self._model.eval()
    
            # Return troubled cells
            model_output = torch.round(self._model(input_data.float()))
            return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]
    
    
    class WaveletDetector(TroubledCellDetector):
        def _check_colors(self):
            self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
            self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
            self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
            self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')
    
        def _reset(self, config):
            super()._reset(config)
    
            # Set additional necessary parameter
            self._num_coarse_grid_cells = self._num_grid_cells//2
            self._wavelet_projection_left, self._wavelet_projection_right = self._basis.get_multiwavelet_projections()
    
        def get_cells(self, projection):
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection[:, 1: -1])
            return self._get_cells(multiwavelet_coeffs, projection)
    
        def _calculate_wavelet_coeffs(self, projection):
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5*(projection[:, 2*i] @ self._wavelet_projection_left
                                 + projection[:, 2*i+1] @ self._wavelet_projection_right)
                output_matrix.append(new_entry)
            return np.transpose(np.array(output_matrix))
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            return []
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            self._plot_details(projection)
            super().plot_results(projection, troubled_cell_history, time_history)
    
        def _plot_details(self, projection):
            fine_mesh = self._mesh[2:-2]
    
            fine_projection = projection[:, 1:-1]
            coarse_projection = self._calculate_coarse_projection(projection)
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            basis = self._basis.get_basis_vector()
            wavelet = self._basis.get_wavelet_vector()
    
    ########################################################################################################################
            # For later consideration
    ########################################################################################################################
            # tic = timeit.default_timer()
            # averaged_projection1 = []
            # wavelet_projection1 = []
            # for degree in range(self._polynomial_degree + 1):
            #     leftMesh = coarse_projection[degree] * basis[degree].subs(x, -1 / 2)
            #     rightMesh = coarse_projection[degree] * basis[degree].subs(x, 1 / 2)
            #     leftTest = multiwavelet_coeffs[degree] * wavelet[degree].subs(z, 1 / 2) \
            #                * (-1)**(self._polynomial_degree + 1 + degree)
            #     rightTest = multiwavelet_coeffs[degree] * wavelet[degree].subs(z, 1 / 2)
            #     newRowMesh = []
            #     newRowTest = []
            #     for i in range(len(coarse_projection[0])):
            #         newRowMesh.append(leftMesh[i])
            #         newRowMesh.append(rightMesh[i])
            #         newRowTest.append(leftTest[i])
            #         newRowTest.append(rightTest[i])
            #     averaged_projection1.append(newRowMesh)
            #     wavelet_projection1.append(newRowTest)
            # toc = timeit.default_timer()
            # print('Loop:', toc-tic)
    ########################################################################################################################
    
            # tic = timeit.default_timer()
            averaged_projection = [[coarse_projection[degree][cell] * basis[degree].subs(x, value)
                                    for cell in range(self._num_coarse_grid_cells)
                                    for value in [-0.5, 0.5]]
                                   for degree in range(self._polynomial_degree + 1)]
    
            wavelet_projection = [[multiwavelet_coeffs[degree][cell] * wavelet[degree].subs(z, 0.5) * value
                                   for cell in range(self._num_coarse_grid_cells)
                                   for value in [(-1) ** (self._polynomial_degree + degree + 1), 1]]
                                  for degree in range(self._polynomial_degree + 1)]
            # toc = timeit.default_timer()
            # print('List:', toc-tic)
    
            # print(averaged_projection1 == averaged_projection)
            # print(wavelet_projection1 == wavelet_projection)
    
            projected_coarse = np.sum(averaged_projection, axis=0)
            projected_fine = np.sum([fine_projection[degree] * basis[degree].subs(x, 0)
                                     for degree in range(self._polynomial_degree + 1)], axis=0)
            projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
    
            plt.figure(4)
            plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
            plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
            plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
            plt.xlabel('X')
            plt.ylabel('Detail Coefficients')
            plt.title('Wavelet Coefficients')
    
        def _calculate_coarse_projection(self, projection):
            basis_projection_left, basis_projection_right = self._basis.get_basis_projections()
    
            # Remove ghost cells
            projection = projection[:, 1:-1]
    
            # Calculate projection on coarse mesh
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5 * (projection[:, 2 * i] @ basis_projection_left
                                   + projection[:, 2 * i + 1] @ basis_projection_right)
                output_matrix.append(new_entry)
            coarse_projection = np.transpose(np.array(output_matrix))
    
            return coarse_projection
    
        def _plot_mesh(self, projection):
            grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
            approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_coarse_mesh(projection)
            self._plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'], self._colors['fine_approx'])
            plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', 'Approx (Fine)'])
            self._plot_semilog_error(grid, pointwise_error)
            self._plot_error(grid, exact, approx)
    
            return max_error
    
        def save_plots(self, name):
            super().save_plots(name)
    
            # Set path for details plot files if not existing already
            if not os.path.exists(self._plot_dir + '/coeff_details'):
                os.makedirs(self._plot_dir + '/coeff_details')
    
            plt.figure(4)
            plt.savefig(self._plot_dir + '/coeff_details/' + name + '.pdf')
    
        def _plot_coarse_mesh(self, projection):
            coarse_cell_len = 2*self._cell_len
            coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len), self._right_bound + (1.5*coarse_cell_len),
                                    coarse_cell_len)
    
            coarse_projection = self._calculate_coarse_projection(projection)
    
            # Plot exact and approximate solutions for coarse mesh
            grid, exact = self._calculate_exact_solution(coarse_mesh[1:-1], coarse_cell_len)
            approx = self._calculate_approximate_solution(coarse_projection, self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
            self._plot_solution_and_approx(grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])
    
    
    class Boxplot(WaveletDetector):
        def _reset(self, config):
            super()._reset(config)
    
            # Unpack necessary configurations
            self._fold_len = config.pop('fold_len', 16)
            self._whisker_len = config.pop('whisker_len', 3)
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            indexed_coeffs = [[multiwavelet_coeffs[0, i], i]for i in range(self._num_coarse_grid_cells)]
    
            if self._num_coarse_grid_cells < self._fold_len:
                self._fold_len = self._num_coarse_grid_cells
    
            num_folds = self._num_coarse_grid_cells//self._fold_len
            troubled_cells = []
    
            for fold in range(num_folds):
                sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:(fold+1) * self._fold_len])
    
                boundary_index = self._fold_len//4
                balance_factor = self._fold_len/4.0 - boundary_index
    
                first_quartile = (1-balance_factor) * sorted_fold[boundary_index-1][0] \
                    + balance_factor * sorted_fold[boundary_index][0]
                third_quartile = (1-balance_factor) * sorted_fold[3*boundary_index-1][0]\
                    + balance_factor * sorted_fold[3*boundary_index][0]
    
                lower_bound = first_quartile - self._whisker_len * (third_quartile-first_quartile)
                upper_bound = third_quartile + self._whisker_len * (third_quartile-first_quartile)
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold:
                    if cell[0] < lower_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold[::-1][:]:
                    if cell[0] > upper_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
            return sorted(troubled_cells)
    
    
    class Theoretical(WaveletDetector):
        def _reset(self, config):
            super()._reset(config)
    
            # Unpack necessary configurations
            self._cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._cell_len)
            # comment to line above: or 2 or 3
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            troubled_cells = []
            max_avg = np.sqrt(0.5) * max(1, max(abs(projection[0][cell+1]) for cell in range(self._num_coarse_grid_cells)))
    
            for cell in range(self._num_coarse_grid_cells):
                if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                    troubled_cells.append(cell)
    
            return troubled_cells
    
        def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
            max_value = max(abs(multiwavelet_coeffs[degree][cell])
                            for degree in range(self._polynomial_degree+1))/max_avg
            eps = self._cutoff_factor / (self._cell_len*self._num_coarse_grid_cells*2)
    
            return max_value > eps