garage.sampler.ray_sampler
¶
This is an implementation of an on policy batch sampler.
Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.
-
class
RaySampler
(worker_factory, agents, envs)¶ Bases:
garage.sampler.sampler.Sampler
Samples episodes in a data-parallel fashion using a Ray cluster.
- Parameters
worker_factory (WorkerFactory) – Used for worker behavior.
agents (list[Policy]) – Agents to distribute across workers.
envs (list[Environment]) – Environments to distribute across workers.
-
classmethod
from_worker_factory
(cls, worker_factory, agents, envs)¶ Construct this sampler.
- Parameters
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- Returns
An instance of cls.
- Return type
-
start_worker
(self)¶ Initialize a new ray worker.
-
obtain_samples
(self, itr, num_samples, agent_update, env_update=None)¶ Sample the policy for new episodes.
- Parameters
itr (int) – Iteration number.
num_samples (int) – Number of steps the the sampler should collect.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
Batch of gathered episodes.
- Return type
-
obtain_exact_episodes
(self, n_eps_per_worker, agent_update, env_update=None)¶ Sample an exact number of episodes per worker.
- Parameters
n_eps_per_worker (int) – Exact number of episodes to gather for each worker.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
- Batch of gathered episodes. Always in worker
order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.
- Return type
-
shutdown_worker
(self)¶ Shuts down the worker.
-
class
SamplerWorker
(worker_id, env, agent_pkl, worker_factory)¶ Constructs a single sampler worker.
- Parameters
worker_id (int) – The ID of this worker.
env (Environment) – Environment to sample form.
agent_pkl (bytes) – Pickled
Policy
to sample with.worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.
-
update
(self, agent_update, env_update)¶ Update the agent and environment.
-
rollout
(self)¶ Sample one episode of the agent in the environment.
- Returns
Worker ID and batch of samples.
- Return type
-
shutdown
(self)¶ Shuts down the worker.