Source code for garage.sampler.ray_sampler

"""This is an implementation of an on policy batch sampler.

Uses a data parallel design.
Included is a sampler that deploys sampler workers.
The sampler workers must implement some type of set agent parameters
function, and a rollout function.

"""
from collections import defaultdict
import itertools

import click
import cloudpickle
import ray

from garage import TrajectoryBatch
from garage.sampler.sampler import Sampler


[docs]class RaySampler(Sampler): """Collects Policy Rollouts in a data parallel fashion. Args: worker_factory(garage.sampler.WorkerFactory): Used for worker behavior. agents(list[garage.Policy]): Agents to distribute across workers. envs(list[gym.Env]): Environments to distribute across workers. """ def __init__(self, worker_factory, agents, envs): # pylint: disable=super-init-not-called if not ray.is_initialized(): ray.init(log_to_driver=False) self._sampler_worker = ray.remote(SamplerWorker) self._worker_factory = worker_factory self._agents = agents self._envs = self._worker_factory.prepare_worker_messages(envs) self._all_workers = defaultdict(None) self._workers_started = False self.start_worker()
[docs] @classmethod def from_worker_factory(cls, worker_factory, agents, envs): """Construct this sampler. Args: worker_factory(WorkerFactory): Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there. agents(Agent or List[Agent]): Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly `worker_factory.n_workers`, and will be spread across the workers. envs(gym.Env or List[gym.Env]): Environment rollouts are performed in. If a list is passed in, it must have length exactly `worker_factory.n_workers`, and will be spread across the workers. Returns: Sampler: An instance of `cls`. """ return cls(worker_factory, agents, envs)
[docs] def start_worker(self): """Initialize a new ray worker.""" if self._workers_started: return self._workers_started = True # We need to pickle the agent so that we can e.g. set up the TF.Session # in the worker *before* unpickling it. agent_pkls = self._worker_factory.prepare_worker_messages( self._agents, cloudpickle.dumps) for worker_id in range(self._worker_factory.n_workers): self._all_workers[worker_id] = self._sampler_worker.remote( worker_id, self._envs[worker_id], agent_pkls[worker_id], self._worker_factory)
def _update_workers(self, agent_update, env_update): """Update all of the workers. Args: agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: list[ray._raylet.ObjectID]: Remote values of worker ids. """ updating_workers = [] param_ids = self._worker_factory.prepare_worker_messages( agent_update, ray.put) env_ids = self._worker_factory.prepare_worker_messages( env_update, ray.put) for worker_id in range(self._worker_factory.n_workers): worker = self._all_workers[worker_id] updating_workers.append( worker.update.remote(param_ids[worker_id], env_ids[worker_id])) return updating_workers
[docs] def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Sample the policy for new trajectories. Args: itr(int): Iteration number. num_samples(int): Number of steps the the sampler should collect. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. """ active_workers = [] completed_samples = 0 batches = [] # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) with click.progressbar(length=num_samples, label='Sampling') as pbar: while completed_samples < num_samples: # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting trajectories on # those workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect trajectories # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected # trajectory if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, trajectory_batch = ray.get(result) idle_worker_ids.append(ready_worker_id) num_returned_samples = trajectory_batch.lengths.sum() completed_samples += num_returned_samples batches.append(trajectory_batch) pbar.update(num_returned_samples) return TrajectoryBatch.concatenate(*batches)
[docs] def obtain_exact_trajectories(self, n_traj_per_worker, agent_update, env_update=None): """Sample an exact number of trajectories per worker. Args: n_traj_per_worker (int): Exact number of trajectories to gather for each worker. agent_update(object): Value which will be passed into the `agent_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update(object): Value which will be passed into the `env_update_fn` before doing rollouts. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: TrajectoryBatch: Batch of gathered trajectories. Always in worker order. In other words, first all trajectories from worker 0, then all trajectories from worker 1, etc. """ active_workers = [] trajectories = defaultdict(list) # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) with click.progressbar(length=self._worker_factory.n_workers, label='Sampling') as pbar: while any( len(trajectories[i]) < n_traj_per_worker for i in range(self._worker_factory.n_workers)): # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting trajectories on # those workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect trajectories # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected # trajectory if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, trajectory_batch = ray.get(result) trajectories[ready_worker_id].append(trajectory_batch) if len(trajectories[ready_worker_id]) < n_traj_per_worker: idle_worker_ids.append(ready_worker_id) pbar.update(1) ordered_trajectories = list( itertools.chain(*[ trajectories[i] for i in range(self._worker_factory.n_workers) ])) return TrajectoryBatch.concatenate(*ordered_trajectories)
[docs] def shutdown_worker(self): """Shuts down the worker.""" for worker in self._all_workers.values(): worker.shutdown.remote() ray.shutdown()
[docs]class SamplerWorker: """Constructs a single sampler worker. Args: worker_id(int): The id of the sampler_worker env(gym.Env): The gym env agent_pkl(bytes): The pickled agent worker_factory(WorkerFactory): Factory to construct this worker's behavior. """ def __init__(self, worker_id, env, agent_pkl, worker_factory): # Must be called before pickle.loads below. self.inner_worker = worker_factory(worker_id) self.worker_id = worker_id self.inner_worker.update_env(env) self.inner_worker.update_agent(cloudpickle.loads(agent_pkl))
[docs] def update(self, agent_update, env_update): """Update the agent and environment. Args: agent_update(object): Agent update. env_update(object): Environment update. Returns: int: The worker id. """ self.inner_worker.update_agent(agent_update) self.inner_worker.update_env(env_update) return self.worker_id
[docs] def rollout(self): """Compute one rollout of the agent in the environment. Returns: tuple[int, garage.TrajectoryBatch]: Worker ID and batch of samples. """ return (self.worker_id, self.inner_worker.rollout())
[docs] def shutdown(self): """Shuts down the worker.""" self.inner_worker.shutdown()