"""Worker that "vectorizes" environments."""
import collections
import copy
import gym
import numpy as np
from garage import TrajectoryBatch
from garage.sampler.default_worker import DefaultWorker
from garage.sampler.env_update import EnvUpdate
[docs]class VecWorker(DefaultWorker):
"""Worker with a single policy and multiple environemnts.
Alternates between taking a single step in all environments and asking the
policy for an action for every environment. This allows computing a batch
of actions, which is generally much more efficient than computing a single
action when using neural networks.
Args:
seed(int): The seed to use to intialize random number generators.
max_path_length(int or float): The maximum length paths which will
be sampled. Can be (floating point) infinity.
worker_number(int): The number of the worker this update is
occurring in. This argument is used to set a different seed for
each worker.
n_envs (int): Number of environment copies to use.
"""
DEFAULT_N_ENVS = 8
def __init__(self,
*,
seed,
max_path_length,
worker_number,
n_envs=DEFAULT_N_ENVS):
super().__init__(seed=seed,
max_path_length=max_path_length,
worker_number=worker_number)
self._n_envs = n_envs
self._completed_rollouts = []
self._needs_agent_reset = True
self._needs_env_reset = True
self._envs = [None] * n_envs
self._agents = [None] * n_envs
self._path_lengths = [0] * self._n_envs
[docs] def update_agent(self, agent_update):
"""Update an agent, assuming it implements garage.Policy.
Args:
agent_update (np.ndarray or dict or garage.Policy): If a
tuple, dict, or np.ndarray, these should be parameters to
agent, which should have been generated by calling
`policy.get_param_values`. Alternatively, a policy itself. Note
that other implementations of `Worker` may take different types
for this parameter.
"""
super().update_agent(agent_update)
self._needs_agent_reset = True
[docs] def update_env(self, env_update):
"""Use any non-None env_update as a new environment.
A simple env update function. If env_update is not None, it should be
the complete new environment.
This allows changing environments by passing the new environment as
`env_update` into `obtain_samples`.
Args:
env_update(gym.Env or EnvUpdate or None): The environment to
replace the existing env with. Note that other implementations
of `Worker` may take different types for this parameter.
Raises:
TypeError: If env_update is not one of the documented types.
ValueError: If the wrong number of updates is passed.
"""
if isinstance(env_update, list):
if len(env_update) != self._n_envs:
raise ValueError('If separate environments are passed for '
'each worker, there must be exactly n_envs '
'({}) environments, but received {} '
'environments.'.format(
self._n_envs, len(env_update)))
for env_index, env_up in enumerate(env_update):
self._update_env_inner(env_up, env_index)
elif env_update is not None:
for env_index in range(self._n_envs):
self._update_env_inner(copy.deepcopy(env_update), env_index)
def _update_env_inner(self, env_update, env_index):
"""Update a single environment.
Args:
env_update(gym.Env or EnvUpdate or None): The environment to
replace the existing env with. Note that other implementations
of `Worker` may take different types for this parameter.
env_index (int): Number of the environment to update.
Raises:
TypeError: If env_update is not one of the documented types.
"""
if isinstance(env_update, EnvUpdate):
self._envs[env_index] = env_update(self._envs[env_index])
self._needs_env_reset = True
elif isinstance(env_update, gym.Env):
if self._envs[env_index] is not None:
self._envs[env_index].close()
self._envs[env_index] = env_update
self._needs_env_reset = True
else:
raise TypeError('Unknown environment update type.')
[docs] def start_rollout(self):
"""Begin a new rollout."""
if self._needs_agent_reset or self._needs_env_reset:
n = len(self._envs)
self.agent.reset([True] * n)
if self._needs_env_reset:
self._prev_obs = np.asarray(
[env.reset() for env in self._envs])
else:
# Avoid calling reset on environments that are already at the
# start of a rollout.
for i, env in enumerate(self._envs):
if self._path_lengths[i] > 0:
self._prev_obs[i] = env.reset()
self._path_lengths = [0 for _ in range(n)]
self._observations = [[] for _ in range(n)]
self._actions = [[] for _ in range(n)]
self._rewards = [[] for _ in range(n)]
self._terminals = [[] for _ in range(n)]
self._env_infos = [collections.defaultdict(list) for _ in range(n)]
self._agent_infos = [
collections.defaultdict(list) for _ in range(n)
]
self._needs_agent_reset = False
self._needs_env_reset = False
def _gather_rollout(self, rollout_number, last_observation):
assert 0 < self._path_lengths[rollout_number] <= self._max_path_length
env_infos = self._env_infos[rollout_number]
agent_infos = self._agent_infos[rollout_number]
for k, v in env_infos.items():
env_infos[k] = np.asarray(v)
for k, v in agent_infos.items():
agent_infos[k] = np.asarray(v)
traj = TrajectoryBatch(
self._envs[rollout_number].spec,
np.asarray(self._observations[rollout_number]),
np.asarray([last_observation]),
np.asarray(self._actions[rollout_number]),
np.asarray(self._rewards[rollout_number]),
np.asarray(self._terminals[rollout_number]), dict(env_infos),
dict(agent_infos),
np.asarray([self._path_lengths[rollout_number]], dtype='l'))
self._completed_rollouts.append(traj)
self._observations[rollout_number] = []
self._actions[rollout_number] = []
self._rewards[rollout_number] = []
self._terminals[rollout_number] = []
self._path_lengths[rollout_number] = 0
self._prev_obs[rollout_number] = self._envs[rollout_number].reset()
self._env_infos[rollout_number] = collections.defaultdict(list)
self._agent_infos[rollout_number] = collections.defaultdict(list)
[docs] def step_rollout(self):
"""Take a single time-step in the current rollout.
Returns:
bool: True iff at least one of the paths was completed.
"""
finished = False
actions, agent_info = self.agent.get_actions(self._prev_obs)
completes = [False] * len(self._envs)
for i, action in enumerate(actions):
if self._path_lengths[i] < self._max_path_length:
next_o, r, d, env_info = self._envs[i].step(action)
self._observations[i].append(self._prev_obs[i])
self._rewards[i].append(r)
self._actions[i].append(actions[i])
for k, v in agent_info.items():
self._agent_infos[i][k].append(v[i])
for k, v in env_info.items():
self._env_infos[i][k].append(v)
self._path_lengths[i] += 1
self._terminals[i].append(d)
self._prev_obs[i] = next_o
if self._path_lengths[i] >= self._max_path_length or d:
self._gather_rollout(i, next_o)
completes[i] = True
finished = True
if finished:
self.agent.reset(completes)
return finished
[docs] def collect_rollout(self):
"""Collect all completed rollouts.
Returns:
garage.TrajectoryBatch: A batch of the trajectories completed since
the last call to collect_rollout().
"""
if len(self._completed_rollouts) == 1:
result = self._completed_rollouts[0]
else:
result = TrajectoryBatch.concatenate(*self._completed_rollouts)
self._completed_rollouts = []
return result
[docs] def shutdown(self):
"""Close the worker's environments."""
for env in self._envs:
env.close()