Source code for garage.torch.policies.stochastic_policy

"""Base Stochastic Policy."""
import abc

import akro
import numpy as np
import torch

from garage.torch import global_device
from garage.torch.policies.policy import Policy


[docs]class StochasticPolicy(Policy, abc.ABC): """Abstract base class for torch stochastic policies."""
[docs] def get_action(self, observation): r"""Get a single action given an observation. Args: observation (np.ndarray): Observation from the environment. Shape is :math:`env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted action. Shape is :math:`env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ if not isinstance(observation, np.ndarray) and not isinstance( observation, torch.Tensor): observation = self._env_spec.observation_space.flatten(observation) with torch.no_grad(): if not isinstance(observation, torch.Tensor): observation = torch.as_tensor(observation).float().to( global_device()) observation = observation.unsqueeze(0) action, agent_infos = self.get_actions(observation) return action[0], {k: v[0] for k, v in agent_infos.items()}
[docs] def get_actions(self, observations): r"""Get actions given observations. Args: observations (np.ndarray): Observations from the environment. Shape is :math:`batch_dim \bullet env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted actions. :math:`batch_dim \bullet env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ if not isinstance(observations[0], np.ndarray) and not isinstance( observations[0], torch.Tensor): observations = self._env_spec.observation_space.flatten_n( observations) # frequently users like to pass lists of torch tensors or lists of # numpy arrays. This handles those conversions. if isinstance(observations, list): if isinstance(observations[0], np.ndarray): observations = np.stack(observations) elif isinstance(observations[0], torch.Tensor): observations = torch.stack(observations) if isinstance(self._env_spec.observation_space, akro.Image) and \ len(observations.shape) < \ len(self._env_spec.observation_space.shape): observations = self._env_spec.observation_space.unflatten_n( observations) with torch.no_grad(): if not isinstance(observations, torch.Tensor): observations = torch.as_tensor(observations).float().to( global_device()) dist, info = self.forward(observations) return dist.sample().cpu().numpy(), { k: v.detach().cpu().numpy() for (k, v) in info.items() }
# pylint: disable=arguments-differ
[docs] @abc.abstractmethod def forward(self, observations): """Compute the action distributions from the observations. Args: observations (torch.Tensor): Batch of observations on default torch device. Returns: torch.distributions.Distribution: Batch distribution of actions. dict[str, torch.Tensor]: Additional agent_info, as torch Tensors. Do not need to be detached, and can be on any device. """