Skip to content
Snippets Groups Projects
Select Git revision
  • df25e7352dd4721eaab5729497c452cdeb899745
  • master default protected
2 results

ANN_Data_Generator.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    DG_Approximation.py 13.88 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle
    
    Discussion:
    TODO: Contemplate saving 5-CV split and evaluating models separately
    TODO: Contemplate separating cell average and reconstruction calculations
        completely
    TODO: Contemplate removing Methods section from class docstring
    TODO: Contemplate containing the quadrature application for plots in Mesh
    TODO: Contemplate containing coarse mesh generation in Mesh
    TODO: Contemplate extracting boundary condition from InitialCondition
    TODO: Contemplate containing boundary condition in Mesh
    TODO: Ask whether all quadratures depend on freely chosen num_nodes
    TODO: Contemplate saving training data for each IC separately
    TODO: Contemplate removing TrainingDataGenerator class
    TODO: Ask why no fine mesh is used for calculate_coarse_projection()
    TODO: Discuss adding kwargs to attributes in documentation
    TODO: Discuss descriptions (matrices, cfl number, right-hand side,
        limiting slope, basis, wavelet, etc.)
    TODO: Discuss referencing info on SSPRK3
    
    Urgent:
    TODO: Fix indexing issue in wavelet coefficient calculation
    TODO: Enforce mesh with 2^n cells
    TODO: Enforce even number of ghost cells on each side on fine mesh (?)
    TODO: Change heaviside random to uniform(-100, 100)
    TODO: Adjust Heaviside to have non-symmetric values (left- and right_value)
    TODO: Rename 'adjustment' to 'shift'
    TODO: Induce shift in IC class
    TODO: Give option to set discontinuity to cell boundary
    TODO: Fix typo in LinearAbsolute
    TODO: Introduce middle_factor for two-sided Heaviside
    TODO: Move plot_approximation_results() into plotting script
    TODO: Improve file naming (e.g. use '.' instead of '__')
    TODO: Check whether ghost cells are handled/set correctly
    TODO: Unify use of 'length' and 'len' in naming
    TODO: Unify use of 'initial_condition' and 'init_cond' in naming
    TODO: Unify use of 'mesh' and 'grid' in naming
    TODO: Check sign change in stiffness matrix to accommodate negative wave speed
    TODO: Check whether 'projection' is always a ndarray
    TODO: Force input_size for each ANN model to be stencil length
    
    Critical, but not urgent:
    TODO: Introduce env files for each SM rule
    TODO: Add images to report
    TODO: Add verbose output
    TODO: Outsource scripts into separate directory
    TODO: Check whether all instance variables are sensible
    TODO: Combine ANN workflows if feasible
    TODO: Investigate profiling for speed up
    TODO: Rework Theoretical TC for efficiency
    TODO: Extract object initialization from DGScheme
    TODO: Replace loops with list comprehension if feasible
    TODO: Replace loops/list comprehension with vectorization if feasible
    
    Currently not critical:
    TODO: Give option to select plotting color
    TODO: Add an environment file for Snakemake
    TODO: Build package (module?) for DG scheme
    TODO: Investigate g-mesh(?)
    TODO: Create g-mesh with Mesh class
    TODO: Rename files according to standard
    TODO: Allow comparison between ANN training datasets
    TODO: Use full path for ANN model state
    TODO: Add a default model state
    TODO: Look into validators for variable checks
    
    Not feasible yet or doc-related:
    TODO: Add README for ANN training
    TODO: Give detailed description of wavelet detection
    TODO: Use cfl_number for updating, not just time (equation-related?)
    TODO: Adjust code to allow classes for all equations
        (Burger, linear advection, 1D Euler)
    TODO: Add ThresholdDetector
    TODO: Double-check everything! (also with pylint, pytype, pydoc,
        pycodestyle, pydocstyle)
    TODO: Check whether documentation style is correct
    TODO: Check whether all types in doc are correct
    TODO: Add type annotations to function heads
    TODO: Clean up docstrings
    
    """
    import json
    import numpy as np
    from sympy import Symbol
    import math
    import seaborn as sns
    
    import Troubled_Cell_Detector
    import Initial_Condition
    import Limiter
    import Quadrature
    import Update_Scheme
    from Basis_Function import OrthonormalLegendre
    from encoding_utils import encode_ndarray
    from projection_utils import Mesh
    
    x = Symbol('x')
    sns.set()
    
    
    class DGScheme:
        """Class for Discontinuous Galerkin Method.
    
        Approximates linear advection equation using Discontinuous Galerkin Method
        with troubled-cell-based limiting.
    
        Attributes
        ----------
        basis : Basis object
            Basis for calculation.
        mesh : Mesh
            Mesh for calculation.
    
        Methods
        -------
        approximate()
            Approximates projection.
        save_plots()
            Saves plots generated during approximation process.
        build_training_data(adjustment, stencil_length, initial_condition=None)
            Builds training data set.
    
        """
        def __init__(self, detector, **kwargs):
            """Initializes DGScheme.
    
            Parameters
            ----------
            detector : str
                Name of troubled cell detector class.
    
            Other Parameters
            ----------------
            wave_speed : float, optional
                Speed of wave in rightward direction. Default: 1.
            polynomial_degree : int, optional
                Polynomial degree. Default: 2.
            cfl_number : float, optional
                CFL number to ensure stability. Default: 0.2.
            num_grid_cells : int, optional
                Number of cells in the mesh. Usually exponential of 2. Default: 64.
            final_time : float, optional
                Final time for which approximation is calculated. Default: 1.
            left_bound : float, optional
                Left boundary of interval. Default: -1.
            right_bound : float, optional
                Right boundary of interval. Default: 1.
            verbose : bool, optional
                Flag whether commentary in console is wanted. Default: False.
            history_threshold : float, optional
                Threshold when history will be recorded.
                Default: math.ceil(0.2/cfl_number).
            detector_config : dict, optional
                Additional parameters for detector object. Default: {}.
            init_cond : str, optional
                Name of initial condition for evaluation. Default: 'Sine'
            init_config : dict, optional
                Additional parameters for initial condition object. Default: {}.
            limiter : str, optional
                Name of limiter for evaluation. Default: 'ModifiedMinMod'.
            limiter_config : dict, optional
                Additional parameters for limiter. object. Default: {}:
            quadrature : str, optional
                Name of quadrature for evaluation. Default: 'Gauss'.
            quadrature_config : dict, optional
                Additional parameters for quadrature object. Default: {}.
            update_scheme : str, optional
                Name of update scheme for evaluation. Default: 'SSPRK3'.
    
            """
            # Unpack keyword arguments
            self._wave_speed = kwargs.pop('wave_speed', 1)
            self._cfl_number = kwargs.pop('cfl_number', 0.2)
            self._final_time = kwargs.pop('final_time', 1)
            self._verbose = kwargs.pop('verbose', False)
            self._history_threshold = kwargs.pop('history_threshold',
                                                 math.ceil(0.2/self._cfl_number))
            self._detector = detector
            self._detector_config = kwargs.pop('detector_config', {})
            self._init_cond = kwargs.pop('init_cond', 'Sine')
            self._init_config = kwargs.pop('init_config', {})
            self._limiter = kwargs.pop('limiter', 'ModifiedMinMod')
            self._limiter_config = kwargs.pop('limiter_config', {})
            self._quadrature = kwargs.pop('quadrature', 'Gauss')
            self._quadrature_config = kwargs.pop('quadrature_config', {})
            self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3')
            self._basis = OrthonormalLegendre(kwargs.pop('polynomial_degree', 2))
    
            # Initialize mesh with two ghost cells on each side
            self._mesh = Mesh(num_grid_cells=kwargs.pop('num_grid_cells', 64),
                              left_bound=kwargs.pop('left_bound', -1),
                              right_bound=kwargs.pop('right_bound', 1),
                              num_ghost_cells=2)
            # print(len(self._mesh.cells))
            # print(type(self._mesh.cells))
    
            # Throw an error if there are extra keyword arguments
            if len(kwargs) > 0:
                extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
                raise ValueError('Unrecognized arguments: %s' % extra)
    
            # Make sure all classes actually exist
            if not hasattr(Troubled_Cell_Detector, self._detector):
                raise ValueError('Invalid detector: "%s"' % self._detector)
            if not hasattr(Initial_Condition, self._init_cond):
                raise ValueError('Invalid initial condition: "%s"'
                                 % self._init_cond)
            if not hasattr(Limiter, self._limiter):
                raise ValueError('Invalid limiter: "%s"' % self._limiter)
            if not hasattr(Quadrature, self._quadrature):
                raise ValueError('Invalid quadrature: "%s"' % self._quadrature)
            if not hasattr(Update_Scheme, self._update_scheme):
                raise ValueError('Invalid update scheme: "%s"'
                                 % self._update_scheme)
    
            self._reset()
    
            # Replace the string names with the actual class instances
            # (and add the instance variables for the quadrature)
            self._init_cond = getattr(Initial_Condition, self._init_cond)(
                config=self._init_config)
            self._limiter = getattr(Limiter, self._limiter)(
                config=self._limiter_config)
            self._quadrature = getattr(Quadrature, self._quadrature)(
                config=self._quadrature_config)
            self._detector = getattr(Troubled_Cell_Detector, self._detector)(
                config=self._detector_config, mesh=self._mesh, basis=self._basis)
            self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
                polynomial_degree=self._basis.polynomial_degree,
                num_grid_cells=self._mesh.num_grid_cells, detector=self._detector,
                limiter=self._limiter)
    
        def approximate(self, data_file):
            """Approximates projection.
    
            Initializes projection and evolves it in time. Each time step consists
            of three parts: A projection update, a troubled-cell detection,
            and limiting based on the detected cells.
    
            At final time, results are saved in JSON file.
    
            Attributes
            ----------
            data_file: str
                Path to file in which data will be saved.
    
            """
            projection = do_initial_projection(
                initial_condition=self._init_cond, mesh=self._mesh,
                basis=self._basis, quadrature=self._quadrature)
    
            time_step = abs(self._cfl_number * self._mesh.cell_len /
                            self._wave_speed)
    
            current_time = 0
            iteration = 0
            troubled_cell_history = []
            time_history = []
            while current_time < self._final_time:
                # Adjust for last cell
                cfl_number = self._cfl_number
                if current_time+time_step > self._final_time:
                    time_step = self._final_time-current_time
                    cfl_number = self._wave_speed * time_step / self._mesh.cell_len
    
                # Update projection
                projection, troubled_cells = self._update_scheme.step(projection,
                                                                      cfl_number)
                iteration += 1
    
                if (iteration % self._history_threshold) == 0:
                    troubled_cell_history.append(troubled_cells)
                    time_history.append(current_time)
    
                current_time += time_step
    
            # Save detector-specific data in dictionary
            approx_stats = self._detector.create_data_dict(projection)
    
            # Save approximation results in dictionary
            approx_stats['wave_speed'] = self._wave_speed
            approx_stats['final_time'] = self._final_time
            approx_stats['time_history'] = time_history
            approx_stats['troubled_cell_history'] = troubled_cell_history
    
            # Encode all ndarrays to fit JSON format
            approx_stats = {key: encode_ndarray(approx_stats[key])
                            for key in approx_stats.keys()}
    
            # Save approximation results in JSON format
            with open(data_file + '.json', 'w') \
                    as json_file:
                json_file.write(json.dumps(approx_stats))
    
        def _reset(self):
            """Resets instance variables."""
            # Set additional necessary config parameters
            self._limiter_config['cell_len'] = self._mesh.cell_len
    
    
    def do_initial_projection(initial_condition, mesh, basis, quadrature,
                              adjustment=0):
        """Calculates initial projection.
    
        Calculates a projection at time step 0 and adds ghost cells on both
        sides of the array.
    
        Parameters
        ----------
        initial_condition : InitialCondition object
            Initial condition used for calculation.
        mesh : Mesh
            Mesh for calculation.
        basis: Basis object
            Basis used for calculation.
        quadrature: Quadrature object
            Quadrature used for evaluation.
        adjustment: float, optional
            Extent of adjustment of each evaluation point in x-direction.
            Default: 0.
    
        Returns
        -------
        ndarray
            Matrix containing projection of size (N+2, p+1) with N being the
            number of grid cells and p being the polynomial degree of the basis.
    
        """
        # Initialize matrix and set first entry to accommodate for ghost cell
        output_matrix = [0]
    
        for eval_point in mesh.non_ghost_cells:
            new_row = []
            for degree in range(basis.polynomial_degree + 1):
                new_row.append(np.float64(sum(initial_condition.calculate(
                    x=eval_point + mesh.cell_len/2
                    * quadrature.nodes[point] - adjustment)
                    * basis.basis[degree].subs(
                        x, quadrature.nodes[point])
                    * quadrature.weights[point]
                    for point in range(quadrature.num_nodes))))
    
            new_row = np.array(new_row)
            output_matrix.append(basis.inverse_mass_matrix @ new_row)
    
        # Set ghost cells to respective value
        output_matrix[0] = output_matrix[mesh.num_grid_cells]
        output_matrix.append(output_matrix[1])
    
        # print(np.array(output_matrix).shape)
        return np.transpose(np.array(output_matrix))