garage.sampler
¶
Samplers which run agents in environments.
- class InProgressEpisode(env, initial_observation=None, episode_info=None)¶
An in-progress episode.
Compared to EpisodeBatch, this datatype does less checking, only contains one episodes, and uses lists instead of numpy arrays to make stepping faster.
- Parameters
env (Environment) – The environment the trajectory is being collected in.
initial_observation (np.ndarray) – The first observation. If None, the environment will be reset to generate this observation.
episode_info (dict[str, np.ndarray]) – Info for this episode.
- Raises
ValueError – if either initial_observation and episode_info is passed in but not the other. Either both or neither should be passed in.
- property last_obs¶
The last observation in the epside.
- Type
np.ndarray
- step(action, agent_info)¶
Step the episode using an action from an agent.
- to_batch()¶
Convert this in-progress episode into a EpisodeBatch.
- Returns
This episode as a batch.
- Return type
- Raises
AssertionError – If this episode contains no time steps.
- class EnvUpdate¶
A callable that “updates” an environment.
Implementors of this interface can be called on environments to update them. The passed in environment should then be ignored, and the returned one used instead.
Since no new environment needs to be passed in, this type can also be used to construct new environments.
- class ExistingEnvUpdate(env)¶
Bases:
EnvUpdate
EnvUpdate
that carries an already constructed environment.- Parameters
env (Environment) – The environment.
- class NewEnvUpdate(env_constructor)¶
Bases:
EnvUpdate
EnvUpdate
that creates a new environment every update.- Parameters
env_constructor (Callable[Environment]) – Callable that constructs an environment.
- class SetTaskUpdate(env_type, task, wrapper_constructor)¶
Bases:
EnvUpdate
EnvUpdate
that calls set_task with the provided task.
- class FragmentWorker(*, seed, max_episode_length, worker_number, n_envs=DEFAULT_N_ENVS, timesteps_per_call=1)¶
Bases:
garage.sampler.default_worker.DefaultWorker
Vectorized Worker that collects partial episodes.
Useful for off-policy RL.
- Parameters
seed (int) – The seed to use to intialize random number generators.
max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.
of (length of fragments before they're transmitted out) –
worker_number (int) – The number of the worker this update is occurring in. This argument is used to set a different seed for each worker.
n_envs (int) – Number of environment copies to use.
timesteps_per_call (int) – Maximum number of timesteps to gather per env per call to the worker. Defaults to 1 (i.e. gather 1 timestep per env each call, or n_envs timesteps in total each call).
- DEFAULT_N_ENVS = 8¶
- update_env(env_update)¶
Update the environments.
If passed a list (inside this list passed to the Sampler itself), distributes the environments across the “vectorization” dimension.
- Parameters
env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.
- Raises
TypeError – If env_update is not one of the documented types.
ValueError – If the wrong number of updates is passed.
- start_episode()¶
Resets all agents if the environment was updated.
- step_episode()¶
Take a single time-step in the current episode.
- Returns
True iff at least one of the episodes was completed.
- Return type
- collect_episode()¶
Gather fragments from all in-progress episodes.
- Returns
A batch of the episode fragments.
- Return type
- rollout()¶
Sample a single episode of the agent in the environment.
- Returns
The collected episode.
- Return type
- shutdown()¶
Close the worker’s environments.
- worker_init()¶
Initialize a worker.
- update_agent(agent_update)¶
Update an agent, assuming it implements
Policy
.- Parameters
agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.
- class MultiprocessingSampler(agents, envs, *, worker_factory=None, max_episode_length=None, is_tf_worker=False, seed=get_seed(), n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)¶
Bases:
garage.sampler.sampler.Sampler
Sampler that uses multiprocessing to distribute workers.
The sampler need to be created either from a worker factory or from parameters which can construct a worker factory. See the __init__ method of WorkerFactory for the detail of these parameters.
- Parameters
agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there. Either this param or params after this are required to construct a sampler.
max_episode_length (int) – Params used to construct a worker factory. The maximum length episodes which will be sampled.
is_tf_worker (bool) – Whether it is workers for TFTrainer.
seed (int) – The seed to use to initialize random number generators.
n_workers (int) – The number of workers to use.
worker_class (type) – Class of the workers. Instances should implement the Worker interface.
worker_args (dict or None) – Additional arguments that should be passed to the worker.
- classmethod from_worker_factory(worker_factory, agents, envs)¶
Construct this sampler.
- Parameters
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- Returns
An instance of cls.
- Return type
- obtain_samples(itr, num_samples, agent_update, env_update=None)¶
Collect at least a given number transitions (timesteps).
- Parameters
itr (int) – The current iteration number. Using this argument is deprecated.
num_samples (int) – Minimum number of transitions / timesteps to sample.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
The batch of collected episodes.
- Return type
- Raises
AssertionError – On internal errors.
- obtain_exact_episodes(n_eps_per_worker, agent_update, env_update=None)¶
Sample an exact number of episodes per worker.
- Parameters
n_eps_per_worker (int) – Exact number of episodes to gather for each worker.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
- Batch of gathered episodes. Always in worker
order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.
- Return type
- Raises
AssertionError – On internal errors.
- shutdown_worker()¶
Shutdown the workers.
- start_worker()¶
Initialize the sampler.
i.e. launching parallel workers if necessary.
This method is deprecated, please launch workers in construct instead.
- class RaySampler(agents, envs, *, worker_factory=None, max_episode_length=None, is_tf_worker=False, seed=get_seed(), n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)¶
Bases:
garage.sampler.sampler.Sampler
Samples episodes in a data-parallel fashion using a Ray cluster.
The sampler need to be created either from a worker factory or from parameters which can construct a worker factory. See the __init__ method of WorkerFactory for the detail of these parameters.
- Parameters
agents (list[Policy]) – Agents to distribute across workers.
envs (list[Environment]) – Environments to distribute across workers.
worker_factory (WorkerFactory) – Used for worker behavior. Either this param or params after this are required to construct a sampler.
max_episode_length (int) – Params used to construct a worker factory. The maximum length episodes which will be sampled.
is_tf_worker (bool) – Whether it is workers for TFTrainer.
seed (int) – The seed to use to initialize random number generators.
n_workers (int) – The number of workers to use.
worker_class (type) – Class of the workers. Instances should implement the Worker interface.
worker_args (dict or None) – Additional arguments that should be passed to the worker.
- classmethod from_worker_factory(worker_factory, agents, envs)¶
Construct this sampler.
- Parameters
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
agents (Policy or List[Policy]) – Agent(s) to use to sample episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- Returns
An instance of cls.
- Return type
- start_worker()¶
Initialize a new ray worker.
- obtain_samples(itr, num_samples, agent_update, env_update=None)¶
Sample the policy for new episodes.
- Parameters
itr (int) – Iteration number.
num_samples (int) – Number of steps the the sampler should collect.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
Batch of gathered episodes.
- Return type
- obtain_exact_episodes(n_eps_per_worker, agent_update, env_update=None)¶
Sample an exact number of episodes per worker.
- Parameters
n_eps_per_worker (int) – Exact number of episodes to gather for each worker.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
- Batch of gathered episodes. Always in worker
order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc.
- Return type
- shutdown_worker()¶
Shuts down the worker.
- class Sampler(algo, env)¶
Bases:
abc.ABC
Abstract base class of all samplers.
Implementations of this class should override construct, obtain_samples, and shutdown_worker. construct takes a WorkerFactory, which implements most of the RL-specific functionality a Sampler needs. Specifically, it specifies how to construct `Worker`s, which know how to collect episodes and update both agents and environments.
Currently, __init__ is also part of the interface, but calling it is deprecated. start_worker is also deprecated, and does not need to be implemented.
- classmethod from_worker_factory(worker_factory, agents, envs)¶
Construct this sampler.
- Parameters
worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
agents (Policy or List[Policy]) – Agent(s) to use to collect episodes. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
envs (Environment or List[Environment]) – Environment from which episodes are sampled. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- Returns
An instance of cls.
- Return type
- start_worker()¶
Initialize the sampler.
i.e. launching parallel workers if necessary.
This method is deprecated, please launch workers in construct instead.
- abstract obtain_samples(itr, num_samples, agent_update, env_update=None)¶
Collect at least a given number transitions :class:`TimeStep`s.
- Parameters
itr (int) – The current iteration number. Using this argument is deprecated.
num_samples (int) – Minimum number of :class:`TimeStep`s to sample.
agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- Returns
The batch of collected episodes.
- Return type
- abstract shutdown_worker()¶
Terminate workers if necessary.
Because Python object destruction can be somewhat unpredictable, this method isn’t deprecated.
- class VecWorker(*, seed, max_episode_length, worker_number, n_envs=DEFAULT_N_ENVS)¶
Bases:
garage.sampler.default_worker.DefaultWorker
Worker with a single policy and multiple environments.
Alternates between taking a single step in all environments and asking the policy for an action for every environment. This allows computing a batch of actions, which is generally much more efficient than computing a single action when using neural networks.
- Parameters
seed (int) – The seed to use to intialize random number generators.
max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.
worker_number (int) – The number of the worker this update is occurring in. This argument is used set a different seed for each worker.
n_envs (int) – Number of environment copies to use.
- DEFAULT_N_ENVS = 8¶
- update_agent(agent_update)¶
Update an agent, assuming it implements
Policy
.- Parameters
agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.
- update_env(env_update)¶
Update the environments.
If passed a list (inside this list passed to the Sampler itself), distributes the environments across the “vectorization” dimension.
- Parameters
env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.
- Raises
TypeError – If env_update is not one of the documented types.
ValueError – If the wrong number of updates is passed.
- start_episode()¶
Begin a new episode.
- step_episode()¶
Take a single time-step in the current episode.
- Returns
True iff at least one of the episodes was completed.
- Return type
- collect_episode()¶
Collect all completed episodes.
- Returns
- A batch of the episodes completed since the last call
to collect_episode().
- Return type
- shutdown()¶
Close the worker’s environments.
- worker_init()¶
Initialize a worker.
- rollout()¶
Sample a single episode of the agent in the environment.
- Returns
The collected episode.
- Return type
- class Worker(*, seed, max_episode_length, worker_number)¶
Bases:
abc.ABC
Worker class used in all Samplers.
- update_agent(agent_update)¶
Update the worker’s agent, using agent_update.
- Parameters
agent_update (object) – An agent update. The exact type of this argument depends on the Worker implementation.
- update_env(env_update)¶
Update the worker’s env, using env_update.
- Parameters
env_update (object) – An environment update. The exact type of this argument depends on the Worker implementation.
- rollout()¶
Sample a single episode of the agent in the environment.
- Returns
- Batch of sampled episodes. May be truncated if
max_episode_length is set.
- Return type
- start_episode()¶
Begin a new episode.
- step_episode()¶
Take a single time-step in the current episode.
- Returns
True iff the episode is done, either due to the environment indicating termination of due to reaching max_episode_length.
- collect_episode()¶
Collect the current episode, clearing the internal buffer.
- Returns
- Batch of sampled episodes. May be truncated if the
episodes haven’t completed yet.
- Return type
- shutdown()¶
Shutdown the worker.
- class WorkerFactory(*, max_episode_length, is_tf_worker=False, seed=get_seed(), n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)¶
Constructs workers for Samplers.
The intent is that this object should be sufficient to avoid subclassing the sampler. Instead of subclassing the sampler for e.g. a specific backend, implement a specialized WorkerFactory (or specify appropriate functions to this one). Not that this object must be picklable, since it may be passed to workers. However, its fields individually need not be.
All arguments to this type must be passed by keyword.
- Parameters
max_episode_length (int) – The maximum length episodes which will be sampled.
is_tf_worker (bool) – Whether it is workers for TFTrainer.
seed (int) – The seed to use to initialize random number generators.
n_workers (int) – The number of workers to use.
worker_class (type) – Class of the workers. Instances should implement the Worker interface.
worker_args (dict or None) – Additional arguments that should be passed to the worker.
- prepare_worker_messages(objs, preprocess=identity_function)¶
Take an argument and canonicalize it into a list for all workers.
This helper function is used to handle arguments in the sampler API which may (optionally) be lists. Specifically, these are agent, env, agent_update, and env_update. Checks that the number of parameters is correct.
- Parameters
- Raises
ValueError – If a list is passed of a length other than n_workers.
- Returns
A list of length self.n_workers.
- Return type
List[object]