garage.tf.baselines.gaussian_cnn_baseline module

Gaussian CNN Baseline.

class GaussianCNNBaseline(env_spec, subsample_factor=1.0, regressor_args=None, name='GaussianCNNBaseline')[source]

Bases: garage.np.baselines.base.Baseline

GaussianCNNBaseline With Model.

It fits the input data to a gaussian distribution estimated by a CNN.

Parameters:
  • env_spec (garage.envs.env_spec.EnvSpec) – Environment specification.
  • subsample_factor (float) – The factor to subsample the data. By default it is 1.0, which means using all the data.
  • regressor_args (dict) – Arguments for regressor.
fit(paths)[source]

Fit regressor based on paths.

get_param_values(**tags)[source]

Get parameter values.

get_params_internal(**tags)[source]

Get parameter values.

predict(path)[source]

Predict value based on paths.

set_param_values(flattened_params, **tags)[source]

Set parameter values to val.