Source code for garage.tf.policies.policy

"""Base class for policies in TensorFlow."""
import abc

from garage.tf.models import Module, StochasticModule


[docs]class Policy(Module): """Base class for policies in TensorFlow. Args: name (str): Policy name, also the variable scope. env_spec (garage.envs.env_spec.EnvSpec): Environment specification. """ def __init__(self, name, env_spec): super().__init__(name) self._env_spec = env_spec
[docs] @abc.abstractmethod def get_action(self, observation): """Get action sampled from the policy. Args: observation (np.ndarray): Observation from the environment. Returns: (np.ndarray): Action sampled from the policy. """
[docs] @abc.abstractmethod def get_actions(self, observations): """Get action sampled from the policy. Args: observations (list[np.ndarray]): Observations from the environment. Returns: (np.ndarray): Actions sampled from the policy. """
@property def vectorized(self): """Boolean for vectorized. Returns: bool: Indicates whether the policy is vectorized. If True, it should implement get_actions(), and support resetting with multiple simultaneous states. """ return False @property def observation_space(self): """Observation space. Returns: akro.Space: The observation space of the environment. """ return self._env_spec.observation_space @property def action_space(self): """Action space. Returns: akro.Space: The action space of the environment. """ return self._env_spec.action_space @property def env_spec(self): """Policy environment specification. Returns: garage.EnvSpec: Environment specification. """ return self._env_spec
[docs] def log_diagnostics(self, paths): """Log extra information per iteration based on the collected paths. Args: paths (dict[numpy.ndarray]): Sample paths. """
# pylint: disable=abstract-method
[docs]class StochasticPolicy(Policy, StochasticModule): """Stochastic Policy."""