Skip to content
Snippets Groups Projects
Select Git revision
  • d78720387a0556f22ee2bbf82c29792744c9baa2
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

training.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    DG_Approximation.py 10.41 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle
    
    """
    import json
    import numpy as np
    from sympy import Symbol
    import math
    import seaborn as sns
    
    from . import Troubled_Cell_Detector
    from . import Initial_Condition
    from . import Limiter
    from . import Quadrature
    from . import Update_Scheme
    from .Basis_Function import OrthonormalLegendre
    from .encoding_utils import encode_ndarray
    from .projection_utils import Mesh
    
    x = Symbol('x')
    sns.set()
    
    
    class DGScheme:
        """Class for Discontinuous Galerkin Method.
    
        Approximates linear advection equation using Discontinuous Galerkin Method
        with troubled-cell-based limiting.
    
        Attributes
        ----------
        basis : Basis object
            Basis for calculation.
        mesh : Mesh
            Mesh for calculation.
    
        Methods
        -------
        approximate()
            Approximates projection.
        save_plots()
            Saves plots generated during approximation process.
        build_training_data(x_shift, stencil_length, init_cond=None)
            Builds training data set.
    
        """
        def __init__(self, detector, **kwargs):
            """Initializes DGScheme.
    
            Parameters
            ----------
            detector : str
                Name of troubled cell detector class.
    
            Other Parameters
            ----------------
            wave_speed : float, optional
                Speed of wave in rightward direction. Default: 1.
            polynomial_degree : int, optional
                Polynomial degree. Default: 2.
            cfl_number : float, optional
                CFL number to ensure stability. Default: 0.2.
            num_mesh_cells : int, optional
                Number of cells in the mesh. Usually exponential of 2. Default: 64.
            final_time : float, optional
                Final time for which approximation is calculated. Default: 1.
            left_bound : float, optional
                Left boundary of interval. Default: -1.
            right_bound : float, optional
                Right boundary of interval. Default: 1.
            verbose : bool, optional
                Flag whether commentary in console is wanted. Default: False.
            history_threshold : float, optional
                Threshold when history will be recorded.
                Default: math.ceil(0.2/cfl_number).
            detector_config : dict, optional
                Additional parameters for detector object. Default: {}.
            init_cond : str, optional
                Name of initial condition for evaluation. Default: 'Sine'
            init_config : dict, optional
                Additional parameters for initial condition object. Default: {}.
            limiter : str, optional
                Name of limiter for evaluation. Default: 'ModifiedMinMod'.
            limiter_config : dict, optional
                Additional parameters for limiter. object. Default: {}:
            quadrature : str, optional
                Name of quadrature for evaluation. Default: 'Gauss'.
            quadrature_config : dict, optional
                Additional parameters for quadrature object. Default: {}.
            update_scheme : str, optional
                Name of update scheme for evaluation. Default: 'SSPRK3'.
    
            """
            # Unpack keyword arguments
            self._wave_speed = kwargs.pop('wave_speed', 1)
            self._cfl_number = kwargs.pop('cfl_number', 0.2)
            self._final_time = kwargs.pop('final_time', 1)
            self._verbose = kwargs.pop('verbose', False)
            self._history_threshold = kwargs.pop('history_threshold',
                                                 math.ceil(0.2/self._cfl_number))
            self._detector = detector
            self._detector_config = kwargs.pop('detector_config', {})
            self._init_cond = kwargs.pop('init_cond', 'Sine')
            self._init_config = kwargs.pop('init_config', {})
            self._limiter = kwargs.pop('limiter', 'ModifiedMinMod')
            self._limiter_config = kwargs.pop('limiter_config', {})
            self._quadrature = kwargs.pop('quadrature', 'Gauss')
            self._quadrature_config = kwargs.pop('quadrature_config', {})
            self._update_scheme = kwargs.pop('update_scheme', 'SSPRK3')
            self._basis = OrthonormalLegendre(kwargs.pop('polynomial_degree', 2))
    
            # Initialize mesh with two ghost cells on each side
            self._mesh = Mesh(num_cells=kwargs.pop('num_mesh_cells', 64),
                              left_bound=kwargs.pop('left_bound', -1),
                              right_bound=kwargs.pop('right_bound', 1),
                              num_ghost_cells=2)
            # print(len(self._mesh.cells))
            # print(type(self._mesh.cells))
    
            # Throw an error if there are extra keyword arguments
            if len(kwargs) > 0:
                extra = ', '.join('"%s"' % k for k in list(kwargs.keys()))
                raise ValueError('Unrecognized arguments: %s' % extra)
    
            # Make sure all classes actually exist
            if not hasattr(Troubled_Cell_Detector, self._detector):
                raise ValueError('Invalid detector: "%s"' % self._detector)
            if not hasattr(Initial_Condition, self._init_cond):
                raise ValueError('Invalid initial condition: "%s"'
                                 % self._init_cond)
            if not hasattr(Limiter, self._limiter):
                raise ValueError('Invalid limiter: "%s"' % self._limiter)
            if not hasattr(Quadrature, self._quadrature):
                raise ValueError('Invalid quadrature: "%s"' % self._quadrature)
            if not hasattr(Update_Scheme, self._update_scheme):
                raise ValueError('Invalid update scheme: "%s"'
                                 % self._update_scheme)
    
            self._reset()
    
            # Replace the string names with the actual class instances
            # (and add the instance variables for the quadrature)
            self._init_cond = getattr(Initial_Condition, self._init_cond)(
                config=self._init_config)
            self._limiter = getattr(Limiter, self._limiter)(
                config=self._limiter_config)
            self._quadrature = getattr(Quadrature, self._quadrature)(
                config=self._quadrature_config)
            self._detector = getattr(Troubled_Cell_Detector, self._detector)(
                config=self._detector_config, mesh=self._mesh, basis=self._basis)
            self._update_scheme = getattr(Update_Scheme, self._update_scheme)(
                polynomial_degree=self._basis.polynomial_degree,
                num_cells=self._mesh.num_cells, detector=self._detector,
                limiter=self._limiter, wave_speed=self._wave_speed)
    
        def approximate(self, data_file):
            """Approximates projection.
    
            Initializes projection and evolves it in time. Each time step consists
            of three parts: A projection update, a troubled-cell detection,
            and limiting based on the detected cells.
    
            At final time, results are saved in JSON file.
    
            Attributes
            ----------
            data_file: str
                Path to file in which data will be saved.
    
            """
            projection = do_initial_projection(
                init_cond=self._init_cond, mesh=self._mesh,
                basis=self._basis, quadrature=self._quadrature)
    
            time_step = abs(self._cfl_number * self._mesh.cell_len /
                            self._wave_speed)
    
            current_time = 0
            iteration = 0
            troubled_cell_history = []
            time_history = []
            while current_time < self._final_time:
                # Adjust for last cell
                cfl_number = self._cfl_number
                if current_time+time_step > self._final_time:
                    time_step = self._final_time-current_time
                    cfl_number = self._wave_speed * time_step / self._mesh.cell_len
    
                # Update projection
                projection, troubled_cells = self._update_scheme.step(projection,
                                                                      cfl_number)
                iteration += 1
    
                if (iteration % self._history_threshold) == 0:
                    troubled_cell_history.append(troubled_cells)
                    time_history.append(current_time)
    
                current_time += time_step
    
            # Save detector-specific data in dictionary
            approx_stats = self._detector.create_data_dict(projection)
    
            # Save approximation results in dictionary
            approx_stats['wave_speed'] = self._wave_speed
            approx_stats['final_time'] = self._final_time
            approx_stats['time_history'] = time_history
            approx_stats['troubled_cell_history'] = troubled_cell_history
    
            # Encode all ndarrays to fit JSON format
            approx_stats = {key: encode_ndarray(approx_stats[key])
                            for key in approx_stats.keys()}
    
            # Save approximation results in JSON format
            with open(data_file + '.json', 'w') \
                    as json_file:
                json_file.write(json.dumps(approx_stats))
    
        def _reset(self):
            """Resets instance variables."""
            # Set additional necessary config parameters
            self._limiter_config['cell_len'] = self._mesh.cell_len
    
    
    def do_initial_projection(init_cond, mesh, basis, quadrature,
                              x_shift=0):
        """Calculates initial projection.
    
        Calculates a projection at time step 0 and adds ghost cells on both
        sides of the array.
    
        Parameters
        ----------
        init_cond : InitialCondition object
            Initial condition used for calculation.
        mesh : Mesh
            Mesh for calculation.
        basis: Basis object
            Basis used for calculation.
        quadrature: Quadrature object
            Quadrature used for evaluation.
        x_shift: float, optional
            Shift in x-direction for each evaluation point. Default: 0.
    
        Returns
        -------
        ndarray
            Matrix containing projection of size (N+2, p+1) with N being the
            number of mesh cells and p being the polynomial degree of the basis.
    
        """
        # Initialize matrix and set first entry to accommodate for ghost cell
        output_matrix = [0]
    
        for eval_point in mesh.non_ghost_cells:
            new_row = []
            for degree in range(basis.polynomial_degree + 1):
                new_row.append(np.float64(sum(init_cond.calculate(
                    x=eval_point + mesh.cell_len/2
                    * quadrature.nodes[point] - x_shift, mesh=mesh)
                    * basis.basis[degree].subs(
                        x, quadrature.nodes[point])
                    * quadrature.weights[point]
                    for point in range(quadrature.num_nodes))))
    
            new_row = np.array(new_row)
            output_matrix.append(basis.inverse_mass_matrix @ new_row)
    
        # Set ghost cells to respective value
        output_matrix[0] = output_matrix[mesh.num_cells]
        output_matrix.append(output_matrix[1])
    
        # print(np.array(output_matrix).shape)
        return np.transpose(np.array(output_matrix))