Source code for garage.torch.policies.context_conditioned_policy

"""A policy used in training meta reinforcement learning algorithms.

It is used in PEARL (Probabilistic Embeddings for Actor-Critic Reinforcement
Learning). The paper on PEARL can be found at https://arxiv.org/abs/1903.08254.
Code is adapted from https://github.com/katerakelly/oyster.
"""

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from garage.torch import global_device, product_of_gaussians


# pylint: disable=attribute-defined-outside-init
# pylint does not recognize attributes initialized as buffers in constructor
[docs]class ContextConditionedPolicy(nn.Module): """A policy that outputs actions based on observation and latent context. In PEARL, policies are conditioned on current state and a latent context (adaptation data) variable Z. This inference network estimates the posterior probability of z given past transitions. It uses context information stored in the encoder to infer the probabilistic value of z and samples from a policy conditioned on z. Args: latent_dim (int): Latent context variable dimension. context_encoder (garage.torch.embeddings.ContextEncoder): Recurrent or permutation-invariant context encoder. policy (garage.torch.policies.Policy): Policy used to train the network. use_information_bottleneck (bool): True if latent context is not deterministic; false otherwise. use_next_obs (bool): True if next observation is used in context for distinguishing tasks; false otherwise. """ def __init__(self, latent_dim, context_encoder, policy, use_information_bottleneck, use_next_obs): super().__init__() self._latent_dim = latent_dim self._context_encoder = context_encoder self._policy = policy self._use_information_bottleneck = use_information_bottleneck self._use_next_obs = use_next_obs # initialize buffers for z distribution and z # use buffers so latent context can be saved along with model weights # z_means and z_vars are the params for the gaussian distribution # over latent task belief maintained in the policy; z is a sample from # this distribution that the policy is conditioned on self.register_buffer('z', torch.zeros(1, latent_dim)) self.register_buffer('z_means', torch.zeros(1, latent_dim)) self.register_buffer('z_vars', torch.zeros(1, latent_dim)) self.reset_belief()
[docs] def reset_belief(self, num_tasks=1): r"""Reset :math:`q(z \| c)` to the prior and sample a new z from the prior. Args: num_tasks (int): Number of tasks. """ # reset distribution over z to the prior mu = torch.zeros(num_tasks, self._latent_dim).to(global_device()) if self._use_information_bottleneck: var = torch.ones(num_tasks, self._latent_dim).to(global_device()) else: var = torch.zeros(num_tasks, self._latent_dim).to(global_device()) self.z_means = mu self.z_vars = var # sample a new z from the prior self.sample_from_belief() # reset the context collected so far self._context = None # reset any hidden state in the encoder network (relevant for RNN) self._context_encoder.reset()
[docs] def sample_from_belief(self): """Sample z using distributions from current means and variances.""" if self._use_information_bottleneck: posteriors = [ torch.distributions.Normal(m, torch.sqrt(s)) for m, s in zip( torch.unbind(self.z_means), torch.unbind(self.z_vars)) ] z = [d.rsample() for d in posteriors] self.z = torch.stack(z) else: self.z = self.z_means
[docs] def update_context(self, timestep): """Append single transition to the current context. Args: timestep (garage._dtypes.TimeStep): Timestep containing transition information to be added to context. """ o = torch.as_tensor(timestep.observation[None, None, ...], device=global_device()).float() a = torch.as_tensor(timestep.action[None, None, ...], device=global_device()).float() r = torch.as_tensor(np.array([timestep.reward])[None, None, ...], device=global_device()).float() no = torch.as_tensor(timestep.next_observation[None, None, ...], device=global_device()).float() if self._use_next_obs: data = torch.cat([o, a, r, no], dim=2) else: data = torch.cat([o, a, r], dim=2) if self._context is None: self._context = data else: self._context = torch.cat([self._context, data], dim=1)
[docs] def infer_posterior(self, context): r"""Compute :math:`q(z \| c)` as a function of input context and sample new z. Args: context (torch.Tensor): Context values, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ params = self._context_encoder.forward(context) params = params.view(context.size(0), -1, self._context_encoder.output_dim) # with probabilistic z, predict mean and variance of q(z | c) if self._use_information_bottleneck: mu = params[..., :self._latent_dim] sigma_squared = F.softplus(params[..., self._latent_dim:]) z_params = [ product_of_gaussians(m, s) for m, s in zip(torch.unbind(mu), torch.unbind(sigma_squared)) ] self.z_means = torch.stack([p[0] for p in z_params]) self.z_vars = torch.stack([p[1] for p in z_params]) else: self.z_means = torch.mean(params, dim=1) self.sample_from_belief()
# pylint: disable=arguments-differ
[docs] def forward(self, obs, context): """Given observations and context, get actions and probs from policy. Args: obs (torch.Tensor): Observation values, with shape :math:`(X, N, O)`. X is the number of tasks. N is batch size. O is the size of the flattened observation space. context (torch.Tensor): Context values, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. Returns: tuple: * torch.Tensor: Predicted action values. * np.ndarray: Mean of distribution. * np.ndarray: Log std of distribution. * torch.Tensor: Log likelihood of distribution. * torch.Tensor: Sampled values from distribution before applying tanh transformation. torch.Tensor: z values, with shape :math:`(N, L)`. N is batch size. L is the latent dimension. """ self.infer_posterior(context) self.sample_from_belief() task_z = self.z # task, batch t, b, _ = obs.size() obs = obs.view(t * b, -1) task_z = [z.repeat(b, 1) for z in task_z] task_z = torch.cat(task_z, dim=0) # run policy, get log probs and new actions obs_z = torch.cat([obs, task_z.detach()], dim=1) dist = self._policy(obs_z)[0] pre_tanh, actions = dist.rsample_with_pre_tanh_value() log_pi = dist.log_prob(value=actions, pre_tanh_value=pre_tanh) log_pi = log_pi.unsqueeze(1) mean = dist.mean.to('cpu').detach().numpy() log_std = (dist.variance**.5).log().to('cpu').detach().numpy() return (actions, mean, log_std, log_pi, pre_tanh), task_z
[docs] def get_action(self, obs): """Sample action from the policy, conditioned on the task embedding. Args: obs (torch.Tensor): Observation values, with shape :math:`(1, O)`. O is the size of the flattened observation space. Returns: torch.Tensor: Output action value, with shape :math:`(1, A)`. A is the size of the flattened action space. dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ z = self.z obs = torch.as_tensor(obs[None], device=global_device()).float() obs_in = torch.cat([obs, z], dim=1) action, info = self._policy.get_action(obs_in) action = np.squeeze(action, axis=0) info['mean'] = np.squeeze(info['mean'], axis=0) return action, info
[docs] def compute_kl_div(self): r"""Compute :math:`KL(q(z|c) \| p(z))`. Returns: float: :math:`KL(q(z|c) \| p(z))`. """ prior = torch.distributions.Normal( torch.zeros(self._latent_dim).to(global_device()), torch.ones(self._latent_dim).to(global_device())) posteriors = [ torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip( torch.unbind(self.z_means), torch.unbind(self.z_vars)) ] kl_divs = [ torch.distributions.kl.kl_divergence(post, prior) for post in posteriors ] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
@property def networks(self): """Return context_encoder and policy. Returns: list: Encoder and policy networks. """ return [self._context_encoder, self._policy] @property def context(self): """Return context. Returns: torch.Tensor: Context values, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ return self._context