Source code for garage.torch.modules.multi_headed_mlp_module

"""MultiHeadedMLPModule."""
import copy

import torch
import torch.nn as nn


[docs]class MultiHeadedMLPModule(nn.Module): """MultiHeadedMLPModule Model. A PyTorch module composed only of a multi-layer perceptron (MLP) with multiple parallel output layers which maps real-valued inputs to real-valued outputs. The length of outputs is n_heads and shape of each output element is depend on each output dimension Args: n_heads (int): Number of different output layers input_dim (int): Dimension of the network input. output_dims (int or list or tuple): Dimension of the network output. hidden_sizes (list[int]): Output dimension of dense layer(s). For example, (32, 32) means this MLP consists of two hidden layers, each with 32 hidden units. hidden_nonlinearity (callable or torch.nn.Module or list or tuple): Activation function for intermediate dense layer(s). It should return a torch.Tensor. Set it to None to maintain a linear activation. hidden_w_init (callable): Initializer function for the weight of intermediate dense layer(s). The function should return a torch.Tensor. hidden_b_init (callable): Initializer function for the bias of intermediate dense layer(s). The function should return a torch.Tensor. output_nonlinearities (callable or torch.nn.Module or list or tuple): Activation function for output dense layer. It should return a torch.Tensor. Set it to None to maintain a linear activation. Size of the parameter should be 1 or equal to n_head output_w_inits (callable or list or tuple): Initializer function for the weight of output dense layer(s). The function should return a torch.Tensor. Size of the parameter should be 1 or equal to n_head output_b_inits (callable or list or tuple): Initializer function for the bias of output dense layer(s). The function should return a torch.Tensor. Size of the parameter should be 1 or equal to n_head layer_normalization (bool): Bool for using layer normalization or not. """ def __init__(self, n_heads, input_dim, output_dims, hidden_sizes, hidden_nonlinearity=torch.relu, hidden_w_init=nn.init.xavier_normal_, hidden_b_init=nn.init.zeros_, output_nonlinearities=None, output_w_inits=nn.init.xavier_normal_, output_b_inits=nn.init.zeros_, layer_normalization=False): super().__init__() self._layers = nn.ModuleList() output_dims = self._check_parameter_for_output_layer( 'output_dims', output_dims, n_heads) output_w_inits = self._check_parameter_for_output_layer( 'output_w_inits', output_w_inits, n_heads) output_b_inits = self._check_parameter_for_output_layer( 'output_b_inits', output_b_inits, n_heads) output_nonlinearities = self._check_parameter_for_output_layer( 'output_nonlinearities', output_nonlinearities, n_heads) self._layers = nn.ModuleList() prev_size = input_dim for size in hidden_sizes: hidden_layers = nn.Sequential() if layer_normalization: hidden_layers.add_module('layer_normalization', nn.LayerNorm(prev_size)) linear_layer = nn.Linear(prev_size, size) hidden_w_init(linear_layer.weight) hidden_b_init(linear_layer.bias) hidden_layers.add_module('linear', linear_layer) if hidden_nonlinearity: hidden_layers.add_module('non_linearity', _NonLinearity(hidden_nonlinearity)) self._layers.append(hidden_layers) prev_size = size self._output_layers = nn.ModuleList() for i in range(n_heads): output_layer = nn.Sequential() linear_layer = nn.Linear(prev_size, output_dims[i]) output_w_inits[i](linear_layer.weight) output_b_inits[i](linear_layer.bias) output_layer.add_module('linear', linear_layer) if output_nonlinearities[i]: output_layer.add_module( 'non_linearity', _NonLinearity(output_nonlinearities[i])) self._output_layers.append(output_layer) @classmethod def _check_parameter_for_output_layer(cls, var_name, var, n_heads): """Check input parameters for output layer are valid. Args: var_name (str): variable name var (any): variable to be checked n_heads (int): number of head Returns: list: list of variables (length of n_heads) Raises: ValueError: if the variable is a list but length of the variable is not equal to n_heads """ if isinstance(var, (list, tuple)): if len(var) == 1: return list(var) * n_heads if len(var) == n_heads: return var msg = ('{} should be either an integer or a collection of length ' 'n_heads ({}), but {} provided.') raise ValueError(msg.format(var_name, n_heads, var)) return [copy.deepcopy(var) for _ in range(n_heads)] # pylint: disable=arguments-differ
[docs] def forward(self, input_val): """Forward method. Args: input_val (torch.Tensor): Input values with (N, *, input_dim) shape. Returns: List[torch.Tensor]: Output values """ x = input_val for layer in self._layers: x = layer(x) return [output_layer(x) for output_layer in self._output_layers]
class _NonLinearity(nn.Module): """Wrapper class for non linear function or module. Args: non_linear (callable or type): Non-linear function or type to be wrapped. """ def __init__(self, non_linear): super().__init__() if isinstance(non_linear, type): self.module = non_linear() elif callable(non_linear): self.module = copy.deepcopy(non_linear) else: raise ValueError( 'Non linear function {} is not supported'.format(non_linear)) # pylint: disable=arguments-differ def forward(self, input_value): """Forward method. Args: input_value (torch.Tensor): Input values Returns: torch.Tensor: Output value """ return self.module(input_value) # pylint: disable=missing-return-doc, missing-return-type-doc def __repr__(self): return repr(self.module)