garage.tf.algos.te module¶
Task Embedding Algorithm.
-
class
TaskEmbeddingWorker
(*, seed, max_path_length, worker_number)[source]¶ Bases:
garage.sampler.default_worker.DefaultWorker
A sampler worker for Task Embedding Algorithm.
In addition to DefaultWorker, this worker adds one-hot task id to env_info, and adds latent and latent infos to agent_info.
Parameters: - seed (int) – The seed to use to intialize random number generators.
- max_path_length (int or float) – The maximum length paths which will be sampled. Can be (floating point) infinity.
- worker_number (int) – The number of the worker where this update is occurring. This argument is used to set a different seed for each worker.
-
collect_rollout
()[source]¶ Collect the current rollout, clearing the internal buffer.
One-hot task id is saved in env_infos[‘task_onehot’]. Latent is saved in agent_infos[‘latent’]. Latent infos are saved in agent_infos[‘latent_info_name’], where info_name is the original latent info name.
Returns: - A batch of the trajectories completed since
- the last call to collect_rollout().
Return type: garage.TrajectoryBatch