garage.tf.policies.task_embedding_policy

Policy class for Task Embedding envs.

class TaskEmbeddingPolicy

Bases: garage.tf.policies.policy.Policy

Inheritance diagram of garage.tf.policies.task_embedding_policy.TaskEmbeddingPolicy

Base class for Task Embedding policies in TensorFlow.

This policy needs a task id in addition to observation to sample an action.

property encoder

Encoder.

Type

garage.tf.embeddings.encoder.Encoder

property latent_space

Space of latent.

Type

akro.Box

property task_space

One-hot space of task id.

Type

akro.Box

property augmented_observation_space

Concatenated observation space and one-hot task id.

Type

akro.Box

property encoder_distribution

Encoder distribution.

Type

tfp.Distribution.MultivariateNormalDiag

property state_info_specs

State info specification.

Returns

keys and shapes for the information related to the

module’s state when taking an action.

Return type

List[str]

property state_info_keys

State info keys.

Returns

keys for the information related to the module’s state

when taking an input.

Return type

List[str]

property name

Name of policy.

Returns

Name of policy

Return type

str

property env_spec

Policy environment specification.

Returns

Environment specification.

Return type

garage.EnvSpec

property observation_space

Observation space.

Returns

The observation space of the environment.

Return type

akro.Space

property action_space

Action space.

Returns

The action space of the environment.

Return type

akro.Space

get_latent(task_id)

Get embedded task id in latent space.

Parameters

task_id (np.ndarray) – One-hot task id, with shape \((N, )\). N is the number of tasks.

Returns

An embedding sampled from embedding distribution, with

shape \((Z, )\). Z is the dimension of the latent embedding.

dict: Embedding distribution information.

Return type

np.ndarray

abstract get_action(observation)

Get action sampled from the policy.

Parameters

observation (np.ndarray) – Augmented observation from the environment, with shape \((O+N, )\). O is the dimension of observation, N is the number of tasks.

Returns

Action sampled from the policy,

with shape \((A, )\). A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

abstract get_actions(observations)

Get actions sampled from the policy.

Parameters

observations (np.ndarray) – Augmented observation from the environment, with shape \((T, O+N)\). T is the number of environment steps, O is the dimension of observation, N is the number of tasks.

Returns

Actions sampled from the policy,

with shape \((T, A)\). T is the number of environment steps, A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

abstract get_action_given_task(observation, task_id)

Sample an action given observation and task id.

Parameters
  • observation (np.ndarray) – Observation from the environment, with shape \((O, )\). O is the dimension of the observation.

  • task_id (np.ndarray) – One-hot task id, with shape :math:`(N, ). N is the number of tasks.

Returns

Action sampled from the policy, with shape

\((A, )\). A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

abstract get_actions_given_tasks(observations, task_ids)

Sample a batch of actions given observations and task ids.

Parameters
  • observations (np.ndarray) – Observations from the environment, with shape \((T, O)\). T is the number of environment steps, O is the dimension of observation.

  • task_ids (np.ndarry) – One-hot task ids, with shape \((T, N)\). T is the number of environment steps, N is the number of tasks.

Returns

Actions sampled from the policy,

with shape \((T, A)\). T is the number of environment steps, A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

abstract get_action_given_latent(observation, latent)

Sample an action given observation and latent.

Parameters
  • observation (np.ndarray) – Observation from the environment, with shape \((O, )\). O is the dimension of observation.

  • latent (np.ndarray) – Latent, with shape \((Z, )\). Z is the dimension of latent embedding.

Returns

Action sampled from the policy,

with shape \((A, )\). A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

abstract get_actions_given_latents(observations, latents)

Sample a batch of actions given observations and latents.

Parameters
  • observations (np.ndarray) – Observations from the environment, with shape \((T, O)\). T is the number of environment steps, O is the dimension of observation.

  • latents (np.ndarray) – Latents, with shape \((T, Z)\). T is the number of environment steps, Z is the dimension of latent embedding.

Returns

Actions sampled from the policy,

with shape \((T, A)\). T is the number of environment steps, A is the dimension of action.

dict: Action distribution information.

Return type

np.ndarray

split_augmented_observation(collated)

Splits up observation into one-hot task and environment observation.

Parameters

collated (np.ndarray) – Environment observation concatenated with task one-hot, with shape \((O+N, )\). O is the dimension of observation, N is the number of tasks.

Returns

Vanilla environment observation,

with shape \((O, )\). O is the dimension of observation.

np.ndarray: Task one-hot, with shape \((N, )\). N is the number

of tasks.

Return type

np.ndarray

reset(do_resets=None)

Reset the policy.

This is effective only to recurrent policies.

do_resets is an array of boolean indicating which internal states to be reset. The length of do_resets should be equal to the length of inputs, i.e. batch size.

Parameters

do_resets (numpy.ndarray) – Bool array indicating which states to be reset.