Source code for garage.np.policies.scripted_policy

"""Simulates a garage policy object."""
from garage.np.policies.policy import Policy


[docs]class ScriptedPolicy(Policy): """Simulates a garage policy object. Args: scripted_actions (list or dictionary): data structure indexed by observation, returns a corresponding action agent_env_infos (list or dictionary): data structure indexed by observation, returns a corresponding agent_env_info """ def __init__(self, scripted_actions, agent_env_infos=None): self._scripted_actions = scripted_actions self._agent_env_infos = agent_env_infos
[docs] def set_param_values(self, params): """Set param values. Args: params (np.ndarray): A numpy array of parameter values. """
# pylint: disable=no-self-use
[docs] def get_param_values(self): """Get param values. Returns: np.ndarray: Values of the parameters evaluated in the current session """ return []
[docs] def get_action(self, observation): """Return a single action. Args: observation (numpy.ndarray): Observations. Returns: int: Action given input observation. dict[dict]: Agent infos indexed by observation. """ if self._agent_env_infos: a_info = self._agent_env_infos[observation] else: a_info = dict() return self._scripted_actions[observation], a_info
[docs] def get_actions(self, observations): """Return multiple actions. Args: observations (numpy.ndarray): Observations. Returns: list[int]: Actions given input observations. dict[dict]: Agent info indexed by observation. """ if self._agent_env_infos: a_info = self._agent_env_infos[observations[0]] else: a_info = dict() return [self._scripted_actions[obs] for obs in observations], a_info