garage.sampler.sampler module¶
Base sampler class.
-
class
Sampler
(algo, env)[source]¶ Bases:
abc.ABC
Abstract base class of all samplers.
Implementations of this class should override construct, obtain_samples, and shutdown_worker. construct takes a WorkerFactory, which implements most of the RL-specific functionality a Sampler needs. Specifically, it specifies how to construct `Worker`s, which know how to perform rollouts and update both agents and environments.
Currently, __init__ is also part of the interface, but calling it is deprecated. start_worker is also deprecated, and does not need to be implemented.
-
classmethod
from_worker_factory
(worker_factory, agents, envs)[source]¶ Construct this sampler.
Parameters: - worker_factory (WorkerFactory) – Pickleable factory for creating workers. Should be transmitted to other processes / nodes where work needs to be done, then workers should be constructed there.
- agents (Agent or List[Agent]) – Agent(s) to use to perform rollouts. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
- envs (gym.Env or List[gym.Env]) – Environment rollouts are performed in. If a list is passed in, it must have length exactly worker_factory.n_workers, and will be spread across the workers.
Returns: An instance of cls.
Return type:
-
obtain_samples
(itr, num_samples, agent_update, env_update=None)[source]¶ Collect at least a given number transitions (timesteps).
Parameters: - itr (int) – The current iteration number. Using this argument is deprecated.
- num_samples (int) – Minimum number of transitions / timesteps to sample.
- agent_update (object) – Value which will be passed into the agent_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
- env_update (object) – Value which will be passed into the env_update_fn before doing rollouts. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.
Returns: The batch of collected trajectories.
Return type:
-
classmethod