"""Base Policy."""
import abc
import torch
[docs]class Policy(torch.nn.Module, abc.ABC):
"""Policy base class.
Args:
env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
name (str): Name of policy.
"""
def __init__(self, env_spec, name):
super().__init__()
self._env_spec = env_spec
self._name = name
[docs] @abc.abstractmethod
def get_action(self, observation):
"""Get a single action given an observation.
Args:
observation (torch.Tensor): Observation from the environment.
Returns:
tuple:
* torch.Tensor: Predicted action.
* dict:
* list[float]: Mean of the distribution
* list[float]: Log of standard deviation of the
distribution
"""
[docs] @abc.abstractmethod
def get_actions(self, observations):
"""Get actions given observations.
Args:
observations (torch.Tensor): Observations from the environment.
Returns:
tuple:
* torch.Tensor: Predicted actions.
* dict:
* list[float]: Mean of the distribution
* list[float]: Log of standard deviation of the
distribution
"""
@property
def observation_space(self):
"""The observation space for the environment.
Returns:
akro.Space: Observation space.
"""
return self._env_spec.observation_space
@property
def action_space(self):
"""The action space for the environment.
Returns:
akro.Space: Action space.
"""
return self._env_spec.action_space
[docs] def get_param_values(self):
"""Get the parameters to the policy.
This method is included to ensure consistency with TF policies.
Returns:
dict: The parameters (in the form of the state dictionary).
"""
return self.state_dict()
[docs] def set_param_values(self, state_dict):
"""Set the parameters to the policy.
This method is included to ensure consistency with TF policies.
Args:
state_dict (dict): State dictionary.
"""
self.load_state_dict(state_dict)
[docs] def reset(self, dones=None):
"""Reset the environment.
Args:
dones (numpy.ndarray): Reset values
"""
@property
def name(self):
"""Name of policy.
Returns:
str: Name of policy
"""
return self._name