# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)

INFO: /home/laura/anaconda3/lib/python3.7/site-packages/torch/nn/modules

"""
import torch


class ThreeLayerReLu(torch.nn.Module):
    """Class for a fully-connected, three-layered ANN with ReLu.

    Predicts score for each output class using the following structure:

    Linear -> ReLu -> Linear -> ReLu -> Linear -> any activation function

    Attributes
    ----------
    name : str
        String containing name of model.
    input_linear : torch.nn.Module
        Linear input layer.
    middle_linear : torch.nn.Module
        Linear middle layer.
    output_linear : torch.nn.Module
        Linear output layer.
    output_layer : torch.nn.Module
        Activation layer for output calculation.

    Methods
    -------
    forward(input_data)
        Executes forward propagation.
    get_name()
        Returns string of model name.

    """
    def __init__(self, config):
        """Initializes ThreeLayerReLu.

        Parameters
        ----------
        config : dict
            Additional parameters for model.

        """
        super().__init__()

        input_size = config.pop('input_size', 5)
        first_hidden_size = config.pop('first_hidden_size', 8)
        second_hidden_size = config.pop('second_hidden_size', 4)
        output_size = config.pop('output_size', 2)
        activation_function = config.pop('activation_function', 'Sigmoid')
        activation_config = config.pop('activation_config', {})

        if not hasattr(torch.nn.modules.activation, activation_function):
            raise ValueError('Invalid activation function: "%s"'
                             % activation_function)

        self._name = self.__class__.__name__ + '_' + str(first_hidden_size) + \
            '_' + str(second_hidden_size) + '_' + activation_function

        self.input_linear = torch.nn.Linear(input_size, first_hidden_size)
        self.middle_linear = torch.nn.Linear(first_hidden_size,
                                             second_hidden_size)
        self.output_linear = torch.nn.Linear(second_hidden_size, output_size)
        self._output_layer = getattr(torch.nn.modules.activation,
                                     activation_function)(**activation_config)

    def forward(self, input_data):
        """Executes forward propagation.

        Parameters
        ----------
        input_data : ndarray
            2D array containing input data.

        Returns
        -------
        prediction : ndarray
            Matrix containing predicted output data.

        """
        prediction = self.input_linear(input_data).clamp(min=0)
        prediction = self.middle_linear(prediction).clamp(min=0)
        prediction = self.output_linear(prediction)
        prediction = self._output_layer(prediction)
        return prediction

    def get_name(self):
        """Returns string of model name."""
        return self._name