Skip to content
Snippets Groups Projects
Commit 39666197 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Changed layer naming temporarily for consistency with older models.

parent 9441751e
Branches
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Combine all ThreeLayerNet classes in one class -> Done
TODO: Change naming temporarily for consistency with older models -> Done (later)
INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules
......@@ -30,15 +31,15 @@ class ThreeLayerReLu(torch.nn.Module):
self._name = self.__class__.__name__ + '_' + str(first_hidden_size) + '_' + str(second_hidden_size) + '_'\
+ activation_function
self._input_layer = torch.nn.Linear(input_size, first_hidden_size)
self._first_hidden_layer = torch.nn.Linear(first_hidden_size, second_hidden_size)
self._second_hidden_layer = torch.nn.Linear(second_hidden_size, output_size)
self.input_linear = torch.nn.Linear(input_size, first_hidden_size)
self.middle_linear = torch.nn.Linear(first_hidden_size, second_hidden_size)
self.output_linear = torch.nn.Linear(second_hidden_size, output_size)
self._output_layer = getattr(torch.nn.modules.activation, activation_function)(**activation_config)
def forward(self, input_data):
prediction = self._input_layer(input_data).clamp(min=0)
prediction = self._first_hidden_layer(prediction).clamp(min=0)
prediction = self._second_hidden_layer(prediction)
prediction = self.input_linear(input_data).clamp(min=0)
prediction = self.middle_linear(prediction).clamp(min=0)
prediction = self.output_linear(prediction)
prediction = self._output_layer(prediction)
return prediction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment