"""Regressor base classes without Parameterized."""
import abc
import tensorflow as tf
from garage.misc.tensor_utils import flatten_tensors
from garage.misc.tensor_utils import unflatten_tensors
[docs]class Regressor(abc.ABC):
"""Regressor base class.
Args:
input_shape (tuple[int]): Input shape.
output_dim (int): Output dimension.
name (str): Name of the regressor.
"""
def __init__(self, input_shape, output_dim, name):
self._input_shape = input_shape
self._output_dim = output_dim
self._name = name
self._variable_scope = None
self._cached_params = {}
self._cached_param_shapes = {}
[docs] def fit(self, xs, ys):
"""Fit with input data xs and label ys.
Args:
xs (numpy.ndarray): Input data.
ys (numpy.ndarray): Label of input data.
"""
[docs] def predict(self, xs):
"""Predict ys based on input xs.
Args:
xs (numpy.ndarray): Input data.
Return:
The predicted ys.
"""
[docs] def get_params_internal(self, **tags):
"""Get the list of parameters.
This internal method does not perform caching, and should
be implemented by subclasses.
Return:
A list of trainable variables of type list(tf.Variable)
"""
[docs] def get_params(self, **tags):
"""Get the list of parameters, filtered by the provided tags.
Args:
tags (dict): Some common tags include 'regularizable' and
'trainable'
"""
tag_tuple = tuple(sorted(list(tags.items()), key=lambda x: x[0]))
if tag_tuple not in self._cached_params:
self._cached_params[tag_tuple] = self.get_params_internal(**tags)
return self._cached_params[tag_tuple]
[docs] def get_param_shapes(self, **tags):
"""Get the list of shapes for the parameters.
Args:
tags (dict): Some common tags include 'regularizable' and
'trainable'
Returns:
List[tuple[int]]: A list of shapes of each parameter.
"""
tag_tuple = tuple(sorted(list(tags.items()), key=lambda x: x[0]))
if tag_tuple not in self._cached_param_shapes:
params = self.get_params(**tags)
param_values = tf.compat.v1.get_default_session().run(params)
self._cached_param_shapes[tag_tuple] = [
val.shape for val in param_values
]
return self._cached_param_shapes[tag_tuple]
[docs] def get_param_values(self, **tags):
"""Get the list of values for the parameters.
Args:
tags (dict): Some common tags include 'regularizable' and
'trainable'
Returns:
List[np.ndarray]: A list of values of each parameter.
"""
params = self.get_params(**tags)
param_values = tf.compat.v1.get_default_session().run(params)
return flatten_tensors(param_values)
[docs] def set_param_values(self, flattened_params, name=None, **tags):
"""Set the values for the parameters.
Args:
tags (dict): Some common tags include 'regularizable' and
'trainable'
"""
with tf.name_scope(name, 'set_param_values', [flattened_params]):
param_values = unflatten_tensors(flattened_params,
self.get_param_shapes(**tags))
for param, value in zip(self.get_params(**tags), param_values):
param.load(value)
[docs] def flat_to_params(self, flattened_params, **tags):
"""Unflatten tensors according to their respective shapes.
Args:
flattened_params (np.ndarray): A numpy array of flattened params.
tags (dict): Some common tags include 'regularizable' and
'trainable'
Returns:
tensors (List[np.ndarray]): A list of parameters reshaped to the
shapes specified.
"""
return unflatten_tensors(flattened_params,
self.get_param_shapes(**tags))
def __getstate__(self):
"""Object.__getstate__."""
new_dict = self.__dict__.copy()
del new_dict['_cached_params']
return new_dict
def __setstate__(self, state):
"""Object.__setstate__."""
self._cached_params = {}
self.__dict__.update(state)
[docs]class StochasticRegressor(Regressor):
"""StochasticRegressor base class.
Args:
input_shape (tuple[int]): Input shape.
output_dim (int): Output dimension.
name (str): Name of the regressor.
"""
def __init__(self, input_shape, output_dim, name):
super().__init__(input_shape, output_dim, name)
[docs] def log_likelihood_sym(self, x_var, y_var, name=None):
"""Symbolic graph of the log likelihood.
Args:
x_var (tf.Tensor): Input tf.Tensor for the input data.
y_var (tf.Tensor): Input tf.Tensor for the label of data.
name (str): Name of the new graph.
Return:
tf.Tensor output of the symbolic log likelihood.
"""
[docs] def dist_info_sym(self, x_var, name=None):
"""Symbolic graph of the distribution.
Args:
x_var (tf.Tensor): Input tf.Tensor for the input data.
name (str): Name of the new graph.
Return:
tf.Tensor output of the symbolic distribution.
"""