Source code for garage.tf.policies.task_embedding_policy

"""Policy class for Task Embedding envs."""
import abc

import akro

from garage.tf.policies.policy import StochasticPolicy


[docs]class TaskEmbeddingPolicy(StochasticPolicy): """Base class for Task Embedding policies in TensorFlow. This policy needs a task id in addition to observation to sample an action. Args: name (str): Policy name, also the variable scope. env_spec (garage.envs.EnvSpec): Environment specification. encoder (garage.tf.embeddings.StochasticEncoder): A encoder that embeds a task id to a latent. """ # pylint: disable=too-many-public-methods def __init__(self, name, env_spec, encoder): super().__init__(name, env_spec) self._encoder = encoder self._augmented_observation_space = akro.concat( self._env_spec.observation_space, self.task_space) @property def encoder(self): """garage.tf.embeddings.encoder.Encoder: Encoder.""" return self._encoder
[docs] def get_latent(self, task_id): """Get embedded task id in latent space. Args: task_id (np.ndarray): One-hot task id, with shape :math:`(N, )`. N is the number of tasks. Returns: np.ndarray: An embedding sampled from embedding distribution, with shape :math:`(Z, )`. Z is the dimension of the latent embedding. dict: Embedding distribution information. """ return self.encoder.get_latent(task_id)
@property def latent_space(self): """akro.Box: Space of latent.""" return self.encoder.spec.output_space @property def task_space(self): """akro.Box: One-hot space of task id.""" return self.encoder.spec.input_space @property def augmented_observation_space(self): """akro.Box: Concatenated observation space and one-hot task id.""" return self._augmented_observation_space @property def encoder_distribution(self): """garage.tf.distributions.DiagonalGaussian: Encoder distribution.""" return self.encoder.distribution
[docs] @abc.abstractmethod def get_action(self, observation): """Get action sampled from the policy. Args: observation (np.ndarray): Augmented observation from the environment, with shape :math:`(O+N, )`. O is the dimension of observation, N is the number of tasks. Returns: np.ndarray: Action sampled from the policy, with shape :math:`(A, )`. A is the dimension of action. dict: Action distribution information. """
[docs] @abc.abstractmethod def get_actions(self, observations): """Get actions sampled from the policy. Args: observations (np.ndarray): Augmented observation from the environment, with shape :math:`(T, O+N)`. T is the number of environment steps, O is the dimension of observation, N is the number of tasks. Returns: np.ndarray: Actions sampled from the policy, with shape :math:`(T, A)`. T is the number of environment steps, A is the dimension of action. dict: Action distribution information. """
[docs] @abc.abstractmethod def get_action_given_task(self, observation, task_id): """Sample an action given observation and task id. Args: observation (np.ndarray): Observation from the environment, with shape :math:`(O, )`. O is the dimension of the observation. task_id (np.ndarray): One-hot task id, with shape :math:`(N, ). N is the number of tasks. Returns: np.ndarray: Action sampled from the policy, with shape :math:`(A, )`. A is the dimension of action. dict: Action distribution information. """
[docs] @abc.abstractmethod def get_actions_given_tasks(self, observations, task_ids): """Sample a batch of actions given observations and task ids. Args: observations (np.ndarray): Observations from the environment, with shape :math:`(T, O)`. T is the number of environment steps, O is the dimension of observation. task_ids (np.ndarry): One-hot task ids, with shape :math:`(T, N)`. T is the number of environment steps, N is the number of tasks. Returns: np.ndarray: Actions sampled from the policy, with shape :math:`(T, A)`. T is the number of environment steps, A is the dimension of action. dict: Action distribution information. """
[docs] @abc.abstractmethod def get_action_given_latent(self, observation, latent): """Sample an action given observation and latent. Args: observation (np.ndarray): Observation from the environment, with shape :math:`(O, )`. O is the dimension of observation. latent (np.ndarray): Latent, with shape :math:`(Z, )`. Z is the dimension of latent embedding. Returns: np.ndarray: Action sampled from the policy, with shape :math:`(A, )`. A is the dimension of action. dict: Action distribution information. """
[docs] @abc.abstractmethod def get_actions_given_latents(self, observations, latents): """Sample a batch of actions given observations and latents. Args: observations (np.ndarray): Observations from the environment, with shape :math:`(T, O)`. T is the number of environment steps, O is the dimension of observation. latents (np.ndarray): Latents, with shape :math:`(T, Z)`. T is the number of environment steps, Z is the dimension of latent embedding. Returns: np.ndarray: Actions sampled from the policy, with shape :math:`(T, A)`. T is the number of environment steps, A is the dimension of action. dict: Action distribution information. """
[docs] def get_trainable_vars(self): """Get trainable variables. The trainable vars of a multitask policy should be the trainable vars of its model and the trainable vars of its embedding model. Returns: List[tf.Variable]: A list of trainable variables in the current variable scope. """ return (self._variable_scope.trainable_variables() + self.encoder.get_trainable_vars())
[docs] def get_global_vars(self): """Get global variables. The global vars of a multitask policy should be the global vars of its model and the trainable vars of its embedding model. Returns: List[tf.Variable]: A list of global variables in the current variable scope. """ return (self._variable_scope.global_variables() + self.encoder.get_global_vars())
[docs] def split_augmented_observation(self, collated): """Splits up observation into one-hot task and environment observation. Args: collated (np.ndarray): Environment observation concatenated with task one-hot, with shape :math:`(O+N, )`. O is the dimension of observation, N is the number of tasks. Returns: np.ndarray: Vanilla environment observation, with shape :math:`(O, )`. O is the dimension of observation. np.ndarray: Task one-hot, with shape :math:`(N, )`. N is the number of tasks. """ task_dim = self.task_space.flat_dim return collated[:-task_dim], collated[-task_dim:]