Source code for garage.tf.baselines.gaussian_cnn_baseline

"""Gaussian CNN Baseline."""

import numpy as np

from garage.np.baselines import Baseline
from garage.tf.regressors import GaussianCNNRegressor


[docs]class GaussianCNNBaseline(Baseline): """ GaussianCNNBaseline With Model. It fits the input data to a gaussian distribution estimated by a CNN. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. subsample_factor (float): The factor to subsample the data. By default it is 1.0, which means using all the data. regressor_args (dict): Arguments for regressor. """ def __init__( self, env_spec, subsample_factor=1., regressor_args=None, name='GaussianCNNBaseline', ): super().__init__(env_spec) if regressor_args is None: regressor_args = dict() self._regressor = GaussianCNNRegressor( input_shape=(env_spec.observation_space.shape), output_dim=1, name=name, **regressor_args) self.name = name
[docs] def fit(self, paths): """Fit regressor based on paths.""" observations = np.concatenate([p['observations'] for p in paths]) returns = np.concatenate([p['returns'] for p in paths]) self._regressor.fit(observations, returns.reshape((-1, 1)))
[docs] def predict(self, path): """Predict value based on paths.""" return self._regressor.predict(path['observations']).flatten()
[docs] def get_param_values(self, **tags): """Get parameter values.""" return self._regressor.get_param_values(**tags)
[docs] def set_param_values(self, flattened_params, **tags): """Set parameter values to val.""" self._regressor.set_param_values(flattened_params, **tags)
[docs] def get_params_internal(self, **tags): """Get parameter values.""" return self._regressor.get_params_internal(**tags)