"""Base Stochastic Policy."""
import abc
import torch
from garage.torch import global_device
from garage.torch.policies.policy import Policy
[docs]class StochasticPolicy(Policy, abc.ABC):
"""Abstract base class for torch stochastic policies."""
[docs] def get_action(self, observation):
r"""Get a single action given an observation.
Args:
observation (np.ndarray): Observation from the environment.
Shape is :math:`env_spec.observation_space`.
Returns:
tuple:
* np.ndarray: Predicted action. Shape is
:math:`env_spec.action_space`.
* dict:
* np.ndarray[float]: Mean of the distribution
* np.ndarray[float]: Standard deviation of logarithmic
values of the distribution.
"""
with torch.no_grad():
if not isinstance(observation, torch.Tensor):
observation = torch.as_tensor(observation).float().to(
global_device())
observation = observation.unsqueeze(0)
dist, info = self.forward(observation)
return dist.sample().squeeze(0).cpu().numpy(), {
k: v.squeeze(0).detach().cpu().numpy()
for (k, v) in info.items()
}
[docs] def get_actions(self, observations):
r"""Get actions given observations.
Args:
observations (np.ndarray): Observations from the environment.
Shape is :math:`batch_dim \bullet env_spec.observation_space`.
Returns:
tuple:
* np.ndarray: Predicted actions.
:math:`batch_dim \bullet env_spec.action_space`.
* dict:
* np.ndarray[float]: Mean of the distribution.
* np.ndarray[float]: Standard deviation of logarithmic
values of the distribution.
"""
with torch.no_grad():
if not isinstance(observations, torch.Tensor):
observations = torch.as_tensor(observations).float().to(
global_device())
dist, info = self.forward(observations)
return dist.sample().cpu().numpy(), {
k: v.detach().cpu().numpy()
for (k, v) in info.items()
}
# pylint: disable=arguments-differ
[docs] @abc.abstractmethod
def forward(self, observations):
"""Compute the action distributions from the observations.
Args:
observations (torch.Tensor): Batch of observations on default
torch device.
Returns:
torch.distributions.Distribution: Batch distribution of actions.
dict[str, torch.Tensor]: Additional agent_info, as torch Tensors.
Do not need to be detached, and can be on any device.
"""