garage.sampler.ray_sampler
¶
This is an implementation of an on policy batch sampler.
Uses a data parallel design. Included is a sampler that deploys sampler workers. The sampler workers must implement some type of set agent parameters function, and a rollout function.
- class RaySampler(agents, envs, *, worker_factory=None, max_episode_length=None, is_tf_worker=False, seed=get_seed(), n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)¶
Bases:
garage.sampler.sampler.Sampler
Samples episodes in a data-parallel fashion using a Ray cluster.
The sampler need to be created either from a worker factory or from parameters which can construct a worker factory. See the __init__ method of WorkerFactory for the detail of these parameters.
- Parameters
agents (list[Policy]) – Agents to distribute across workers.
envs (list[Environment]) – Environments to distribute across workers.
worker_factory (WorkerFactory) – Used for worker behavior. Either this param or params after this are required to construct a sampler.
max_episode_length (int) – Params used to construct a worker factory. The maximum length episodes which will be sampled.
is_tf_worker (bool) – Whether it is workers for TFTrainer.
seed (int) – The seed to use to initialize random number generators.
n_workers (int) – The number of workers to use.
worker_class (type) – Class of the workers. Instances should implement the Worker interface.
worker_args (dict or None) – Additional arguments that should be passed to the worker.
- classmethod from_worker_factory(worker_factory, agents, envs)¶
Construct this sampler.
- Parameters
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- Returns
An instance of cls.
- Return type
- start_worker()¶
Initialize a new ray worker.
- obtain_samples(itr, num_samples, agent_update, env_update=None)¶
Sample the policy for new episodes.
- Parameters
itr (int) – Iteration number.
num_samples (int) – Number of steps the the sampler should collect.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
Batch of gathered episodes.
- Return type
- obtain_exact_episodes(n_eps_per_worker, agent_update, env_update=None)¶
Sample an exact number of episodes per worker.
- Parameters
n_eps_per_worker (int) – Exact number of episodes to gather for each worker.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
- Batch of gathered episodes. Always in worker
order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.
- Return type
- shutdown_worker()¶
Shuts down the worker.
- class SamplerWorker(worker_id, env, agent_pkl, worker_factory)¶
Constructs a single sampler worker.
- Parameters
worker_id (int) – The ID of this worker.
env (Environment) – Environment to sample form.
agent_pkl (bytes) – Pickled
Policy
to sample with.worker_factory (WorkerFactory) – Factory to construct this worker’s behavior.
- update(agent_update, env_update)¶
Update the agent and environment.
- rollout()¶
Sample one episode of the agent in the environment.
- Returns
Worker ID and batch of samples.
- Return type
- shutdown()¶
Shuts down the worker.