Source code for garage.np.exploration_policies.exploration_policy

"""Exploration Policy API used by off-policy algorithms."""
import abc


# This should be an ABC inheritting from garage.Policy, but that doesn't exist
# yet.
[docs]class ExplorationPolicy(abc.ABC): """Policy that wraps another policy to add action noise. Args: policy (garage.Policy): Policy to wrap. """ def __init__(self, policy): self.policy = policy
[docs] @abc.abstractmethod def get_action(self, observation): """Return an action with noise. Args: observation (np.ndarray): Observation from the environment. Returns: np.ndarray: An action with noise. dict: Arbitrary policy state information (agent_info). """
[docs] @abc.abstractmethod def get_actions(self, observations): """Return actions with noise. Args: observations (np.ndarray): Observation from the environment. Returns: np.ndarray: Actions with noise. List[dict]: Arbitrary policy state information (agent_info). """
[docs] def reset(self, dones=None): """Reset the state of the exploration. Args: dones (List[bool] or numpy.ndarray or None): Which vectorization states to reset. """ self.policy.reset(dones)
[docs] def update(self, episode_batch): """Update the exploration policy using a batch of trajectories. Args: episode_batch (EpisodeBatch): A batch of trajectories which were sampled with this policy active. """
[docs] def get_param_values(self): """Get parameter values. Returns: list or dict: Values of each parameter. """ return self.policy.get_param_values()
[docs] def set_param_values(self, params): """Set param values. Args: params (np.ndarray): A numpy array of parameter values. """ self.policy.set_param_values(params)