garage.sampler.ray_sampler module

This is an implementation of an on policy batch sampler.

Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.

class RaySampler(worker_factory, agents, envs)[source]

Bases: garage.sampler.sampler.Sampler

Collects Policy Rollouts in a data parallel fashion.

Parameters:
  • worker_factory (garage.sampler.WorkerFactory) – Used for worker behavior.
  • agents (list[garage.Policy]) – Agents to distribute across workers.
  • envs (list[gym.Env]) – Environments to distribute across workers.
classmethod from_worker_factory(worker_factory, agents, envs)[source]

Construct this sampler.

Parameters:
  • worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
  • agents (Agent or List[Agent]) – Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
  • envs (gym.Env or List[gym.Env]) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Returns:

An instance of cls.

Return type:

Sampler

obtain_exact_trajectories(n_traj_per_worker, agent_update, env_update=None)[source]

Sample an exact number of trajectories per worker.

Parameters:
  • n_traj_per_worker (int) – Exact number of trajectories to gather for each worker.
  • agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

Batch of gathered trajectories. Always in worker

order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc.

Return type:

TrajectoryBatch

obtain_samples(itr, num_samples, agent_update, env_update=None)[source]

Sample the policy for new trajectories.

Parameters:
  • itr (int) – Iteration number.
  • num_samples (int) – Number of steps the the sampler should collect.
  • agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
  • env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns:

Batch of gathered trajectories.

Return type:

TrajectoryBatch

shutdown_worker()[source]

Shuts down the worker.

start_worker()[source]

Initialize a new ray worker.

class SamplerWorker(worker_id, env, agent_pkl, worker_factory)[source]

Bases: object

Constructs a single sampler worker.

Parameters:
  • worker_id (int) – The id of the sampler_worker
  • env (gym.Env) – The gym env
  • agent_pkl (bytes) – The pickled agent
  • worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.
rollout()[source]

Compute one rollout of the agent in the environment.

Returns:Worker ID and batch of samples.
Return type:tuple[int, garage.TrajectoryBatch]
shutdown()[source]

Shuts down the worker.

update(agent_update, env_update)[source]

Update the agent and environment.

Parameters:
  • agent_update (object) – Agent update.
  • env_update (object) – Environment update.
Returns:

The worker id.

Return type:

int