From 878327a9f058fc05336c3d9cd62661065923f7d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 15 Feb 2022 21:55:18 +0100 Subject: [PATCH] Added documentation to 'ANN_Model'. --- ANN_Model.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/ANN_Model.py b/ANN_Model.py index 7347911..350eab2 100644 --- a/ANN_Model.py +++ b/ANN_Model.py @@ -8,11 +8,36 @@ INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules import torch -# Define Neural Network - -# Model with Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function class ThreeLayerReLu(torch.nn.Module): + """Class for a fully-connected, three-layered ANN with ReLu. + + Predicts score for each output class using the following structure: + + Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function + + Attributes + ---------- + name: str + String containing name of model. + input_linear: torch.nn.Module + Linear input layer. + middle_linear: torch.nn.Module + Linear middle layer. + output_linear: torch.nn.Module + Linear output layer. + output_layer: torch.nn.Module + Activation layer for output calculation. + + """ def __init__(self, config): + """Initializes ThreeLayerReLu. + + Parameters + ---------- + config : dict + Additional parameters for model. + + """ super().__init__() input_size = config.pop('input_size', 5) @@ -35,6 +60,19 @@ class ThreeLayerReLu(torch.nn.Module): **activation_config) def forward(self, input_data): + """Executes forward propagation. + + Parameters + ---------- + input_data: ndarray + 2D array containing input data. + + Returns + ------- + prediction: ndarray + Matrix containing predicted output data. + + """ prediction = self.input_linear(input_data).clamp(min=0) prediction = self.middle_linear(prediction).clamp(min=0) prediction = self.output_linear(prediction) @@ -42,4 +80,5 @@ class ThreeLayerReLu(torch.nn.Module): return prediction def get_name(self): + """Returns string of model name.""" return self._name -- GitLab