"""Interface for primitives which build on top of models."""
import abc
import tensorflow as tf
from garage.misc.tensor_utils import flatten_tensors, unflatten_tensors
[docs]class Module(abc.ABC):
"""A module that builds on top of model.
Args:
name (str): Module name, also the variable scope.
"""
def __init__(self, name):
self._name = name
self._variable_scope = None
self._cached_params = None
self._cached_param_shapes = None
@property
def name(self):
"""str: Name of this module."""
return self._name
@property
@abc.abstractmethod
def vectorized(self):
"""bool: If this module supports vectorization input."""
[docs] def reset(self, do_resets=None):
"""Reset the module.
This is effective only to recurrent modules. do_resets is effective
only to vectoried modules.
For a vectorized modules, do_resets is an array of boolean indicating
which internal states to be reset. The length of do_resets should be
equal to the length of inputs.
Args:
do_resets (numpy.ndarray): Bool array indicating which states
to be reset.
"""
@property
def state_info_specs(self):
"""State info specification.
Returns:
List[str]: keys and shapes for the information related to the
module's state when taking an action.
"""
return list()
@property
def state_info_keys(self):
"""State info keys.
Returns:
List[str]: keys for the information related to the module's state
when taking an input.
"""
return [k for k, _ in self.state_info_specs]
[docs] def terminate(self):
"""Clean up operation."""
[docs] def get_trainable_vars(self):
"""Get trainable variables.
Returns:
List[tf.Variable]: A list of trainable variables in the current
variable scope.
"""
return self._variable_scope.trainable_variables()
[docs] def get_global_vars(self):
"""Get global variables.
Returns:
List[tf.Variable]: A list of global variables in the current
variable scope.
"""
return self._variable_scope.global_variables()
[docs] def get_params(self):
"""Get the trainable variables.
Returns:
List[tf.Variable]: A list of trainable variables in the current
variable scope.
"""
if self._cached_params is None:
self._cached_params = self.get_trainable_vars()
return self._cached_params
[docs] def get_param_shapes(self):
"""Get parameter shapes.
Returns:
List[tuple]: A list of variable shapes.
"""
if self._cached_param_shapes is None:
params = self.get_params()
param_values = tf.compat.v1.get_default_session().run(params)
self._cached_param_shapes = [val.shape for val in param_values]
return self._cached_param_shapes
[docs] def get_param_values(self):
"""Get param values.
Returns:
np.ndarray: Values of the parameters evaluated in
the current session
"""
params = self.get_params()
param_values = tf.compat.v1.get_default_session().run(params)
return flatten_tensors(param_values)
[docs] def set_param_values(self, param_values):
"""Set param values.
Args:
param_values (np.ndarray): A numpy array of parameter values.
"""
param_values = unflatten_tensors(param_values, self.get_param_shapes())
for param, value in zip(self.get_params(), param_values):
param.load(value)
[docs] def flat_to_params(self, flattened_params):
"""Unflatten tensors according to their respective shapes.
Args:
flattened_params (np.ndarray): A numpy array of flattened params.
Returns:
List[np.ndarray]: A list of parameters reshaped to the
shapes specified.
"""
return unflatten_tensors(flattened_params, self.get_param_shapes())
def __getstate__(self):
"""Object.__getstate__.
Returns:
dict: The state to be pickled for the instance.
"""
new_dict = self.__dict__.copy()
del new_dict['_cached_params']
return new_dict
def __setstate__(self, state):
"""Object.__setstate__.
Args:
state (dict): Unpickled state.
"""
self._cached_params = None
self.__dict__.update(state)
[docs]class StochasticModule(Module):
"""Stochastic Module."""
@property
@abc.abstractmethod
def distribution(self):
"""Distribution."""