Source code for garage.torch.modules.mlp_module

"""MLP Module."""

from torch import nn
from torch.nn import functional as F

from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule


[docs]class MLPModule(MultiHeadedMLPModule): """MLP Model. A Pytorch module composed only of a multi-layer perceptron (MLP), which maps real-valued inputs to real-valued outputs. Args: input_dim (int) : Dimension of the network input. output_dim (int): Dimension of the network output. hidden_sizes (list[int]): Output dimension of dense layer(s). For example, (32, 32) means this MLP consists of two hidden layers, each with 32 hidden units. hidden_nonlinearity (callable or torch.nn.Module): Activation function for intermediate dense layer(s). It should return a torch.Tensor. Set it to None to maintain a linear activation. hidden_w_init (callable): Initializer function for the weight of intermediate dense layer(s). The function should return a torch.Tensor. hidden_b_init (callable): Initializer function for the bias of intermediate dense layer(s). The function should return a torch.Tensor. output_nonlinearity (callable or torch.nn.Module): Activation function for output dense layer. It should return a torch.Tensor. Set it to None to maintain a linear activation. output_w_init (callable): Initializer function for the weight of output dense layer(s). The function should return a torch.Tensor. output_b_init (callable): Initializer function for the bias of output dense layer(s). The function should return a torch.Tensor. layer_normalization (bool): Bool for using layer normalization or not. """ def __init__(self, input_dim, output_dim, hidden_sizes, hidden_nonlinearity=F.relu, hidden_w_init=nn.init.xavier_normal_, hidden_b_init=nn.init.zeros_, output_nonlinearity=None, output_w_init=nn.init.xavier_normal_, output_b_init=nn.init.zeros_, layer_normalization=False): super().__init__(1, input_dim, output_dim, hidden_sizes, hidden_nonlinearity, hidden_w_init, hidden_b_init, output_nonlinearity, output_w_init, output_b_init, layer_normalization) # pylint: disable=arguments-differ
[docs] def forward(self, input_value): """Forward method. Args: input_value (torch.Tensor): Input values with (N, *, input_dim) shape. Returns: torch.Tensor: Output value """ return super().forward(input_value)[0]