"""LSTM Model.
A model composed only of a long-short term memory (LSTM).
"""
import tensorflow as tf
from garage.tf.models.base import Model
from garage.tf.models.lstm import lstm
[docs]class LSTMModel(Model):
"""LSTM Model.
Args:
output_dim (int): Dimension of the network output.
hidden_dim (int): Hidden dimension for LSTM cell.
name (str): Policy name, also the variable scope.
hidden_nonlinearity (callable): Activation function for intermediate
dense layer(s). It should return a tf.Tensor. Set it to
None to maintain a linear activation.
hidden_w_init (callable): Initializer function for the weight
of intermediate dense layer(s). The function should return a
tf.Tensor.
hidden_b_init (callable): Initializer function for the bias
of intermediate dense layer(s). The function should return a
tf.Tensor.
recurrent_nonlinearity (callable): Activation function for recurrent
layers. It should return a tf.Tensor. Set it to None to
maintain a linear activation.
recurrent_w_init (callable): Initializer function for the weight
of recurrent layer(s). The function should return a
tf.Tensor.
output_nonlinearity (callable): Activation function for output dense
layer. It should return a tf.Tensor. Set it to None to
maintain a linear activation.
output_w_init (callable): Initializer function for the weight
of output dense layer(s). The function should return a
tf.Tensor.
output_b_init (callable): Initializer function for the bias
of output dense layer(s). The function should return a
tf.Tensor.
hidden_state_init (callable): Initializer function for the
initial hidden state. The functino should return a tf.Tensor.
hidden_state_init_trainable (bool): Bool for whether the initial
hidden state is trainable.
cell_state_init (callable): Initializer function for the
initial cell state. The functino should return a tf.Tensor.
cell_state_init_trainable (bool): Bool for whether the initial
cell state is trainable.
forget_bias (bool): If True, add 1 to the bias of the forget gate at
initialization. It's used to reduce the scale of forgetting at the
beginning of the training.
layer_normalization (bool): Bool for using layer normalization or not.
"""
def __init__(self,
output_dim,
hidden_dim,
name=None,
hidden_nonlinearity=tf.nn.tanh,
hidden_w_init=tf.glorot_uniform_initializer(),
hidden_b_init=tf.zeros_initializer(),
recurrent_nonlinearity=tf.nn.sigmoid,
recurrent_w_init=tf.glorot_uniform_initializer(),
output_nonlinearity=None,
output_w_init=tf.glorot_uniform_initializer(),
output_b_init=tf.zeros_initializer(),
hidden_state_init=tf.zeros_initializer(),
hidden_state_init_trainable=False,
cell_state_init=tf.zeros_initializer(),
cell_state_init_trainable=False,
forget_bias=True,
layer_normalization=False):
super().__init__(name)
self._output_dim = output_dim
self._hidden_dim = hidden_dim
self._forget_bias = forget_bias
self._hidden_nonlinearity = hidden_nonlinearity
self._hidden_w_init = hidden_w_init
self._hidden_b_init = hidden_b_init
self._recurrent_nonlinearity = recurrent_nonlinearity
self._recurrent_w_init = recurrent_w_init
self._hidden_state_init = hidden_state_init
self._hidden_state_init_trainable = hidden_state_init_trainable
self._cell_state_init = cell_state_init
self._cell_state_init_trainable = cell_state_init_trainable
self._output_nonlinearity = output_nonlinearity
self._output_w_init = output_w_init
self._output_b_init = output_b_init
self._layer_normalization = layer_normalization
self._initialize()
def _initialize(self):
self._lstm_cell = tf.keras.layers.LSTMCell(
units=self._hidden_dim,
activation=self._hidden_nonlinearity,
kernel_initializer=self._hidden_w_init,
bias_initializer=self._hidden_b_init,
recurrent_activation=self._recurrent_nonlinearity,
recurrent_initializer=self._recurrent_w_init,
unit_forget_bias=self._forget_bias,
name='lstm_layer')
self._output_nonlinearity_layer = tf.keras.layers.Dense(
units=self._output_dim,
activation=self._output_nonlinearity,
kernel_initializer=self._output_w_init,
bias_initializer=self._output_b_init,
name='output_layer')
[docs] def network_output_spec(self):
"""Network output spec."""
return [
'all_output', 'step_output', 'step_hidden', 'step_cell',
'init_hidden', 'init_cell'
]
def _build(self,
all_input_var,
step_input_var,
step_hidden_var,
step_cell_var,
name=None):
return lstm(
name='lstm',
lstm_cell=self._lstm_cell,
all_input_var=all_input_var,
step_input_var=step_input_var,
step_hidden_var=step_hidden_var,
step_cell_var=step_cell_var,
hidden_state_init=self._hidden_state_init,
hidden_state_init_trainable=self._hidden_state_init_trainable,
cell_state_init=self._cell_state_init,
cell_state_init_trainable=self._cell_state_init_trainable,
output_nonlinearity_layer=self._output_nonlinearity_layer)
def __getstate__(self):
"""Object.__getstate__."""
new_dict = super().__getstate__()
del new_dict['_lstm_cell']
del new_dict['_output_nonlinearity_layer']
return new_dict
def __setstate__(self, state):
"""Object.__setstate__."""
super().__setstate__(state)
self._initialize()