Skip to content
Snippets Groups Projects
Commit 878327a9 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added documentation to 'ANN_Model'.

parent a34d6f4f
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,36 @@ INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules
import torch
# Define Neural Network
# Model with Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function
class ThreeLayerReLu(torch.nn.Module):
"""Class for a fully-connected, three-layered ANN with ReLu.
Predicts score for each output class using the following structure:
Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function
Attributes
----------
name: str
String containing name of model.
input_linear: torch.nn.Module
Linear input layer.
middle_linear: torch.nn.Module
Linear middle layer.
output_linear: torch.nn.Module
Linear output layer.
output_layer: torch.nn.Module
Activation layer for output calculation.
"""
def __init__(self, config):
"""Initializes ThreeLayerReLu.
Parameters
----------
config : dict
Additional parameters for model.
"""
super().__init__()
input_size = config.pop('input_size', 5)
......@@ -35,6 +60,19 @@ class ThreeLayerReLu(torch.nn.Module):
**activation_config)
def forward(self, input_data):
"""Executes forward propagation.
Parameters
----------
input_data: ndarray
2D array containing input data.
Returns
-------
prediction: ndarray
Matrix containing predicted output data.
"""
prediction = self.input_linear(input_data).clamp(min=0)
prediction = self.middle_linear(prediction).clamp(min=0)
prediction = self.output_linear(prediction)
......@@ -42,4 +80,5 @@ class ThreeLayerReLu(torch.nn.Module):
return prediction
def get_name(self):
"""Returns string of model name."""
return self._name
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment