garage.torch.algos.pearl
¶
PEARL and PEARLWorker in Pytorch.
Code is adapted from https://github.com/katerakelly/oyster.
-
class
PEARL
(env, inner_policy, qf, vf, *, num_train_tasks, num_test_tasks=None, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=ContextConditionedPolicy, encoder_class=MLPEncoder, policy_lr=0.0003, qf_lr=0.0003, vf_lr=0.0003, context_lr=0.0003, policy_mean_reg_coeff=0.001, policy_std_reg_coeff=0.001, policy_pre_activation_coeff=0.0, soft_target_tau=0.005, kl_lambda=0.1, optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, discount=0.99, replay_buffer_size=1000000, reward_scale=1, update_post_train=1)¶ Bases:
garage.np.algos.MetaRLAlgorithm
A PEARL model based on https://arxiv.org/abs/1903.08254.
PEARL, which stands for Probablistic Embeddings for Actor-Critic Reinforcement Learning, is an off-policy meta-RL algorithm. It is built on top of SAC using two Q-functions and a value function with an addition of an inference network that estimates the posterior \(q(z \| c)\). The policy is conditioned on the latent variable Z in order to adpat its behavior to specific tasks.
- Parameters
env (list[Environment]) – Batch of sampled environment updates( EnvUpdate), which, when invoked on environments, will configure them with new tasks.
policy_class (type) – Class implementing :pyclass:`~ContextConditionedPolicy`
encoder_class (garage.torch.embeddings.ContextEncoder) – Encoder class for the encoder in context-conditioned policy.
inner_policy (garage.torch.policies.Policy) – Policy.
qf (torch.nn.Module) – Q-function.
vf (torch.nn.Module) – Value function.
num_train_tasks (int) – Number of tasks for training.
latent_dim (int) – Size of latent context vector.
encoder_hidden_sizes (list[int]) – Output dimension of dense layer(s) of the context encoder.
test_env_sampler (garage.experiment.SetTaskSampler) – Sampler for test tasks.
policy_lr (float) – Policy learning rate.
qf_lr (float) – Q-function learning rate.
vf_lr (float) – Value function learning rate.
context_lr (float) – Inference network learning rate.
policy_mean_reg_coeff (float) – Policy mean regulation weight.
policy_std_reg_coeff (float) – Policy std regulation weight.
policy_pre_activation_coeff (float) – Policy pre-activation weight.
soft_target_tau (float) – Interpolation parameter for doing the soft target update.
kl_lambda (float) – KL lambda value.
optimizer_class (type) – Type of optimizer for training networks.
use_information_bottleneck (bool) – False means latent context is deterministic.
use_next_obs_in_context (bool) – Whether or not to use next observation in distinguishing between tasks.
meta_batch_size (int) – Meta batch size.
num_steps_per_epoch (int) – Number of iterations per epoch.
num_initial_steps (int) – Number of transitions obtained per task before training.
num_tasks_sample (int) – Number of random tasks to obtain data for each iteration.
num_steps_prior (int) – Number of transitions to obtain per task with z ~ prior.
num_steps_posterior (int) – Number of transitions to obtain per task with z ~ posterior.
num_extra_rl_steps_posterior (int) – Number of additional transitions to obtain per task with z ~ posterior that are only used to train the policy and NOT the encoder.
batch_size (int) – Number of transitions in RL batch.
embedding_batch_size (int) – Number of transitions in context batch.
embedding_mini_batch_size (int) – Number of transitions in mini context batch; should be same as embedding_batch_size for non-recurrent encoder.
discount (float) – RL discount factor.
replay_buffer_size (int) – Maximum samples in replay buffer.
reward_scale (int) – Reward scale.
update_post_train (int) – How often to resample context when obtaining data during training (in episodes).
-
train
(self, trainer)¶ Obtain samples, train, and evaluate for each epoch.
- Parameters
trainer (Trainer) – Gives the algorithm the access to :method:`Trainer..step_epochs()`, which provides services such as snapshotting and sampler control.
-
property
policy
(self)¶ Return all the policy within the model.
- Returns
Policy within the model.
- Return type
-
property
networks
(self)¶ Return all the networks within the model.
- Returns
A list of networks.
- Return type
-
get_exploration_policy
(self)¶ Return a policy used before adaptation to a specific task.
Each time it is retrieved, this policy should only be evaluated in one task.
- Returns
- The policy used to obtain samples that are later used for
meta-RL adaptation.
- Return type
-
adapt_policy
(self, exploration_policy, exploration_episodes)¶ Produce a policy adapted for a task.
- Parameters
exploration_policy (Policy) – A policy which was returned from get_exploration_policy(), and which generated exploration_episodes by interacting with an environment. The caller may not use this object after passing it into this method.
exploration_episodes (EpisodeBatch) – Episodes to which to adapt, generated by exploration_policy exploring the environment.
- Returns
- A policy adapted to the task represented by the
exploration_episodes.
- Return type
-
to
(self, device=None)¶ Put all the networks within the model on device.
- Parameters
device (str) – ID of GPU or CPU.
-
classmethod
augment_env_spec
(cls, env_spec, latent_dim)¶ Augment environment by a size of latent dimension.
-
classmethod
get_env_spec
(cls, env_spec, latent_dim, module)¶ Get environment specs of encoder with latent dimension.
-
class
PEARLWorker
(*, seed, max_episode_length, worker_number, deterministic=False, accum_context=False)¶ Bases:
garage.sampler.DefaultWorker
A worker class used in sampling for PEARL.
It stores context and resample belief in the policy every step.
- Parameters
seed (int) – The seed to use to intialize random number generators.
max_episode_length (int or float) – The maximum length of episodes which will be sampled. Can be (floating point) infinity.
worker_number (int) – The number of the worker where this update is occurring. This argument is used to set a different seed for each worker.
deterministic (bool) – If True, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution.
accum_context (bool) – If True, update context of the agent.
-
env
¶ The worker’s environment.
- Type
Environment or None
-
start_episode
(self)¶ Begin a new episode.
-
step_episode
(self)¶ Take a single time-step in the current episode.
- Returns
True iff the episode is done, either due to the environment indicating termination of due to reaching max_episode_length.
- Return type
-
rollout
(self)¶ Sample a single episode of the agent in the environment.
- Returns
The collected episode.
- Return type
-
worker_init
(self)¶ Initialize a worker.
-
update_agent
(self, agent_update)¶ Update an agent, assuming it implements
Policy
.- Parameters
agent_update (np.ndarray or dict or Policy) – If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling Policy.get_param_values. Alternatively, a policy itself. Note that other implementations of Worker may take different types for this parameter.
-
update_env
(self, env_update)¶ Use any non-None env_update as a new environment.
A simple env update function. If env_update is not None, it should be the complete new environment.
This allows changing environments by passing the new environment as env_update into obtain_samples.
- Parameters
env_update (Environment or EnvUpdate or None) – The environment to replace the existing env with. Note that other implementations of Worker may take different types for this parameter.
- Raises
TypeError – If env_update is not one of the documented types.
-
collect_episode
(self)¶ Collect the current episode, clearing the internal buffer.
- Returns
- A batch of the episodes completed since the last call
to collect_episode().
- Return type
-
shutdown
(self)¶ Close the worker’s environment.