Source code for garage.sampler.fragment_worker

"""Worker that "vectorizes" environments."""
import copy

import numpy as np

from garage import EpisodeBatch, StepType
from garage.sampler import _apply_env_update, InProgressEpisode
from garage.sampler.default_worker import DefaultWorker


[docs]class FragmentWorker(DefaultWorker): """Vectorized Worker that collects partial episodes. Useful for off-policy RL. Args: seed (int): The seed to use to intialize random number generators. max_episode_length (int or float): The maximum length of episodes which will be sampled. Can be (floating point) infinity. length of fragments before they're transmitted out of worker_number (int): The number of the worker this update is occurring in. This argument is used to set a different seed for each worker. n_envs (int): Number of environment copies to use. timesteps_per_call (int): Maximum number of timesteps to gather per env per call to the worker. Defaults to 1 (i.e. gather 1 timestep per env each call, or n_envs timesteps in total each call). """ DEFAULT_N_ENVS = 8 def __init__(self, *, seed, max_episode_length, worker_number, n_envs=DEFAULT_N_ENVS, timesteps_per_call=1): super().__init__(seed=seed, max_episode_length=max_episode_length, worker_number=worker_number) self._n_envs = n_envs self._timesteps_per_call = timesteps_per_call self._needs_env_reset = True self._envs = [None] * n_envs self._agents = [None] * n_envs self._episode_lengths = [0] * self._n_envs self._complete_fragments = [] # Initialized in start_episode self._fragments = None
[docs] def update_env(self, env_update): """Update the environments. If passed a list (*inside* this list passed to the Sampler itself), distributes the environments across the "vectorization" dimension. Args: env_update(Environment or EnvUpdate or None): The environment to replace the existing env with. Note that other implementations of `Worker` may take different types for this parameter. Raises: TypeError: If env_update is not one of the documented types. ValueError: If the wrong number of updates is passed. """ if isinstance(env_update, list): if len(env_update) != self._n_envs: raise ValueError('If separate environments are passed for ' 'each worker, there must be exactly n_envs ' '({}) environments, but received {} ' 'environments.'.format( self._n_envs, len(env_update))) elif env_update is not None: env_update = [ copy.deepcopy(env_update) for _ in range(self._n_envs) ] if env_update: for env_index, env_up in enumerate(env_update): self._envs[env_index], up = _apply_env_update( self._envs[env_index], env_up) self._needs_env_reset |= up
[docs] def start_episode(self): """Resets all agents if the environment was updated.""" if self._needs_env_reset: self._needs_env_reset = False self.agent.reset([True] * len(self._envs)) self._episode_lengths = [0] * len(self._envs) self._fragments = [InProgressEpisode(env) for env in self._envs]
[docs] def step_episode(self): """Take a single time-step in the current episode. Returns: bool: True iff at least one of the episodes was completed. """ prev_obs = np.asarray([frag.last_obs for frag in self._fragments]) actions, agent_infos = self.agent.get_actions(prev_obs) completes = [False] * len(self._envs) for i, action in enumerate(actions): frag = self._fragments[i] if self._episode_lengths[i] < self._max_episode_length: agent_info = {k: v[i] for (k, v) in agent_infos.items()} frag.step(action, agent_info) self._episode_lengths[i] += 1 if (self._episode_lengths[i] >= self._max_episode_length or frag.step_types[-1] == StepType.TERMINAL): self._episode_lengths[i] = 0 complete_frag = frag.to_batch() self._complete_fragments.append(complete_frag) self._fragments[i] = InProgressEpisode(self._envs[i]) completes[i] = True if any(completes): self.agent.reset(completes) return any(completes)
[docs] def collect_episode(self): """Gather fragments from all in-progress episodes. Returns: EpisodeBatch: A batch of the episode fragments. """ for i, frag in enumerate(self._fragments): assert frag.env is self._envs[i] if len(frag.rewards) > 0: complete_frag = frag.to_batch() self._complete_fragments.append(complete_frag) self._fragments[i] = InProgressEpisode(frag.env, frag.last_obs, frag.episode_info) assert len(self._complete_fragments) > 0 result = EpisodeBatch.concatenate(*self._complete_fragments) self._complete_fragments = [] return result
[docs] def rollout(self): """Sample a single episode of the agent in the environment. Returns: EpisodeBatch: The collected episode. """ self.start_episode() for _ in range(self._timesteps_per_call): self.step_episode() complete_frag = self.collect_episode() return complete_frag
[docs] def shutdown(self): """Close the worker's environments.""" for env in self._envs: env.close()