Skip to content
Snippets Groups Projects
Select Git revision
  • cf02ef68fde70a82e41d5223a86eb15b4619e7e2
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

self_bleu.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Troubled_Cell_Detector.py 19.11 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Fix cell averages and reconstructions to create data with an x-point stencil
    TODO: Add comments to get_cells() for ArtificialNeuralNetwork -> Done
    
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import torch
    from sympy import Symbol
    
    import ANN_Model
    
    x = Symbol('x')
    z = Symbol('z')
    
    
    class TroubledCellDetector(object):
        def __init__(self, config, mesh, wave_speed, polynomial_degree, num_grid_cells, final_time, left_bound, right_bound,
                     basis, init_cond, quadrature):
            self._mesh = mesh
            self._wave_speed = wave_speed
            self._polynomial_degree = polynomial_degree
            self._num_grid_cells = num_grid_cells
            self._final_time = final_time
            self._left_bound = left_bound
            self._right_bound = right_bound
            self._interval_len = right_bound - left_bound
            self._cell_len = self._interval_len / num_grid_cells
            self._basis = basis
            self._init_cond = init_cond
            self._quadrature = quadrature
    
            # Set parameters from config if existing
            self._plot_dir = config.pop('plot_dir', 'fig')
            self._colors = config.pop('colors', {})
    
            self._check_colors()
            self._reset(config)
    
        def _check_colors(self):
            self._colors['exact'] = self._colors.get('exact', 'k-')
            self._colors['approx'] = self._colors.get('approx', 'y')
    
        def _reset(self, config):
            sns.set()
    
        def get_name(self):
            return self.__class__.__name__
    
        def get_cells(self, projection):
            pass
    
        def calculate_cell_average_and_reconstructions(self, projection):
            cell_averages = self._calculate_approximate_solution(projection, [0], 0)
            left_reconstructions = self._calculate_approximate_solution(projection, [-1], self._polynomial_degree)
            right_reconstructions = self._calculate_approximate_solution(projection, [1], self._polynomial_degree)
            return np.array(list(map(np.float64, zip(cell_averages[:, 0], left_reconstructions[:, 1], cell_averages[:, 1],
                                     right_reconstructions[:, 1], cell_averages[:, 2]))))
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            self._plot_shock_tube(troubled_cell_history, time_history)
            max_error = self._plot_mesh(projection)
    
            print('p =', self._polynomial_degree)
            print('N =', self._num_grid_cells)
            print('maximum error =', max_error)
    
        def _plot_shock_tube(self, troubled_cell_history, time_history):
            plt.figure(6)
            for pos in range(len(time_history)):
                current_cells = troubled_cell_history[pos]
                for cell in current_cells:
                    plt.plot(cell, time_history[pos], 'k.')
            plt.xlim((0, self._num_grid_cells//2))
            plt.xlabel('Cell')
            plt.ylabel('Time')
            plt.title('Shock Tubes')
    
        def _plot_mesh(self, projection):
            grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
            approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_solution_and_approx(grid, exact, approx, self._colors['exact'], self._colors['approx'])
            plt.legend(['Exact', 'Approx'])
            self._plot_semilog_error(grid, pointwise_error)
            self._plot_error(grid, exact, approx)
    
            return max_error
    
        @staticmethod
        def _plot_solution_and_approx(grid, exact, approx, color_exact, color_approx):
            print(color_exact, color_approx)
            plt.figure(1)
            plt.plot(grid[0], exact[0], color_exact)
            plt.plot(grid[0], approx[0], color_approx)
            plt.xlabel('x')
            plt.ylabel('u(x,t)')
            plt.title('Solution and Approximation')
    
        @staticmethod
        def _plot_semilog_error(grid, pointwise_error):
            plt.figure(2)
            plt.semilogy(grid[0], pointwise_error[0])
            plt.xlabel('x')
            plt.ylabel('|u(x,t)-uh(x,t)|')
            plt.title('Semilog Error plotted at Evaluation points')
    
        @staticmethod
        def _plot_error(grid, exact, approx):
            plt.figure(3)
            plt.plot(grid[0], exact[0]-approx[0])
            plt.xlabel('X')
            plt.ylabel('u(x,t)-uh(x,t)')
            plt.title('Errors')
    
        def _calculate_exact_solution(self, mesh, cell_len):
            grid = []
            exact = []
            num_periods = np.floor(self._wave_speed * self._final_time / self._interval_len)
    
            for cell in range(len(mesh)):
                eval_points = mesh[cell] + cell_len/2 * self._quadrature.get_eval_points()
    
                eval_values = []
                for point in range(len(eval_points)):
                    new_entry = self._init_cond.calculate(eval_points[point] - self._wave_speed*self._final_time
                                                          + num_periods*self._interval_len)
                    eval_values.append(new_entry)
    
                grid.append(eval_points)
                exact.append(eval_values)
    
            exact = np.reshape(np.array(exact), (1, len(exact)*len(exact[0])))
            grid = np.reshape(np.array(grid), (1, len(grid)*len(grid[0])))
    
            return grid, exact
    
        def _calculate_approximate_solution(self, projection, points, polynomial_degree):
            num_points = len(points)
            basis = self._basis.get_basis_vector()
    
            basis_matrix = [[basis[degree].subs(x, points[point]) for point in range(num_points)]
                            for degree in range(polynomial_degree+1)]
    
            approx = [[sum(projection[degree][cell] * basis_matrix[degree][point]
                           for degree in range(polynomial_degree+1))
                       for point in range(num_points)]
                      for cell in range(len(projection[0]))]
    
            return np.reshape(np.array(approx), (1, len(approx) * num_points))
    
        def save_plots(self, name):
            # Set paths for plot files if not existing already
            if not os.path.exists(self._plot_dir):
                os.makedirs(self._plot_dir)
    
            if not os.path.exists(self._plot_dir + '/exact_and_approx'):
                os.makedirs(self._plot_dir + '/exact_and_approx')
    
            if not os.path.exists(self._plot_dir + '/semilog_error'):
                os.makedirs(self._plot_dir + '/semilog_error')
    
            if not os.path.exists(self._plot_dir + '/error'):
                os.makedirs(self._plot_dir + '/error')
    
            if not os.path.exists(self._plot_dir + '/shock_tube'):
                os.makedirs(self._plot_dir + '/shock_tube')
    
            # Save plots
            plt.figure(1)
            plt.savefig(self._plot_dir + '/exact_and_approx/' + name + '.pdf')
    
            plt.figure(2)
            plt.savefig(self._plot_dir + '/semilog_error/' + name + '.pdf')
    
            plt.figure(3)
            plt.savefig(self._plot_dir + '/error/' + name + '.pdf')
    
            plt.figure(6)
            plt.savefig(self._plot_dir + '/shock_tube/' + name + '.pdf')
    
    
    class NoDetection(TroubledCellDetector):
        def get_cells(self, projection):
            return []
    
    
    class ArtificialNeuralNetwork(TroubledCellDetector):
        def _reset(self, config):
            super()._reset(config)
    
            self._stencil_len = config.pop('stencil_len', 3)
            self._model = config.pop('model', 'ThreeLayerReLu')
            self._model_config = config.pop('model_config', {'input_size': self._stencil_len+2, 'first_hidden_size': 8,
                                                             'second_hidden_size': 4, 'output_size': 2,
                                                             'activation_function': 'Softmax',
                                                             'activation_config': {'dim': 1}})
            self._model_state = config.pop('model_state', 'Train24k24k_Valid8k8k_Norm12ReLU8+4nodesSM1Adamlr1e-2MSE.pt')
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            self._model = getattr(ANN_Model, self._model)(self._model_config)
    
        def get_cells(self, projection):
            # Reset ghost cells to adjust for stencil length
            num_ghost_cells = self._stencil_len//2
            projection = projection[:, 1:-1]
            projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]),
                                        axis=1)
    
            # Calculate input data depending on stencil length
            input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
                projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1])
                for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
    
            # Evaluate troubled cell probabilities
            self._model.load_state_dict(torch.load(self._model_state))
            self._model.eval()
    
            # Return troubled cells
            model_output = torch.round(self._model(input_data.float()))
            return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]
    
    
    class WaveletDetector(TroubledCellDetector):
        def _check_colors(self):
            self._colors['fine_exact'] = self._colors.get('fine_exact', 'k-.')
            self._colors['fine_approx'] = self._colors.get('fine_approx', 'b-.')
            self._colors['coarse_exact'] = self._colors.get('coarse_exact', 'k-')
            self._colors['coarse_approx'] = self._colors.get('coarse_approx', 'y')
    
        def _reset(self, config):
            super()._reset(config)
    
            # Set additional necessary parameter
            self._num_coarse_grid_cells = self._num_grid_cells//2
            self._wavelet_projection_left, self._wavelet_projection_right = self._basis.get_multiwavelet_projections()
    
        def get_cells(self, projection):
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection[:, 1: -1])
            return self._get_cells(multiwavelet_coeffs, projection)
    
        def _calculate_wavelet_coeffs(self, projection):
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5*(projection[:, 2*i] @ self._wavelet_projection_left
                                 + projection[:, 2*i+1] @ self._wavelet_projection_right)
                output_matrix.append(new_entry)
            return np.transpose(np.array(output_matrix))
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            return []
    
        def plot_results(self, projection, troubled_cell_history, time_history):
            self._plot_details(projection)
            super().plot_results(projection, troubled_cell_history, time_history)
    
        def _plot_details(self, projection):
            fine_mesh = self._mesh[2:-2]
    
            fine_projection = projection[:, 1:-1]
            coarse_projection = self._calculate_coarse_projection(projection)
            multiwavelet_coeffs = self._calculate_wavelet_coeffs(projection)
            basis = self._basis.get_basis_vector()
            wavelet = self._basis.get_wavelet_vector()
    
    ########################################################################################################################
            # For later consideration
    ########################################################################################################################
            # tic = timeit.default_timer()
            # averaged_projection1 = []
            # wavelet_projection1 = []
            # for degree in range(self._polynomial_degree + 1):
            #     leftMesh = coarse_projection[degree] * basis[degree].subs(x, -1 / 2)
            #     rightMesh = coarse_projection[degree] * basis[degree].subs(x, 1 / 2)
            #     leftTest = multiwavelet_coeffs[degree] * wavelet[degree].subs(z, 1 / 2) \
            #                * (-1)**(self._polynomial_degree + 1 + degree)
            #     rightTest = multiwavelet_coeffs[degree] * wavelet[degree].subs(z, 1 / 2)
            #     newRowMesh = []
            #     newRowTest = []
            #     for i in range(len(coarse_projection[0])):
            #         newRowMesh.append(leftMesh[i])
            #         newRowMesh.append(rightMesh[i])
            #         newRowTest.append(leftTest[i])
            #         newRowTest.append(rightTest[i])
            #     averaged_projection1.append(newRowMesh)
            #     wavelet_projection1.append(newRowTest)
            # toc = timeit.default_timer()
            # print('Loop:', toc-tic)
    ########################################################################################################################
    
            # tic = timeit.default_timer()
            averaged_projection = [[coarse_projection[degree][cell] * basis[degree].subs(x, value)
                                    for cell in range(self._num_coarse_grid_cells)
                                    for value in [-0.5, 0.5]]
                                   for degree in range(self._polynomial_degree + 1)]
    
            wavelet_projection = [[multiwavelet_coeffs[degree][cell] * wavelet[degree].subs(z, 0.5) * value
                                   for cell in range(self._num_coarse_grid_cells)
                                   for value in [(-1) ** (self._polynomial_degree + degree + 1), 1]]
                                  for degree in range(self._polynomial_degree + 1)]
            # toc = timeit.default_timer()
            # print('List:', toc-tic)
    
            # print(averaged_projection1 == averaged_projection)
            # print(wavelet_projection1 == wavelet_projection)
    
            projected_coarse = np.sum(averaged_projection, axis=0)
            projected_fine = np.sum([fine_projection[degree] * basis[degree].subs(x, 0)
                                     for degree in range(self._polynomial_degree + 1)], axis=0)
            projected_wavelet_coeffs = np.sum(wavelet_projection, axis=0)
    
            plt.figure(4)
            plt.plot(fine_mesh, projected_fine - projected_coarse, 'm-.')
            plt.plot(fine_mesh, projected_wavelet_coeffs, 'y')
            plt.legend(['Fine-Coarse', 'Wavelet Coeff'])
            plt.xlabel('X')
            plt.ylabel('Detail Coefficients')
            plt.title('Wavelet Coefficients')
    
        def _calculate_coarse_projection(self, projection):
            basis_projection_left, basis_projection_right = self._basis.get_basis_projections()
    
            # Remove ghost cells
            projection = projection[:, 1:-1]
    
            # Calculate projection on coarse mesh
            output_matrix = []
            for i in range(self._num_coarse_grid_cells):
                new_entry = 0.5 * (projection[:, 2 * i] @ basis_projection_left
                                   + projection[:, 2 * i + 1] @ basis_projection_right)
                output_matrix.append(new_entry)
            coarse_projection = np.transpose(np.array(output_matrix))
    
            return coarse_projection
    
        def _plot_mesh(self, projection):
            grid, exact = self._calculate_exact_solution(self._mesh[2:-2], self._cell_len)
            approx = self._calculate_approximate_solution(projection[:, 1:-1], self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
    
            pointwise_error = np.abs(exact-approx)
            max_error = np.max(pointwise_error)
    
            self._plot_coarse_mesh(projection)
            self._plot_solution_and_approx(grid, exact, approx, self._colors['fine_exact'], self._colors['fine_approx'])
            plt.legend(['Exact (Coarse)', 'Approx (Coarse)', 'Exact (Fine)', 'Approx (Fine)'])
            self._plot_semilog_error(grid, pointwise_error)
            self._plot_error(grid, exact, approx)
    
            return max_error
    
        def save_plots(self, name):
            super().save_plots(name)
    
            # Set path for details plot files if not existing already
            if not os.path.exists(self._plot_dir + '/coeff_details'):
                os.makedirs(self._plot_dir + '/coeff_details')
    
            plt.figure(4)
            plt.savefig(self._plot_dir + '/coeff_details/' + name + '.pdf')
    
        def _plot_coarse_mesh(self, projection):
            coarse_cell_len = 2*self._cell_len
            coarse_mesh = np.arange(self._left_bound - (0.5*coarse_cell_len), self._right_bound + (1.5*coarse_cell_len),
                                    coarse_cell_len)
    
            coarse_projection = self._calculate_coarse_projection(projection)
    
            # Plot exact and approximate solutions for coarse mesh
            grid, exact = self._calculate_exact_solution(coarse_mesh[1:-1], coarse_cell_len)
            approx = self._calculate_approximate_solution(coarse_projection, self._quadrature.get_eval_points(),
                                                          self._polynomial_degree)
            self._plot_solution_and_approx(grid, exact, approx, self._colors['coarse_exact'], self._colors['coarse_approx'])
    
    
    class Boxplot(WaveletDetector):
        def _reset(self, config):
            super()._reset(config)
    
            # Unpack necessary configurations
            self._fold_len = config.pop('fold_len', 16)
            self._whisker_len = config.pop('whisker_len', 3)
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            indexed_coeffs = [[multiwavelet_coeffs[0, i], i]for i in range(self._num_coarse_grid_cells)]
    
            if self._num_coarse_grid_cells < self._fold_len:
                self._fold_len = self._num_coarse_grid_cells
    
            num_folds = self._num_coarse_grid_cells//self._fold_len
            troubled_cells = []
    
            for fold in range(num_folds):
                sorted_fold = sorted(indexed_coeffs[fold * self._fold_len:(fold+1) * self._fold_len])
    
                boundary_index = self._fold_len//4
                balance_factor = self._fold_len/4.0 - boundary_index
    
                first_quartile = (1-balance_factor) * sorted_fold[boundary_index-1][0] \
                    + balance_factor * sorted_fold[boundary_index][0]
                third_quartile = (1-balance_factor) * sorted_fold[3*boundary_index-1][0]\
                    + balance_factor * sorted_fold[3*boundary_index][0]
    
                lower_bound = first_quartile - self._whisker_len * (third_quartile-first_quartile)
                upper_bound = third_quartile + self._whisker_len * (third_quartile-first_quartile)
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold:
                    if cell[0] < lower_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
                # Check for lower extreme outliers and add respective cells
                for cell in sorted_fold[::-1][:]:
                    if cell[0] > upper_bound:
                        troubled_cells.append(cell[1])
                    else:
                        break
    
            return sorted(troubled_cells)
    
    
    class Theoretical(WaveletDetector):
        def _reset(self, config):
            super()._reset(config)
    
            # Unpack necessary configurations
            self._cutoff_factor = config.pop('cutoff_factor', np.sqrt(2) * self._cell_len)
            # comment to line above: or 2 or 3
    
        def _get_cells(self, multiwavelet_coeffs, projection):
            troubled_cells = []
            max_avg = np.sqrt(0.5) * max(1, max(abs(projection[0][cell+1]) for cell in range(self._num_coarse_grid_cells)))
    
            for cell in range(self._num_coarse_grid_cells):
                if self._is_troubled_cell(multiwavelet_coeffs, cell, max_avg):
                    troubled_cells.append(cell)
    
            return troubled_cells
    
        def _is_troubled_cell(self, multiwavelet_coeffs, cell, max_avg):
            max_value = max(abs(multiwavelet_coeffs[degree][cell])
                            for degree in range(self._polynomial_degree+1))/max_avg
            eps = self._cutoff_factor / (self._cell_len*self._num_coarse_grid_cells*2)
    
            return max_value > eps