garage.sampler.ray_sampler module¶
This is an implementation of an on policy batch sampler.
Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.
-
class
RaySampler
(worker_factory, agents, envs)[source]¶ Bases:
garage.sampler.sampler.Sampler
Collects Policy Rollouts in a data parallel fashion.
Parameters: - worker_factory (garage.sampler.WorkerFactory) – Used for worker behavior.
- agents (list[garage.Policy]) – Agents to distribute across workers.
- envs (list[gym.Env]) – Environments to distribute across workers.
-
classmethod
from_worker_factory
(worker_factory, agents, envs)[source]¶ Construct this sampler.
Parameters: - worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
- agents (Agent or List[Agent]) – Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- envs (gym.Env or List[gym.Env]) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Returns: An instance of cls.
Return type:
-
obtain_exact_trajectories
(n_traj_per_worker, agent_update, env_update=None)[source]¶ Sample an exact number of trajectories per worker.
Parameters: - n_traj_per_worker (int) – Exact number of trajectories to gather for each worker.
- agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns: - Batch of gathered trajectories. Always in worker
order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc.
Return type:
-
obtain_samples
(itr, num_samples, agent_update, env_update=None)[source]¶ Sample the policy for new trajectories.
Parameters: - itr (int) – Iteration number.
- num_samples (int) – Number of steps the the sampler should collect.
- agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns: Batch of gathered trajectories.
Return type:
-
class
SamplerWorker
(worker_id, env, agent_pkl, worker_factory)[source]¶ Bases:
object
Constructs a single sampler worker.
Parameters: - worker_id (int) – The id of the sampler_worker
- env (gym.Env) – The gym env
- agent_pkl (bytes) – The pickled agent
- worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.
-
rollout
()[source]¶ Compute one rollout of the agent in the environment.
Returns: Worker ID and batch of samples. Return type: tuple[int, garage.TrajectoryBatch]