Source code for garage.sampler.vec_worker

"""Worker that "vectorizes" environments."""
import collections
import copy

import gym
import numpy as np

from garage import TrajectoryBatch
from garage.sampler.default_worker import DefaultWorker
from garage.sampler.env_update import EnvUpdate


[docs]class VecWorker(DefaultWorker): """Worker with a single policy and multiple environemnts. Alternates between taking a single step in all environments and asking the policy for an action for every environment. This allows computing a batch of actions, which is generally much more efficient than computing a single action when using neural networks. Args: seed(int): The seed to use to intialize random number generators. max_path_length(int or float): The maximum length paths which will be sampled. Can be (floating point) infinity. worker_number(int): The number of the worker this update is occurring in. This argument is used to set a different seed for each worker. n_envs (int): Number of environment copies to use. """ DEFAULT_N_ENVS = 8 def __init__(self, *, seed, max_path_length, worker_number, n_envs=DEFAULT_N_ENVS): super().__init__(seed=seed, max_path_length=max_path_length, worker_number=worker_number) self._n_envs = n_envs self._completed_rollouts = [] self._needs_agent_reset = True self._needs_env_reset = True self._envs = [None] * n_envs self._agents = [None] * n_envs self._path_lengths = [0] * self._n_envs
[docs] def update_agent(self, agent_update): """Update an agent, assuming it implements garage.Policy. Args: agent_update (np.ndarray or dict or garage.Policy): If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling `policy.get_param_values`. Alternatively, a policy itself. Note that other implementations of `Worker` may take different types for this parameter. """ super().update_agent(agent_update) self._needs_agent_reset = True
[docs] def update_env(self, env_update): """Use any non-None env_update as a new environment. A simple env update function. If env_update is not None, it should be the complete new environment. This allows changing environments by passing the new environment as `env_update` into `obtain_samples`. Args: env_update(gym.Env or EnvUpdate or None): The environment to replace the existing env with. Note that other implementations of `Worker` may take different types for this parameter. Raises: TypeError: If env_update is not one of the documented types. ValueError: If the wrong number of updates is passed. """ if isinstance(env_update, list): if len(env_update) != self._n_envs: raise ValueError('If separate environments are passed for ' 'each worker, there must be exactly n_envs ' '({}) environments, but received {} ' 'environments.'.format( self._n_envs, len(env_update))) for env_index, env_up in enumerate(env_update): self._update_env_inner(env_up, env_index) elif env_update is not None: for env_index in range(self._n_envs): self._update_env_inner(copy.deepcopy(env_update), env_index)
def _update_env_inner(self, env_update, env_index): """Update a single environment. Args: env_update(gym.Env or EnvUpdate or None): The environment to replace the existing env with. Note that other implementations of `Worker` may take different types for this parameter. env_index (int): Number of the environment to update. Raises: TypeError: If env_update is not one of the documented types. """ if isinstance(env_update, EnvUpdate): self._envs[env_index] = env_update(self._envs[env_index]) self._needs_env_reset = True elif isinstance(env_update, gym.Env): if self._envs[env_index] is not None: self._envs[env_index].close() self._envs[env_index] = env_update self._needs_env_reset = True else: raise TypeError('Unknown environment update type.')
[docs] def start_rollout(self): """Begin a new rollout.""" if self._needs_agent_reset or self._needs_env_reset: n = len(self._envs) self.agent.reset([True] * n) if self._needs_env_reset: self._prev_obs = np.asarray( [env.reset() for env in self._envs]) else: # Avoid calling reset on environments that are already at the # start of a rollout. for i, env in enumerate(self._envs): if self._path_lengths[i] > 0: self._prev_obs[i] = env.reset() self._path_lengths = [0 for _ in range(n)] self._observations = [[] for _ in range(n)] self._actions = [[] for _ in range(n)] self._rewards = [[] for _ in range(n)] self._terminals = [[] for _ in range(n)] self._env_infos = [collections.defaultdict(list) for _ in range(n)] self._agent_infos = [ collections.defaultdict(list) for _ in range(n) ] self._needs_agent_reset = False self._needs_env_reset = False
def _gather_rollout(self, rollout_number, last_observation): assert 0 < self._path_lengths[rollout_number] <= self._max_path_length env_infos = self._env_infos[rollout_number] agent_infos = self._agent_infos[rollout_number] for k, v in env_infos.items(): env_infos[k] = np.asarray(v) for k, v in agent_infos.items(): agent_infos[k] = np.asarray(v) traj = TrajectoryBatch( self._envs[rollout_number].spec, np.asarray(self._observations[rollout_number]), np.asarray([last_observation]), np.asarray(self._actions[rollout_number]), np.asarray(self._rewards[rollout_number]), np.asarray(self._terminals[rollout_number]), dict(env_infos), dict(agent_infos), np.asarray([self._path_lengths[rollout_number]], dtype='l')) self._completed_rollouts.append(traj) self._observations[rollout_number] = [] self._actions[rollout_number] = [] self._rewards[rollout_number] = [] self._terminals[rollout_number] = [] self._path_lengths[rollout_number] = 0 self._prev_obs[rollout_number] = self._envs[rollout_number].reset() self._env_infos[rollout_number] = collections.defaultdict(list) self._agent_infos[rollout_number] = collections.defaultdict(list)
[docs] def step_rollout(self): """Take a single time-step in the current rollout. Returns: bool: True iff at least one of the paths was completed. """ finished = False actions, agent_info = self.agent.get_actions(self._prev_obs) completes = [False] * len(self._envs) for i, action in enumerate(actions): if self._path_lengths[i] < self._max_path_length: next_o, r, d, env_info = self._envs[i].step(action) self._observations[i].append(self._prev_obs[i]) self._rewards[i].append(r) self._actions[i].append(actions[i]) for k, v in agent_info.items(): self._agent_infos[i][k].append(v[i]) for k, v in env_info.items(): self._env_infos[i][k].append(v) self._path_lengths[i] += 1 self._terminals[i].append(d) self._prev_obs[i] = next_o if self._path_lengths[i] >= self._max_path_length or d: self._gather_rollout(i, next_o) completes[i] = True finished = True if finished: self.agent.reset(completes) return finished
[docs] def collect_rollout(self): """Collect all completed rollouts. Returns: garage.TrajectoryBatch: A batch of the trajectories completed since the last call to collect_rollout(). """ if len(self._completed_rollouts) == 1: result = self._completed_rollouts[0] else: result = TrajectoryBatch.concatenate(*self._completed_rollouts) self._completed_rollouts = [] return result
[docs] def shutdown(self): """Close the worker's environments.""" for env in self._envs: env.close()