Source code for garage.np.policies.base

"""Base class for policies based on numpy."""
import abc


[docs]class Policy(abc.ABC): """Base classe for policies based on numpy. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. """ def __init__(self, env_spec): self._env_spec = env_spec
[docs] @abc.abstractmethod def get_action(self, observation): """Get action sampled from the policy. Args: observation (np.ndarray): Observation from the environment. Returns: (np.ndarray): Action sampled from the policy. """
[docs] def reset(self, dones=None): """Reset the policy. If dones is None, it will be by default np.array([True]) which implies the policy will not be "vectorized", i.e. number of parallel environments for training data sampling = 1. Args: dones (numpy.ndarray): Bool that indicates terminal state(s). """
@property def observation_space(self): """akro.Space: The observation space of the environment.""" return self._env_spec.observation_space @property def action_space(self): """akro.Space: The action space for the environment.""" return self._env_spec.action_space @property def recurrent(self): """Indicate whether the policy is recurrent. Returns: bool: True if policy is recurrent, False otherwise. """ return False
[docs] def log_diagnostics(self, paths): """Log extra information per iteration based on the collected paths. Args: paths (list[dict]): A list of collected paths """
@property def state_info_keys(self): """Get keys describing policy's state. Returns: List[str]: keys for the information related to the policy's state when taking an action. """ return list()
[docs] def terminate(self): """Clean up operation."""
[docs]class StochasticPolicy(Policy): """Base class for stochastic policies implemented in numpy.""" @property @abc.abstractmethod def distribution(self): """Get the distribution of the policy. Returns: garage.tf.distribution: The distribution of the policy. """
[docs] @abc.abstractmethod def dist_info(self, obs, state_infos): """Return the distribution information about the actions. Args: obs (np.ndarray): observation values state_infos (dict): a dictionary whose values should contain information about the state of the policy at the time it received the observation """