garage.sampler.ray_sampler

This is an implementation of an on policy batch sampler.

Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.

class RaySampler(worker_factory, agents, envs)

Bases: garage.sampler.sampler.Sampler

Inheritance diagram of garage.sampler.ray_sampler.RaySampler

Samples episodes in a data-parallel fashion using a Ray cluster.

Parameters:
  • worker_factory (WorkerFactory) – Used for worker behavior.
  • agents (list[Policy]) – Agents to distribute across workers.
  • envs (list[Environment]) – Environments to distribute across workers.
classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters:
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Returns:

An instance of cls.

Return type:

Sampler

start_worker(self)

Initialize a new ray worker.

obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Sample the policy for new episodes.

Parameters:
  • itr (int) – Iteration number.
  • num_samples (int) – Number of steps the the sampler should collect.
  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

Batch of gathered episodes.

Return type:

EpisodeBatch

obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None)

Sample an exact number of episodes per worker.

Parameters:
  • n_eps_per_worker (int) – Exact number of episodes to gather for each worker.
  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

Batch of gathered episodes. Always in worker

order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.

Return type:

EpisodeBatch

shutdown_worker(self)

Shuts down the worker.

class SamplerWorker(worker_id, env, agent_pkl, worker_factory)

Constructs a single sampler worker.

Parameters:
  • worker_id (int) – The ID of this worker.
  • env (Environment) – Environment to sample form.
  • agent_pkl (bytes) – Pickled Policy to sample with.
  • worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.
update(self, agent_update, env_update)

Update the agent and environment.

Parameters:
  • agent_update (object) – Agent update.
  • env_update (object) – Environment update.
Returns:

The worker id.

Return type:

int

rollout(self)

Sample one episode of the agent in the environment.

Returns:Worker ID and batch of samples.
Return type:tuple[int, EpisodeBatch]
shutdown(self)

Shuts down the worker.