"""Default Worker class."""
from collections import defaultdict
import numpy as np
from garage import EpisodeBatch, StepType
from garage.experiment import deterministic
from garage.sampler import _apply_env_update
from garage.sampler.worker import Worker
[docs]class DefaultWorker(Worker):
"""Initialize a worker.
Args:
seed (int): The seed to use to intialize random number generators.
max_episode_length (int or float): The maximum length of episodes which
will be sampled. Can be (floating point) infinity.
worker_number (int): The number of the worker where this update is
occurring. This argument is used to set a different seed for each
worker.
Attributes:
agent (Policy or None): The worker's agent.
env (Environment or None): The worker's environment.
"""
def __init__(
self,
*, # Require passing by keyword, since everything's an int.
seed,
max_episode_length,
worker_number):
super().__init__(seed=seed,
max_episode_length=max_episode_length,
worker_number=worker_number)
self.agent = None
self.env = None
self._env_steps = []
self._observations = []
self._last_observations = []
self._agent_infos = defaultdict(list)
self._lengths = []
self._prev_obs = None
self._eps_length = 0
self.worker_init()
[docs] def worker_init(self):
"""Initialize a worker."""
if self._seed is not None:
deterministic.set_seed(self._seed + self._worker_number)
[docs] def update_agent(self, agent_update):
"""Update an agent, assuming it implements :class:`~Policy`.
Args:
agent_update (np.ndarray or dict or Policy): If a tuple, dict, or
np.ndarray, these should be parameters to agent, which should
have been generated by calling `Policy.get_param_values`.
Alternatively, a policy itself. Note that other implementations
of `Worker` may take different types for this parameter.
"""
if isinstance(agent_update, (dict, tuple, np.ndarray)):
self.agent.set_param_values(agent_update)
elif agent_update is not None:
self.agent = agent_update
[docs] def update_env(self, env_update):
"""Use any non-None env_update as a new environment.
A simple env update function. If env_update is not None, it should be
the complete new environment.
This allows changing environments by passing the new environment as
`env_update` into `obtain_samples`.
Args:
env_update(Environment or EnvUpdate or None): The environment to
replace the existing env with. Note that other implementations
of `Worker` may take different types for this parameter.
Raises:
TypeError: If env_update is not one of the documented types.
"""
self.env, _ = _apply_env_update(self.env, env_update)
[docs] def start_episode(self):
"""Begin a new episode."""
self._eps_length = 0
self._prev_obs, _ = self.env.reset()
self.agent.reset()
[docs] def step_episode(self):
"""Take a single time-step in the current episode.
Returns:
bool: True iff the episode is done, either due to the environment
indicating termination of due to reaching `max_episode_length`.
"""
if self._eps_length < self._max_episode_length:
a, agent_info = self.agent.get_action(self._prev_obs)
es = self.env.step(a)
self._observations.append(self._prev_obs)
self._env_steps.append(es)
for k, v in agent_info.items():
self._agent_infos[k].append(v)
self._eps_length += 1
if not es.terminal:
self._prev_obs = es.observation
return False
self._lengths.append(self._eps_length)
self._last_observations.append(self._prev_obs)
return True
[docs] def collect_episode(self):
"""Collect the current episode, clearing the internal buffer.
Returns:
EpisodeBatch: A batch of the episodes completed since the last call
to collect_episode().
"""
observations = self._observations
self._observations = []
last_observations = self._last_observations
self._last_observations = []
actions = []
rewards = []
env_infos = defaultdict(list)
step_types = []
for es in self._env_steps:
actions.append(es.action)
rewards.append(es.reward)
step_types.append(es.step_type)
for k, v in es.env_info.items():
env_infos[k].append(v)
self._env_steps = []
agent_infos = self._agent_infos
self._agent_infos = defaultdict(list)
for k, v in agent_infos.items():
agent_infos[k] = np.asarray(v)
for k, v in env_infos.items():
env_infos[k] = np.asarray(v)
lengths = self._lengths
self._lengths = []
return EpisodeBatch(env_spec=self.env.spec,
observations=np.asarray(observations),
last_observations=np.asarray(last_observations),
actions=np.asarray(actions),
rewards=np.asarray(rewards),
step_types=np.asarray(step_types, dtype=StepType),
env_infos=dict(env_infos),
agent_infos=dict(agent_infos),
lengths=np.asarray(lengths, dtype='i'))
[docs] def rollout(self):
"""Sample a single episode of the agent in the environment.
Returns:
EpisodeBatch: The collected episode.
"""
self.start_episode()
while not self.step_episode():
pass
return self.collect_episode()
[docs] def shutdown(self):
"""Close the worker's environment."""
self.env.close()