garage.tf.distributions.categorical module

class Categorical(dim, name=None)[source]

Bases: garage.tf.distributions.distribution.Distribution

cross_entropy_sym(old_dist_info_vars, new_dist_info_vars, name='cross_entropy_sym')[source]
dim

Dimension of this distribution.

Type:int
dist_info_specs

Specification of the parameter of a distribution.

Type:list
entropy(info)[source]

Entropy of a distribution.

Parameters:dist_info (dict) – Parameters of a distribution.
Returns:Entropy of the distribution.
Return type:float
entropy_sym(dist_info_vars)[source]

Symbolic entropy of a distribution.

Parameters:
  • dist_info_vars (dict) – Symbolic parameters of a distribution.
  • name (str) – TensorFlow scope name.
Returns:

Symbolic entropy of the distribution.

Return type:

tf.Tensor

kl(old_dist_info, new_dist_info)[source]

Compute the KL divergence of two categorical distributions

kl_sym(old_dist_info_vars, new_dist_info_vars, name='kl_sym')[source]

Compute the symbolic KL divergence of two categorical distributions

likelihood_ratio_sym(x_var, old_dist_info_vars, new_dist_info_vars, name='likelihood_ratio_sym')[source]

Symbolic likelihood ratio.

Parameters:
  • x_var (tf.Tensor) – Input placeholder.
  • old_dist_info_vars (dict) – Old distribution tensors.
  • new_dist_info_vars (dict) – New distribution tensors.
  • name (str) – TensorFlow scope name.
Returns:

Symbolic likelihood ratio.

Return type:

tf.Tensor

log_likelihood(xs, dist_info)[source]

Log likelihood of a sample under a distribution.

Parameters:
  • xs (np.ndarray) – Input value.
  • dist_info (dict) – Parameters of a distribution.
Returns:

Log likelihood of a sample under the distribution.

Return type:

float

log_likelihood_sym(x_var, dist_info_vars, name='log_likelihood_sym')[source]

Symbolic log likelihood.

Parameters:
  • x_var (tf.Tensor) – Input placeholder.
  • dist_info_vars (dict) – Parameters of a distribution.
  • name (str) – TensorFlow scope name.
Returns:

Symbolic log likelihood.

Return type:

tf.Tensor

sample(dist_info)[source]
sample_sym(dist_info, name='sample_sym')[source]
from_onehot(x_var)[source]