"""Base Stochastic Policy."""
import abc
import akro
import numpy as np
import torch
from garage.torch import global_device
from garage.torch.policies.policy import Policy
[docs]class StochasticPolicy(Policy, abc.ABC):
"""Abstract base class for torch stochastic policies."""
[docs] def get_action(self, observation):
r"""Get a single action given an observation.
Args:
observation (np.ndarray): Observation from the environment.
Shape is :math:`env_spec.observation_space`.
Returns:
tuple:
* np.ndarray: Predicted action. Shape is
:math:`env_spec.action_space`.
* dict:
* np.ndarray[float]: Mean of the distribution
* np.ndarray[float]: Standard deviation of logarithmic
values of the distribution.
"""
if not isinstance(observation, np.ndarray) and not isinstance(
observation, torch.Tensor):
observation = self._env_spec.observation_space.flatten(observation)
with torch.no_grad():
if not isinstance(observation, torch.Tensor):
observation = torch.as_tensor(observation).float().to(
global_device())
observation = observation.unsqueeze(0)
action, agent_infos = self.get_actions(observation)
return action[0], {k: v[0] for k, v in agent_infos.items()}
[docs] def get_actions(self, observations):
r"""Get actions given observations.
Args:
observations (np.ndarray): Observations from the environment.
Shape is :math:`batch_dim \bullet env_spec.observation_space`.
Returns:
tuple:
* np.ndarray: Predicted actions.
:math:`batch_dim \bullet env_spec.action_space`.
* dict:
* np.ndarray[float]: Mean of the distribution.
* np.ndarray[float]: Standard deviation of logarithmic
values of the distribution.
"""
if not isinstance(observations[0], np.ndarray) and not isinstance(
observations[0], torch.Tensor):
observations = self._env_spec.observation_space.flatten_n(
observations)
# frequently users like to pass lists of torch tensors or lists of
# numpy arrays. This handles those conversions.
if isinstance(observations, list):
if isinstance(observations[0], np.ndarray):
observations = np.stack(observations)
elif isinstance(observations[0], torch.Tensor):
observations = torch.stack(observations)
if isinstance(self._env_spec.observation_space, akro.Image) and \
len(observations.shape) < \
len(self._env_spec.observation_space.shape):
observations = self._env_spec.observation_space.unflatten_n(
observations)
with torch.no_grad():
if not isinstance(observations, torch.Tensor):
observations = torch.as_tensor(observations).float().to(
global_device())
dist, info = self.forward(observations)
return dist.sample().cpu().numpy(), {
k: v.detach().cpu().numpy()
for (k, v) in info.items()
}
# pylint: disable=arguments-differ
[docs] @abc.abstractmethod
def forward(self, observations):
"""Compute the action distributions from the observations.
Args:
observations (torch.Tensor): Batch of observations on default
torch device.
Returns:
torch.distributions.Distribution: Batch distribution of actions.
dict[str, torch.Tensor]: Additional agent_info, as torch Tensors.
Do not need to be detached, and can be on any device.
"""