garage.sampler.ray_sampler

This is an implementation of an on policy batch sampler.

Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.

class RaySampler(worker_factory, agents, envs)

Bases: garage.sampler.sampler.Sampler

Inheritance diagram of garage.sampler.ray_sampler.RaySampler

Samples episodes in a data-parallel fashion using a Ray cluster.

Parameters
  • worker_factory (WorkerFactory) – Used for worker behavior.

  • agents (list[Policy]) – Agents to distribute across workers.

  • envs (list[Environment]) – Environments to distribute across workers.

classmethod from_worker_factory(cls, worker_factory, agents, envs)

Construct this sampler.

Parameters
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.

  • agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

  • envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.

Returns

An instance of cls.

Return type

Sampler

start_worker(self)

Initialize a new ray worker.

obtain_samples(self, itr, num_samples, agent_update, env_update=None)

Sample the policy for new episodes.

Parameters
  • itr (int) – Iteration number.

  • num_samples (int) – Number of steps the the sampler should collect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes.

Return type

EpisodeBatch

obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None)

Sample an exact number of episodes per worker.

Parameters
  • n_eps_per_worker (int) – Exact number of episodes to gather for each worker.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Returns

Batch of gathered episodes. Always in worker

order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.

Return type

EpisodeBatch

shutdown_worker(self)

Shuts down the worker.

class SamplerWorker(worker_id, env, agent_pkl, worker_factory)

Constructs a single sampler worker.

Parameters
  • worker_id (int) – The ID of this worker.

  • env (Environment) – Environment to sample form.

  • agent_pkl (bytes) – Pickled Policy to sample with.

  • worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.

update(self, agent_update, env_update)

Update the agent and environment.

Parameters
  • agent_update (object) – Agent update.

  • env_update (object) – Environment update.

Returns

The worker id.

Return type

int

rollout(self)

Sample one episode of the agent in the environment.

Returns

Worker ID and batch of samples.

Return type

tuple[int, EpisodeBatch]

shutdown(self)

Shuts down the worker.