garage.sampler.multiprocessing_sampler module

A multiprocessing sampler which avoids waiting as much as possible.

class MultiprocessingSampler(worker_factory, agents, envs)[source]

Bases: garage.sampler.sampler.Sampler

Sampler that uses multiprocessing to distribute workers.

Parameters:
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
  • agents (Agent or List[Agent]) – Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
  • envs (gym.Env or List[gym.Env]) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
classmethod from_worker_factory(worker_factory, agents, envs)[source]

Construct this sampler.

Parameters:
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
  • agents (Agent or List[Agent]) – Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
  • envs (gym.Env or List[gym.Env]) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Returns:

An instance of cls.

Return type:

Sampler

obtain_exact_trajectories(n_traj_per_worker, agent_update, env_update=None)[source]

Sample an exact number of trajectories per worker.

Parameters:
  • n_traj_per_worker (int) – Exact number of trajectories to gather for each worker.
  • agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

Batch of gathered trajectories. Always in worker

order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc.

Return type:

TrajectoryBatch

Raises:

AssertionError – On internal errors.

obtain_samples(itr, num_samples, agent_update, env_update=None)[source]

Collect at least a given number transitions (timesteps).

Parameters:
  • itr (int) – The current iteration number. Using this argument is deprecated.
  • num_samples (int) – Minimum number of transitions / timesteps to sample.
  • agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

The batch of collected trajectories.

Return type:

garage.TrajectoryBatch

Raises:

AssertionError – On internal errors.

shutdown_worker()[source]

Shutdown the workers.

run_worker(factory, to_worker, to_sampler, worker_number, agent, env)[source]

Run the streaming worker state machine.

Starts in the “not streaming” state. Enters the “streaming” state when the “start” or “continue” message is received. While in the “streaming” state, it streams rollouts back to the parent process. When it receives a “stop” message, or the queue back to the parent process is full, it enters the “not streaming” state. When it receives the “exit” message, it terminates.

Critically, the worker never blocks on sending messages back to the sampler, to ensure it remains responsive to messages.

Parameters:
  • factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
  • to_worker (multiprocessing.Queue) – Queue to send commands to the worker.
  • to_sampler (multiprocessing.Queue) – Queue to send rollouts back to the sampler.
  • worker_number (int) – Number of this worker.
  • agent (Agent) – Agent to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
  • env (gym.Env) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Raises:

AssertionError – On internal errors.