Source code for garage.torch.distributions.tanh_normal

"""A Gaussian distribution with tanh transformation."""
import torch
from torch.distributions import Normal
from torch.distributions.independent import Independent


[docs]class TanhNormal(torch.distributions.Distribution): r"""A distribution induced by applying a tanh transformation to a Gaussian random variable. Algorithms like SAC and Pearl use this transformed distribution. It can be thought of as a distribution of X where :math:`Y ~ \mathcal{N}(\mu, \sigma)` :math:`X = tanh(Y)` Args: loc (torch.Tensor): The mean of this distribution. scale (torch.Tensor): The stdev of this distribution. """ # noqa: 501 def __init__(self, loc, scale): self._normal = Independent(Normal(loc, scale), 1) super().__init__()
[docs] def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6): """The log likelihood of a sample on the this Tanh Distribution. Args: value (torch.Tensor): The sample whose loglikelihood is being computed. pre_tanh_value (torch.Tensor): The value prior to having the tanh function applied to it but after it has been sampled from the normal distribution. epsilon (float): Regularization constant. Making this value larger makes the computation more stable but less precise. Note: when pre_tanh_value is None, an estimate is made of what the value is. This leads to a worse estimation of the log_prob. If the value being used is collected from functions like `sample` and `rsample`, one can instead use functions like `sample_return_pre_tanh_value` or `rsample_return_pre_tanh_value` Returns: torch.Tensor: The log likelihood of value on the distribution. """ # pylint: disable=arguments-differ if pre_tanh_value is None: pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2 norm_lp = self._normal.log_prob(pre_tanh_value) ret = (norm_lp - torch.sum( torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon), axis=-1)) return ret
[docs] def sample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients `do not` pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape)
[docs] def rsample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ z = self._normal.rsample(sample_shape) return torch.tanh(z)
[docs] def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal distribution. Returns the sampled value before the tanh transform is applied and the sampled value with the tanh transform applied to it. Args: sample_shape (list): shape of the return. Note: Gradients pass through this operation. Returns: torch.Tensor: Samples from this distribution. torch.Tensor: Samples from the underlying :obj:`torch.distributions.Normal` distribution, prior to being transformed with `tanh`. """ z = self._normal.rsample(sample_shape) return z, torch.tanh(z)
[docs] def cdf(self, value): """Returns the CDF at the value. Returns the cumulative density/mass function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.cdf(value)
[docs] def icdf(self, value): """Returns the icdf function evaluated at `value`. Returns the icdf function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.icdf(value)
@classmethod def _from_distribution(cls, new_normal): """Construct a new TanhNormal distribution from a normal distribution. Args: new_normal (Independent(Normal)): underlying normal dist for the new TanhNormal distribution. Returns: TanhNormal: A new distribution whose underlying normal dist is new_normal. """ # pylint: disable=protected-access new = cls(torch.zeros(1), torch.zeros(1)) new._normal = new_normal return new
[docs] def expand(self, batch_shape, _instance=None): """Returns a new TanhNormal distribution. (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance(instance): new instance provided by subclasses that need to override `.expand`. Returns: Instance: New distribution instance with batch dimensions expanded to `batch_size`. """ new_normal = self._normal.expand(batch_shape, _instance) new = self._from_distribution(new_normal) return new
[docs] def enumerate_support(self, expand=True): """Returns tensor containing all values supported by a discrete dist. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Note: Calls the enumerate_support function of the underlying normal distribution. Returns: torch.Tensor: Tensor iterating over dimension 0. """ return self._normal.enumerate_support(expand)
@property def mean(self): """torch.Tensor: mean of the distribution.""" return torch.tanh(self._normal.mean) @property def variance(self): """torch.Tensor: variance of the underlying normal distribution.""" return self._normal.variance
[docs] def entropy(self): """Returns entropy of the underlying normal distribution. Returns: torch.Tensor: entropy of the underlying normal distribution. """ return self._normal.entropy()
@staticmethod def _clip_but_pass_gradient(x, lower=0., upper=1.): """Clipping function that allows for gradients to flow through. Args: x (torch.Tensor): value to be clipped lower (float): lower bound of clipping upper (float): upper bound of clipping Returns: torch.Tensor: x clipped between lower and upper. """ clip_up = (x > upper).float() clip_low = (x < lower).float() with torch.no_grad(): clip = ((upper - x) * clip_up + (lower - x) * clip_low) return x + clip def __repr__(self): """Returns the parameterization of the distribution. Returns: str: The parameterization of the distribution and underlying distribution. """ return self.__class__.__name__