Source code for garage.tf.models.categorical_lstm_model

"""Categorical LSTM Model.

A model represented by a Categorical distribution
which is parameterized by a Long short-term memory (LSTM).
"""
import tensorflow as tf
import tensorflow_probability as tfp

from garage.tf.models.lstm_model import LSTMModel


[docs]class CategoricalLSTMModel(LSTMModel): """Categorical LSTM Model. A model represented by a Categorical distribution which is parameterized by a Long short-term memory (LSTM). Args: output_dim (int): Dimension of the network output. hidden_dim (int): Hidden dimension for LSTM cell. name (str): Policy name, also the variable scope. hidden_nonlinearity (callable): Activation function for intermediate dense layer(s). It should return a tf.Tensor. Set it to None to maintain a linear activation. hidden_w_init (callable): Initializer function for the weight of intermediate dense layer(s). The function should return a tf.Tensor. hidden_b_init (callable): Initializer function for the bias of intermediate dense layer(s). The function should return a tf.Tensor. recurrent_nonlinearity (callable): Activation function for recurrent layers. It should return a tf.Tensor. Set it to None to maintain a linear activation. recurrent_w_init (callable): Initializer function for the weight of recurrent layer(s). The function should return a tf.Tensor. output_nonlinearity (callable): Activation function for output dense layer. It should return a tf.Tensor. Set it to None to maintain a linear activation. output_w_init (callable): Initializer function for the weight of output dense layer(s). The function should return a tf.Tensor. output_b_init (callable): Initializer function for the bias of output dense layer(s). The function should return a tf.Tensor. hidden_state_init (callable): Initializer function for the initial hidden state. The functino should return a tf.Tensor. hidden_state_init_trainable (bool): Bool for whether the initial hidden state is trainable. cell_state_init (callable): Initializer function for the initial cell state. The functino should return a tf.Tensor. cell_state_init_trainable (bool): Bool for whether the initial cell state is trainable. forget_bias (bool): If True, add 1 to the bias of the forget gate at initialization. It's used to reduce the scale of forgetting at the beginning of the training. layer_normalization (bool): Bool for using layer normalization or not. """ def __init__(self, output_dim, hidden_dim, name=None, hidden_nonlinearity=tf.nn.tanh, hidden_w_init=tf.initializers.glorot_uniform(), hidden_b_init=tf.zeros_initializer(), recurrent_nonlinearity=tf.nn.sigmoid, recurrent_w_init=tf.initializers.glorot_uniform(), output_nonlinearity=None, output_w_init=tf.initializers.glorot_uniform(), output_b_init=tf.zeros_initializer(), hidden_state_init=tf.zeros_initializer(), hidden_state_init_trainable=False, cell_state_init=tf.zeros_initializer(), cell_state_init_trainable=False, forget_bias=True, layer_normalization=False): super().__init__( output_dim=output_dim, hidden_dim=hidden_dim, name=name, hidden_nonlinearity=hidden_nonlinearity, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, recurrent_nonlinearity=recurrent_nonlinearity, recurrent_w_init=recurrent_w_init, output_nonlinearity=tf.nn.softmax, output_w_init=output_w_init, output_b_init=output_b_init, hidden_state_init=hidden_state_init, hidden_state_init_trainable=hidden_state_init_trainable, cell_state_init=cell_state_init, cell_state_init_trainable=cell_state_init_trainable, forget_bias=forget_bias, layer_normalization=layer_normalization) self._output_normalization_fn = output_nonlinearity
[docs] def network_output_spec(self): """Network output spec. Returns: list[str]: Name of the model outputs, in order. """ return [ 'dist', 'step_output', 'step_hidden', 'step_cell', 'init_hidden', 'init_cell' ]
# pylint: disable=arguments-differ def _build(self, state_input, step_input, step_hidden, step_cell, name=None): """Build model. Args: state_input (tf.Tensor): Entire time-series observation input, with shape :math:`(N, T, S^*)`. step_input (tf.Tensor): Single timestep observation input, with shape :math:`(N, S^*)`. step_hidden (tf.Tensor): Hidden state for step, with shape :math:`(N, S^*)`. step_cell (tf.Tensor): Cell state for step, with shape :math:`(N, S^*)`. name (str): Inner model name, also the variable scope of the inner model, if exist. One example is garage.tf.models.Sequential. Returns: tfp.distributions.OneHotCategorical: Policy distribution. tf.Tensor: Step output, with shape :math:`(N, S^*)` tf.Tensor: Step hidden state, with shape :math:`(N, S^*)` tf.Tensor: Step cell state, with shape :math:`(N, S^*)` tf.Tensor: Initial hidden state, used to reset the hidden state when policy resets. Shape: :math:`(S^*)` tf.Tensor: Initial cell state, used to reset the cell state when policy resets. Shape: :math:`(S^*)` """ (outputs, step_output, step_hidden, step_cell, init_hidden, init_cell) = super()._build(state_input, step_input, step_hidden, step_cell, name=name) if self._output_normalization_fn: outputs = self._output_normalization_fn(outputs) dist = tfp.distributions.OneHotCategorical(probs=outputs) return (dist, step_output, step_hidden, step_cell, init_hidden, init_cell)