garage.sampler.worker_factory

Worker factory used by Samplers to construct Workers.

identity_function(value)[source]

Do nothing.

This function exists so it can be pickled.

Parameters

value (object) – A value.

Returns

The value.

Return type

object

class WorkerFactory(*, max_episode_length, is_tf_worker=False, seed=get_seed(), n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None)[source]

Constructs workers for Samplers.

The intent is that this object should be sufficient to avoid subclassing the sampler. Instead of subclassing the sampler for e.g. a specific backend, implement a specialized WorkerFactory (or specify appropriate functions to this one). Not that this object must be picklable, since it may be passed to workers. However, its fields individually need not be.

All arguments to this type must be passed by keyword.

Parameters
  • max_episode_length (int) – The maximum length episodes which will be sampled.

  • is_tf_worker (bool) – Whether it is workers for TFTrainer.

  • seed (int) – The seed to use to initialize random number generators.

  • n_workers (int) – The number of workers to use.

  • worker_class (type) – Class of the workers. Instances should implement the Worker interface.

  • worker_args (dict or None) – Additional arguments that should be passed to the worker.

prepare_worker_messages(self, objs, preprocess=identity_function)[source]

Take an argument and canonicalize it into a list for all workers.

This helper function is used to handle arguments in the sampler API which may (optionally) be lists. Specifically, these are agent, env, agent_update, and env_update. Checks that the number of parameters is correct.

Parameters
  • objs (object or list) – Must be either a single object or a list of length n_workers.

  • preprocess (function) – Function to call on each single object before creating the list.

Raises

ValueError – If a list is passed of a length other than n_workers.

Returns

A list of length self.n_workers.

Return type

List[object]