Select Git revision
Basis_Function.py
-
Laura Christine Kühle authoredLaura Christine Kühle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ANN_Model.py 2.90 KiB
# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules
"""
import torch
class ThreeLayerReLu(torch.nn.Module):
"""Class for a fully-connected, three-layered ANN with ReLu.
Predicts score for each output class using the following structure:
Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function
Attributes
----------
name : str
String containing name of model.
input_linear : torch.nn.Module
Linear input layer.
middle_linear : torch.nn.Module
Linear middle layer.
output_linear : torch.nn.Module
Linear output layer.
output_layer : torch.nn.Module
Activation layer for output calculation.
Methods
-------
forward(input_data)
Executes forward propagation.
get_name()
Returns string of model name.
"""
def __init__(self, config):
"""Initializes ThreeLayerReLu.
Parameters
----------
config : dict
Additional parameters for model.
"""
super().__init__()
input_size = config.pop('input_size', 5)
first_hidden_size = config.pop('first_hidden_size', 8)
second_hidden_size = config.pop('second_hidden_size', 4)
output_size = config.pop('output_size', 2)
activation_function = config.pop('activation_function', 'Sigmoid')
activation_config = config.pop('activation_config', {})
if not hasattr(torch.nn.modules.activation, activation_function):
raise ValueError('Invalid activation function: "%s"'
% activation_function)
self._name = self.__class__.__name__ + '_' + str(first_hidden_size) + \
'_' + str(second_hidden_size) + '_' + activation_function
self.input_linear = torch.nn.Linear(input_size, first_hidden_size)
self.middle_linear = torch.nn.Linear(first_hidden_size,
second_hidden_size)
self.output_linear = torch.nn.Linear(second_hidden_size, output_size)
self._output_layer = getattr(torch.nn.modules.activation,
activation_function)(**activation_config)
def forward(self, input_data):
"""Executes forward propagation.
Parameters
----------
input_data : ndarray
2D array containing input data.
Returns
-------
prediction : ndarray
Matrix containing predicted output data.
"""
prediction = self.input_linear(input_data).clamp(min=0)
prediction = self.middle_linear(prediction).clamp(min=0)
prediction = self.output_linear(prediction)
prediction = self._output_layer(prediction)
return prediction
def get_name(self):
"""Returns string of model name."""
return self._name