garage.tf.algos.te module

Task Embedding Algorithm.

class TaskEmbeddingWorker(*, seed, max_path_length, worker_number)[source]

Bases: garage.sampler.default_worker.DefaultWorker

A sampler worker for Task Embedding Algorithm.

In addition to DefaultWorker, this worker adds one-hot task id to env_info, and adds latent and latent infos to agent_info.

Parameters:
  • seed (int) – The seed to use to intialize random number generators.
  • max_path_length (int or float) – The maximum length paths which will be sampled. Can be (floating point) infinity.
  • worker_number (int) – The number of the worker where this update is occurring. This argument is used to set a different seed for each worker.
agent

The worker’s agent.

Type:Policy or None
env

The worker’s environment.

Type:gym.Env or None
collect_rollout()[source]

Collect the current rollout, clearing the internal buffer.

One-hot task id is saved in env_infos[‘task_onehot’]. Latent is saved in agent_infos[‘latent’]. Latent infos are saved in agent_infos[‘latent_info_name’], where info_name is the original latent info name.

Returns:
A batch of the trajectories completed since
the last call to collect_rollout().
Return type:garage.TrajectoryBatch
start_rollout()[source]

Begin a new rollout.

step_rollout()[source]

Take a single time-step in the current rollout.

Returns:
True iff the path is done, either due to the environment
indicating termination of due to reaching max_path_length.
Return type:bool