garage.sampler

Samplers which run agents in environments.

class InProgressEpisode(env, initial_observation=None)

An in-progress episode.

Compared to EpisodeBatch, this datatype does less checking, only contains one episodes, and uses lists instead of numpy arrays to make stepping faster.

Parameters
  • env (Environment) – The environment the trajectory is being collected in.

  • initial_observation (np.ndarray) – The first observation. If None, the environment will be reset to generate this observation.

step(self, action, agent_info)

Step the episode using an action from an agent.

Parameters
  • action (np.ndarray) – The action taken by the agent.

  • agent_info (dict[str, np.ndarray]) – Extra agent information.

Returns

The new observation from the environment.

Return type

np.ndarray

to_batch(self)

Convert this in-progress episode into a EpisodeBatch.

Returns

This episode as a batch.

Return type

EpisodeBatch

Raises

AssertionError – If this episode contains no time steps.

property last_obs(self)

np.ndarray: The last observation in the epside.

class EnvUpdate

A callable that “updates” an environment.

Implementors of this interface can be called on environments to update them. The passed in environment should then be ignored, and the returned one used instead.

Since no new environment needs to be passed in, this type can also be used to construct new environments.

class ExistingEnvUpdate(env)

Bases: garage.sampler.env_update.EnvUpdate

Inheritance diagram of garage.sampler.ExistingEnvUpdate

EnvUpdate that carries an already constructed environment.

Parameters

env (Environment) – The environment.

class NewEnvUpdate(env_constructor)

Bases: garage.sampler.env_update.EnvUpdate

Inheritance diagram of garage.sampler.NewEnvUpdate

EnvUpdate that creates a new environment every update.

Parameters

env_constructor (Callable[Environment]) – Callable that constructs an environment.

class SetTaskUpdate(env_type, task, wrapper_constructor)

Bases: garage.sampler.env_update.EnvUpdate

Inheritance diagram of garage.sampler.SetTaskUpdate

EnvUpdate that calls set_task with the provided task.

Parameters
  • env_type (type) – Type of environment.

  • task (object) – Opaque task type.

  • wrapper_constructor (Callable[garage.Env, garage.Env] or None) – Callable that wraps constructed environments.

class FragmentWorker(*, seed, max_episode_length, worker_number, n_envs=DEFAULT_N_ENVS, timesteps_per_call=1)

Bases: garage.sampler.default_worker.DefaultWorker

Inheritance diagram of garage.sampler.FragmentWorker

Vectorized Worker that collects partial episodes.

Useful for off-policy RL.

Parameters
  • seed (int) – The seed to use to intialize random number generators.

  • max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.

  • of fragments before they're transmitted out of (length) –

  • worker_number (int) – The number of the worker this update is occurring in. This argument is used to set a different seed for each worker.

  • n_envs (int) – Number of environment copies to use.

  • timesteps_per_call (int) – Maximum number of timesteps to gather per env per call to the worker. Defaults to 1 (i.e. gather 1 timestep per env each call, or n_envs timesteps in total each call).

DEFAULT_N_ENVS = 8
update_env(self, env_update)

Update the environments.

If passed a list (inside this list passed to the Sampler itself), distributes the environments across the “vectorization” dimension.

Parameters

env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.

Raises
  • TypeError – If env_update is not one of the documented types.

  • ValueError – If the wrong number of updates is passed.

start_episode(self)

Resets all agents if the environment was updated.

step_episode(self)

Take a single time-step in the current episode.

Returns

True iff at least one of the episodes was completed.

Return type

bool

collect_episode(self)

Gather fragments from all in-progress episodes.

Returns

A batch of the episode fragments.

Return type

EpisodeBatch

rollout(self)

Sample a single episode of the agent in the environment.

Returns

The collected episode.

Return type

EpisodeBatch

shutdown(self)

Close the worker’s environments.

worker_init(self)

Initialize a worker.

update_agent(self, agent_update)

Update an agent, assuming it implements Policy.

Parameters

agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.

class LocalSampler(worker_factory, agents, envs)

Bases: garage.sampler.sampler.Sampler

Inheritance diagram of garage.sampler.LocalSampler

Sampler that runs workers in the main process.

This is probably the simplest possible sampler. It’s called the “Local” sampler because it runs everything in the same process and thread as where it was called from.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Agent or List[Agent]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

Returns

An instance of cls.

Return type

Sampler

obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Collect at least a given number transitions (timesteps).

Parameters
  • itr (int) – The current iteration number. Using this argument is deprecated.

  • num_samples (int) – Minimum number of transitions / timesteps to sample.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

The batch of collected episodes.

Return type

EpisodeBatch

obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None)

Sample an exact number of episodes per worker.

Parameters
  • n_eps_per_worker (int) – Exact number of episodes to gather for each worker.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before samplin episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes. Always in worker

order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.

Return type

EpisodeBatch

shutdown_worker(self)

Shutdown the workers.

start_worker(self)

Initialize the sampler.

i.e. launching parallel workers if necessary.

This method is deprecated, please launch workers in construct instead.

class MultiprocessingSampler(worker_factory, agents, envs)

Bases: garage.sampler.sampler.Sampler

Inheritance diagram of garage.sampler.MultiprocessingSampler

Sampler that uses multiprocessing to distribute workers.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

Returns

An instance of cls.

Return type

Sampler

obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Collect at least a given number transitions (timesteps).

Parameters
  • itr (int) – The current iteration number. Using this argument is deprecated.

  • num_samples (int) – Minimum number of transitions / timesteps to sample.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

The batch of collected episodes.

Return type

EpisodeBatch

Raises

AssertionError – On internal errors.

obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None)

Sample an exact number of episodes per worker.

Parameters
  • n_eps_per_worker (int) – Exact number of episodes to gather for each worker.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes. Always in worker

order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.

Return type

EpisodeBatch

Raises

AssertionError – On internal errors.

shutdown_worker(self)

Shutdown the workers.

start_worker(self)

Initialize the sampler.

i.e. launching parallel workers if necessary.

This method is deprecated, please launch workers in construct instead.

class RaySampler(worker_factory, agents, envs)

Bases: garage.sampler.sampler.Sampler

Inheritance diagram of garage.sampler.RaySampler

Samples episodes in a data-parallel fashion using a Ray cluster.

Parameters
  • worker_factory (WorkerFactory) – Used for worker behavior.

  • agents (list[Policy]) – Agents to distribute across workers.

  • envs (list[Environment]) – Environments to distribute across workers.

classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

Returns

An instance of cls.

Return type

Sampler

start_worker(self)

Initialize a new ray worker.

obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Sample the policy for new episodes.

Parameters
  • itr (int) – Iteration number.

  • num_samples (int) – Number of steps the the sampler should collect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes.

Return type

EpisodeBatch

obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None)

Sample an exact number of episodes per worker.

Parameters
  • n_eps_per_worker (int) – Exact number of episodes to gather for each worker.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes. Always in worker

order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.

Return type

EpisodeBatch

shutdown_worker(self)

Shuts down the worker.

class Sampler(algo, env)

Bases: abc.ABC

Inheritance diagram of garage.sampler.Sampler

Abstract base class of all samplers.

Implementations of this class should override construct, obtain_samples, and shutdown_worker. construct takes a WorkerFactory, which implements most of the RL-specific functionality a Sampler needs. Specifically, it specifies how to construct `Worker`s, which know how to collect episodes and update both agents and environments.

Currently, __init__ is also part of the interface, but calling it is deprecated. start_worker is also deprecated, and does not need to be implemented.

classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to collect episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

Returns

An instance of cls.

Return type

Sampler

start_worker(self)

Initialize the sampler.

i.e. launching parallel workers if necessary.

This method is deprecated, please launch workers in construct instead.

abstract obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Collect at least a given number transitions :class:`TimeStep`s.

Parameters
  • itr (int) – The current iteration number. Using this argument is deprecated.

  • num_samples (int) – Minimum number of :class:`TimeStep`s to sample.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

The batch of collected episodes.

Return type

EpisodeBatch

abstract shutdown_worker(self)

Terminate workers if necessary.

Because Python object destruction can be somewhat unpredictable, this method isn’t deprecated.

class VecWorker(*, seed, max_episode_length, worker_number, n_envs=DEFAULT_N_ENVS)

Bases: garage.sampler.default_worker.DefaultWorker

Inheritance diagram of garage.sampler.VecWorker

Worker with a single policy and multiple environments.

Alternates between taking a single step in all environments and asking the policy for an action for every environment. This allows computing a batch of actions, which is generally much more efficient than computing a single action when using neural networks.

Parameters
  • seed (int) – The seed to use to intialize random number generators.

  • max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.

  • worker_number (int) – The number of the worker this update is occurring in. This argument is used set a different seed for each worker.

  • n_envs (int) – Number of environment copies to use.

DEFAULT_N_ENVS = 8
update_agent(self, agent_update)

Update an agent, assuming it implements Policy.

Parameters

agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.

update_env(self, env_update)

Update the environments.

If passed a list (inside this list passed to the Sampler itself), distributes the environments across the “vectorization” dimension.

Parameters

env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.

Raises
  • TypeError – If env_update is not one of the documented types.

  • ValueError – If the wrong number of updates is passed.

start_episode(self)

Begin a new episode.

step_episode(self)

Take a single time-step in the current episode.

Returns

True iff at least one of the episodes was completed.

Return type

bool

collect_episode(self)

Collect all completed episodes.

Returns

A batch of the episodes completed since the last call

to collect_episode().

Return type

EpisodeBatch

shutdown(self)

Close the worker’s environments.

worker_init(self)

Initialize a worker.

rollout(self)

Sample a single episode of the agent in the environment.

Returns

The collected episode.

Return type

EpisodeBatch

class Worker(*, seed, max_episode_length, worker_number)

Bases: abc.ABC

Inheritance diagram of garage.sampler.Worker

Worker class used in all Samplers.

update_agent(self, agent_update)

Update the worker’s agent, using agent_update.

Parameters

agent_update (object) – An agent update. The exact type of this argument depends on the Worker implementation.

update_env(self, env_update)

Update the worker’s env, using env_update.

Parameters

env_update (object) – An environment update. The exact type of this argument depends on the Worker implementation.

rollout(self)

Sample a single episode of the agent in the environment.

Returns

Batch of sampled episodes. May be truncated if

max_episode_length is set.

Return type

EpisodeBatch

start_episode(self)

Begin a new episode.

step_episode(self)

Take a single time-step in the current episode.

Returns

True iff the episode is done, either due to the environment indicating termination of due to reaching max_episode_length.

collect_episode(self)

Collect the current episode, clearing the internal buffer.

Returns

Batch of sampled episodes. May be truncated if the

episodes haven’t completed yet.

Return type

EpisodeBatch

shutdown(self)

Shutdown the worker.

class WorkerFactory(*, seed, max_episode_length, n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)

Constructs workers for Samplers.

The intent is that this object should be sufficient to avoid subclassing the sampler. Instead of subclassing the sampler for e.g. a specific backend, implement a specialized WorkerFactory (or specify appropriate functions to this one). Not that this object must be picklable, since it may be passed to workers. However, its fields individually need not be.

All arguments to this type must be passed by keyword.

Parameters
  • seed (int) – The seed to use to intialize random number generators.

  • n_workers (int) – The number of workers to use.

  • max_episode_length (int) – The maximum length episodes which will be sampled.

  • worker_class (type) – Class of the workers. Instances should implement the Worker interface.

  • worker_args (dict or None) – Additional arguments that should be passed to the worker.

prepare_worker_messages(self, objs, preprocess=identity_function)

Take an argument and canonicalize it into a list for all workers.

This helper function is used to handle arguments in the sampler API which may (optionally) be lists. Specifically, these are agent, env, agent_update, and env_update. Checks that the number of parameters is correct.

Parameters
  • objs (object or list) – Must be either a single object or a list of length n_workers.

  • preprocess (function) – Function to call on each single object before creating the list.

Raises

ValueError – If a list is passed of a length other than n_workers.

Returns

A list of length self.n_workers.

Return type

List[object]