Source code for garage.tf.algos.te

"""Task Embedding Algorithm."""
from collections import defaultdict

import numpy as np

from garage import TrajectoryBatch
from garage.sampler import DefaultWorker


[docs]class TaskEmbeddingWorker(DefaultWorker): """A sampler worker for Task Embedding Algorithm. In addition to DefaultWorker, this worker adds one-hot task id to env_info, and adds latent and latent infos to agent_info. Args: seed(int): The seed to use to intialize random number generators. max_path_length(int or float): The maximum length paths which will be sampled. Can be (floating point) infinity. worker_number(int): The number of the worker where this update is occurring. This argument is used to set a different seed for each worker. Attributes: agent(Policy or None): The worker's agent. env(gym.Env or None): The worker's environment. """ def __init__( self, *, # Require passing by keyword, since everything's an int. seed, max_path_length, worker_number): self._latents = [] self._tasks = [] self._latent_infos = defaultdict(list) self._z, self._t, self._latent_info = None, None, None super().__init__(seed=seed, max_path_length=max_path_length, worker_number=worker_number)
[docs] def start_rollout(self): """Begin a new rollout.""" # pylint: disable=protected-access self._t = self.env._active_task_one_hot() self._z, self._latent_info = self.agent.get_latent(self._t) self._z = self.agent.latent_space.flatten(self._z) super().start_rollout()
[docs] def step_rollout(self): """Take a single time-step in the current rollout. Returns: bool: True iff the path is done, either due to the environment indicating termination of due to reaching `max_path_length`. """ if self._path_length < self._max_path_length: a, agent_info = self.agent.get_action_given_latent( self._prev_obs, self._z) next_o, r, d, env_info = self.env.step(a) self._observations.append(self._prev_obs) self._rewards.append(r) self._actions.append(a) self._tasks.append(self._t) self._latents.append(self._z) for k, v in self._latent_info.items(): self._latent_infos[k].append(v) for k, v in agent_info.items(): self._agent_infos[k].append(v) for k, v in env_info.items(): self._env_infos[k].append(v) self._path_length += 1 self._terminals.append(d) if not d: self._prev_obs = next_o return False self._lengths.append(self._path_length) self._last_observations.append(self._prev_obs) return True
[docs] def collect_rollout(self): """Collect the current rollout, clearing the internal buffer. One-hot task id is saved in env_infos['task_onehot']. Latent is saved in agent_infos['latent']. Latent infos are saved in agent_infos['latent_info_name'], where info_name is the original latent info name. Returns: garage.TrajectoryBatch: A batch of the trajectories completed since the last call to collect_rollout(). """ observations = self._observations self._observations = [] last_observations = self._last_observations self._last_observations = [] actions = self._actions self._actions = [] rewards = self._rewards self._rewards = [] terminals = self._terminals self._terminals = [] latents = self._latents self._latents = [] tasks = self._tasks self._tasks = [] env_infos = self._env_infos self._env_infos = defaultdict(list) agent_infos = self._agent_infos self._agent_infos = defaultdict(list) latent_infos = self._latent_infos self._latent_infos = defaultdict(list) for k, v in latent_infos.items(): latent_infos[k] = np.asarray(v) for k, v in agent_infos.items(): agent_infos[k] = np.asarray(v) for k, v in env_infos.items(): env_infos[k] = np.asarray(v) env_infos['task_onehot'] = np.asarray(tasks) agent_infos['latent'] = np.asarray(latents) for k, v in latent_infos.items(): agent_infos['latent_{}'.format(k)] = v lengths = self._lengths self._lengths = [] return TrajectoryBatch(self.env.spec, np.asarray(observations), np.asarray(last_observations), np.asarray(actions), np.asarray(rewards), np.asarray(terminals), dict(env_infos), dict(agent_infos), np.asarray(lengths, dtype='i'))