PEARL and PEARLWorker in Pytorch.

Code is adapted from

class PEARL(env, inner_policy, qf, vf, sampler, *, num_train_tasks, num_test_tasks=None, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=ContextConditionedPolicy, encoder_class=MLPEncoder, policy_lr=0.0003, qf_lr=0.0003, vf_lr=0.0003, context_lr=0.0003, policy_mean_reg_coeff=0.001, policy_std_reg_coeff=0.001, policy_pre_activation_coeff=0.0, soft_target_tau=0.005, kl_lambda=0.1, optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, discount=0.99, replay_buffer_size=1000000, reward_scale=1, update_post_train=1)


Inheritance diagram of garage.torch.algos.pearl.PEARL

A PEARL model based on

PEARL, which stands for Probablistic Embeddings for Actor-Critic Reinforcement Learning, is an off-policy meta-RL algorithm. It is built on top of SAC using two Q-functions and a value function with an addition of an inference network that estimates the posterior \(q(z \| c)\). The policy is conditioned on the latent variable Z in order to adpat its behavior to specific tasks.

  • env (list[Environment]) – Batch of sampled environment updates( EnvUpdate), which, when invoked on environments, will configure them with new tasks.

  • policy_class (type) – Class implementing :pyclass:`~ContextConditionedPolicy`

  • encoder_class (garage.torch.embeddings.ContextEncoder) – Encoder class for the encoder in context-conditioned policy.

  • inner_policy (garage.torch.policies.Policy) – Policy.

  • qf (torch.nn.Module) – Q-function.

  • vf (torch.nn.Module) – Value function.

  • sampler (garage.sampler.Sampler) – Sampler.

  • num_train_tasks (int) – Number of tasks for training.

  • num_test_tasks (int or None) – Number of tasks for testing.

  • latent_dim (int) – Size of latent context vector.

  • encoder_hidden_sizes (list[int]) – Output dimension of dense layer(s) of the context encoder.

  • test_env_sampler (garage.experiment.SetTaskSampler) – Sampler for test tasks.

  • policy_lr (float) – Policy learning rate.

  • qf_lr (float) – Q-function learning rate.

  • vf_lr (float) – Value function learning rate.

  • context_lr (float) – Inference network learning rate.

  • policy_mean_reg_coeff (float) – Policy mean regulation weight.

  • policy_std_reg_coeff (float) – Policy std regulation weight.

  • policy_pre_activation_coeff (float) – Policy pre-activation weight.

  • soft_target_tau (float) – Interpolation parameter for doing the soft target update.

  • kl_lambda (float) – KL lambda value.

  • optimizer_class (type) – Type of optimizer for training networks.

  • use_information_bottleneck (bool) – False means latent context is deterministic.

  • use_next_obs_in_context (bool) – Whether or not to use next observation in distinguishing between tasks.

  • meta_batch_size (int) – Meta batch size.

  • num_steps_per_epoch (int) – Number of iterations per epoch.

  • num_initial_steps (int) – Number of transitions obtained per task before training.

  • num_tasks_sample (int) – Number of random tasks to obtain data for each iteration.

  • num_steps_prior (int) – Number of transitions to obtain per task with z ~ prior.

  • num_steps_posterior (int) – Number of transitions to obtain per task with z ~ posterior.

  • num_extra_rl_steps_posterior (int) – Number of additional transitions to obtain per task with z ~ posterior that are only used to train the policy and NOT the encoder.

  • batch_size (int) – Number of transitions in RL batch.

  • embedding_batch_size (int) – Number of transitions in context batch.

  • embedding_mini_batch_size (int) – Number of transitions in mini context batch; should be same as embedding_batch_size for non-recurrent encoder.

  • discount (float) – RL discount factor.

  • replay_buffer_size (int) – Maximum samples in replay buffer.

  • reward_scale (int) – Reward scale.

  • update_post_train (int) – How often to resample context when obtaining data during training (in episodes).

property policy

Return all the policy within the model.


Policy within the model.

Return type


property networks

Return all the networks within the model.


A list of networks.

Return type



Obtain samples, train, and evaluate for each epoch.


trainer (Trainer) – Gives the algorithm the access to :method:`Trainer..step_epochs()`, which provides services such as snapshotting and sampler control.


Return a policy used before adaptation to a specific task.

Each time it is retrieved, this policy should only be evaluated in one task.


The policy used to obtain samples that are later used for

meta-RL adaptation.

Return type


adapt_policy(exploration_policy, exploration_episodes)

Produce a policy adapted for a task.

  • exploration_policy (Policy) – A policy which was returned from get_exploration_policy(), and which generated exploration_episodes by interacting with an environment. The caller may not use this object after passing it into this method.

  • exploration_episodes (EpisodeBatch) – Episodes to which to adapt, generated by exploration_policy exploring the environment.


A policy adapted to the task represented by the


Return type



Put all the networks within the model on device.


device (str) – ID of GPU or CPU.

classmethod augment_env_spec(env_spec, latent_dim)

Augment environment by a size of latent dimension.

  • env_spec (EnvSpec) – Environment specs to be augmented.

  • latent_dim (int) – Latent dimension.


Augmented environment specs.

Return type


classmethod get_env_spec(env_spec, latent_dim, module)

Get environment specs of encoder with latent dimension.

  • env_spec (EnvSpec) – Environment specification.

  • latent_dim (int) – Latent dimension.

  • module (str) – Module to get environment specs for.


Module environment specs with latent dimension.

Return type


class PEARLWorker(*, seed, max_episode_length, worker_number, deterministic=False, accum_context=False)

Bases: garage.sampler.DefaultWorker

Inheritance diagram of garage.torch.algos.pearl.PEARLWorker

A worker class used in sampling for PEARL.

It stores context and resample belief in the policy every step.

  • seed (int) – The seed to use to intialize random number generators.

  • max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.

  • worker_number (int) – The number of the worker where this update is occurring. This argument is used to set a different seed for each worker.

  • deterministic (bool) – If True, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution.

  • accum_context (bool) – If True, update context of the agent.


The worker’s agent.


Policy or None


The worker’s environment.


Environment or None


Begin a new episode.


Take a single time-step in the current episode.


True iff the episode is done, either due to the environment indicating termination of due to reaching max_episode_length.

Return type



Sample a single episode of the agent in the environment.


The collected episode.

Return type



Initialize a worker.


Update an agent, assuming it implements Policy.


agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.


Use any non-None env_update as a new environment.

A simple env update function. If env_update is not None, it should be the complete new environment.

This allows changing environments by passing the new environment as env_update into obtain_samples.


env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.


TypeError – If env_update is not one of the documented types.


Collect the current episode, clearing the internal buffer.


A batch of the episodes completed since the last call

to collect_episode().

Return type



Close the worker’s environment.