garage.tf.samplers.batch_sampler module

Collects samples in parallel using a stateful pool of workers.

class BatchSampler(algo, env, n_envs)[source]

Bases: garage.sampler.sampler_deprecated.BaseSampler

Collects samples in parallel using a stateful pool of workers.

Parameters:
obtain_samples(itr, batch_size=None, whole_paths=True)[source]

Collect samples for the given iteration number.

Parameters:
  • itr (int) – Number of iteration.
  • batch_size (int) – Number of environment steps in one batch.
  • whole_paths (bool) – Whether to use whole path or truncated.
Returns:

A list of paths.

Return type:

list[dict]

shutdown_worker()[source]

Terminate workers if necessary.

start_worker()[source]

Initialize the sampler.

worker_init_tf(g)[source]

Initialize the tf.Session on a worker.

Parameters:g (object) – Global state object.
worker_init_tf_vars(g)[source]

Initialize the policy parameters on a worker.

Parameters:g (object) – Global state object.