Source code for garage.torch.policies.stochastic_policy

"""Base Stochastic Policy."""
import abc

import torch

from garage.torch import global_device
from garage.torch.policies.policy import Policy


[docs]class StochasticPolicy(Policy, abc.ABC): """Abstract base class for torch stochastic policies."""
[docs] def get_action(self, observation): r"""Get a single action given an observation. Args: observation (np.ndarray): Observation from the environment. Shape is :math:`env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted action. Shape is :math:`env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ with torch.no_grad(): if not isinstance(observation, torch.Tensor): observation = torch.as_tensor(observation).float().to( global_device()) observation = observation.unsqueeze(0) dist, info = self.forward(observation) return dist.sample().squeeze(0).cpu().numpy(), { k: v.squeeze(0).detach().cpu().numpy() for (k, v) in info.items() }
[docs] def get_actions(self, observations): r"""Get actions given observations. Args: observations (np.ndarray): Observations from the environment. Shape is :math:`batch_dim \bullet env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted actions. :math:`batch_dim \bullet env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ with torch.no_grad(): if not isinstance(observations, torch.Tensor): observations = torch.as_tensor(observations).float().to( global_device()) dist, info = self.forward(observations) return dist.sample().cpu().numpy(), { k: v.detach().cpu().numpy() for (k, v) in info.items() }
# pylint: disable=arguments-differ
[docs] @abc.abstractmethod def forward(self, observations): """Compute the action distributions from the observations. Args: observations (torch.Tensor): Batch of observations on default torch device. Returns: torch.distributions.Distribution: Batch distribution of actions. dict[str, torch.Tensor]: Additional agent_info, as torch Tensors. Do not need to be detached, and can be on any device. """