Skip to content
Snippets Groups Projects
Select Git revision
  • da5e745a9d7e1d745c5ff5e2704a22c1c982f4e9
  • master default protected
2 results

Basis_Function.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_Model.py 2.90 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules
    
    """
    import torch
    
    
    class ThreeLayerReLu(torch.nn.Module):
        """Class for a fully-connected, three-layered ANN with ReLu.
    
        Predicts score for each output class using the following structure:
    
        Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function
    
        Attributes
        ----------
        name : str
            String containing name of model.
        input_linear : torch.nn.Module
            Linear input layer.
        middle_linear : torch.nn.Module
            Linear middle layer.
        output_linear : torch.nn.Module
            Linear output layer.
        output_layer : torch.nn.Module
            Activation layer for output calculation.
    
        Methods
        -------
        forward(input_data)
            Executes forward propagation.
        get_name()
            Returns string of model name.
    
        """
        def __init__(self, config):
            """Initializes ThreeLayerReLu.
    
            Parameters
            ----------
            config : dict
                Additional parameters for model.
    
            """
            super().__init__()
    
            input_size = config.pop('input_size', 5)
            first_hidden_size = config.pop('first_hidden_size', 8)
            second_hidden_size = config.pop('second_hidden_size', 4)
            output_size = config.pop('output_size', 2)
            activation_function = config.pop('activation_function', 'Sigmoid')
            activation_config = config.pop('activation_config', {})
    
            if not hasattr(torch.nn.modules.activation, activation_function):
                raise ValueError('Invalid activation function: "%s"'
                                 % activation_function)
    
            self._name = self.__class__.__name__ + '_' + str(first_hidden_size) + \
                '_' + str(second_hidden_size) + '_' + activation_function
    
            self.input_linear = torch.nn.Linear(input_size, first_hidden_size)
            self.middle_linear = torch.nn.Linear(first_hidden_size,
                                                 second_hidden_size)
            self.output_linear = torch.nn.Linear(second_hidden_size, output_size)
            self._output_layer = getattr(torch.nn.modules.activation,
                                         activation_function)(**activation_config)
    
        def forward(self, input_data):
            """Executes forward propagation.
    
            Parameters
            ----------
            input_data : ndarray
                2D array containing input data.
    
            Returns
            -------
            prediction : ndarray
                Matrix containing predicted output data.
    
            """
            prediction = self.input_linear(input_data).clamp(min=0)
            prediction = self.middle_linear(prediction).clamp(min=0)
            prediction = self.output_linear(prediction)
            prediction = self._output_layer(prediction)
            return prediction
    
        def get_name(self):
            """Returns string of model name."""
            return self._name