Source code for garage.torch.algos.pearl

"""PEARL and PEARLWorker in Pytorch.

Code is adapted from https://github.com/katerakelly/oyster.
"""

import copy

import akro
from dowel import logger
import numpy as np
import torch

from garage import InOutSpec, TimeStep
from garage.envs import EnvSpec
from garage.experiment import MetaEvaluator
from garage.np.algos import MetaRLAlgorithm
from garage.replay_buffer import PathBuffer
from garage.sampler import DefaultWorker
from garage.torch import global_device
from garage.torch.embeddings import MLPEncoder
from garage.torch.policies import ContextConditionedPolicy


[docs]class PEARL(MetaRLAlgorithm): r"""A PEARL model based on https://arxiv.org/abs/1903.08254. PEARL, which stands for Probablistic Embeddings for Actor-Critic Reinforcement Learning, is an off-policy meta-RL algorithm. It is built on top of SAC using two Q-functions and a value function with an addition of an inference network that estimates the posterior :math:`q(z \| c)`. The policy is conditioned on the latent variable Z in order to adpat its behavior to specific tasks. Args: env (list[GarageEnv]): Batch of sampled environment updates(EnvUpdate), which, when invoked on environments, will configure them with new tasks. policy_class (garage.torch.policies.Policy): Context-conditioned policy class. encoder_class (garage.torch.embeddings.ContextEncoder): Encoder class for the encoder in context-conditioned policy. inner_policy (garage.torch.policies.Policy): Policy. qf (torch.nn.Module): Q-function. vf (torch.nn.Module): Value function. num_train_tasks (int): Number of tasks for training. num_test_tasks (int): Number of tasks for testing. latent_dim (int): Size of latent context vector. encoder_hidden_sizes (list[int]): Output dimension of dense layer(s) of the context encoder. test_env_sampler (garage.experiment.SetTaskSampler): Sampler for test tasks. policy_lr (float): Policy learning rate. qf_lr (float): Q-function learning rate. vf_lr (float): Value function learning rate. context_lr (float): Inference network learning rate. policy_mean_reg_coeff (float): Policy mean regulation weight. policy_std_reg_coeff (float): Policy std regulation weight. policy_pre_activation_coeff (float): Policy pre-activation weight. soft_target_tau (float): Interpolation parameter for doing the soft target update. kl_lambda (float): KL lambda value. optimizer_class (callable): Type of optimizer for training networks. use_information_bottleneck (bool): False means latent context is deterministic. use_next_obs_in_context (bool): Whether or not to use next observation in distinguishing between tasks. meta_batch_size (int): Meta batch size. num_steps_per_epoch (int): Number of iterations per epoch. num_initial_steps (int): Number of transitions obtained per task before training. num_tasks_sample (int): Number of random tasks to obtain data for each iteration. num_steps_prior (int): Number of transitions to obtain per task with z ~ prior. num_steps_posterior (int): Number of transitions to obtain per task with z ~ posterior. num_extra_rl_steps_posterior (int): Number of additional transitions to obtain per task with z ~ posterior that are only used to train the policy and NOT the encoder. batch_size (int): Number of transitions in RL batch. embedding_batch_size (int): Number of transitions in context batch. embedding_mini_batch_size (int): Number of transitions in mini context batch; should be same as embedding_batch_size for non-recurrent encoder. max_path_length (int): Maximum path length. discount (float): RL discount factor. replay_buffer_size (int): Maximum samples in replay buffer. reward_scale (int): Reward scale. update_post_train (int): How often to resample context when obtaining data during training (in trajectories). """ # pylint: disable=too-many-statements def __init__(self, env, inner_policy, qf, vf, num_train_tasks, num_test_tasks, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=ContextConditionedPolicy, encoder_class=MLPEncoder, policy_lr=3E-4, qf_lr=3E-4, vf_lr=3E-4, context_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, kl_lambda=.1, optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, reward_scale=1, update_post_train=1): self._env = env self._qf1 = qf self._qf2 = copy.deepcopy(qf) self._vf = vf self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._latent_dim = latent_dim self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._kl_lambda = kl_lambda self._use_information_bottleneck = use_information_bottleneck self._use_next_obs_in_context = use_next_obs_in_context self._meta_batch_size = meta_batch_size self._num_steps_per_epoch = num_steps_per_epoch self._num_initial_steps = num_initial_steps self._num_tasks_sample = num_tasks_sample self._num_steps_prior = num_steps_prior self._num_steps_posterior = num_steps_posterior self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._reward_scale = reward_scale self._update_post_train = update_post_train self._task_idx = None self._is_resuming = False worker_args = dict(deterministic=True, accum_context=True) self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler, max_path_length=max_path_length, worker_class=PEARLWorker, worker_args=worker_args, n_test_tasks=num_test_tasks) encoder_spec = self.get_env_spec(env[0](), latent_dim, 'encoder') encoder_in_dim = int(np.prod(encoder_spec.input_space.shape)) encoder_out_dim = int(np.prod(encoder_spec.output_space.shape)) context_encoder = encoder_class(input_dim=encoder_in_dim, output_dim=encoder_out_dim, hidden_sizes=encoder_hidden_sizes) self._policy = policy_class( latent_dim=latent_dim, context_encoder=context_encoder, policy=inner_policy, use_information_bottleneck=use_information_bottleneck, use_next_obs=use_next_obs_in_context) # buffer for training RL update self._replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self.target_vf = copy.deepcopy(self._vf) self.vf_criterion = torch.nn.MSELoss() self._policy_optimizer = optimizer_class( self._policy.networks[1].parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self._qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self._qf2.parameters(), lr=qf_lr, ) self.vf_optimizer = optimizer_class( self._vf.parameters(), lr=vf_lr, ) self.context_optimizer = optimizer_class( self._policy.networks[0].parameters(), lr=context_lr, ) def __getstate__(self): """Object.__getstate__. Returns: dict: the state to be pickled for the instance. """ data = self.__dict__.copy() del data['_replay_buffers'] del data['_context_replay_buffers'] return data def __setstate__(self, state): """Object.__setstate__. Args: state (dict): unpickled state. """ self.__dict__.update(state) self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._is_resuming = True
[docs] def train(self, runner): """Obtain samples, train, and evaluate for each epoch. Args: runner (LocalRunner): LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control. """ for _ in runner.step_epochs(): epoch = runner.step_itr / self._num_steps_per_epoch # obtain initial set of samples from all train tasks if epoch == 0 or self._is_resuming: for idx in range(self._num_train_tasks): self._task_idx = idx self._obtain_samples(runner, epoch, self._num_initial_steps, np.inf) self._is_resuming = False # obtain samples from random tasks for _ in range(self._num_tasks_sample): idx = np.random.randint(self._num_train_tasks) self._task_idx = idx self._context_replay_buffers[idx].clear() # obtain samples with z ~ prior if self._num_steps_prior > 0: self._obtain_samples(runner, epoch, self._num_steps_prior, np.inf) # obtain samples with z ~ posterior if self._num_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_steps_posterior, self._update_post_train) # obtain extras samples for RL training but not encoder if self._num_extra_rl_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_extra_rl_steps_posterior, self._update_post_train, add_to_enc_buffer=False) logger.log('Training...') # sample train tasks and optimize networks self._train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._policy.reset_belief() self._evaluator.evaluate(self)
def _train_once(self): """Perform one iteration of training.""" for _ in range(self._num_steps_per_epoch): indices = np.random.choice(range(self._num_train_tasks), self._meta_batch_size) self._optimize_policy(indices) def _optimize_policy(self, indices): """Perform algorithm optimizing. Args: indices (list): Tasks used for training. """ num_tasks = len(indices) context = self._sample_context(indices) # clear context and reset belief of policy self._policy.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) obs, actions, rewards, next_obs, terms = self._sample_data(indices) policy_outputs, task_z = self._policy(obs, context) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] # flatten out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # optimize qf and encoder networks q1_pred = self._qf1(torch.cat([obs, actions], dim=1), task_z) q2_pred = self._qf2(torch.cat([obs, actions], dim=1), task_z) v_pred = self._vf(obs, task_z.detach()) with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z) # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self._use_information_bottleneck: kl_div = self._policy.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() rewards_flat = rewards.view(self._batch_size * num_tasks, -1) rewards_flat = rewards_flat * self._reward_scale terms_flat = terms.view(self._batch_size * num_tasks, -1) q_target = rewards_flat + ( 1. - terms_flat) * self._discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.context_optimizer.step() # compute min Q on the new actions q1 = self._qf1(torch.cat([obs, new_actions], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([obs, new_actions], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize vf v_target = min_q - log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() self._update_target_network() # optimize policy log_policy_target = min_q policy_loss = (log_pi - log_policy_target).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self._policy_pre_activation_coeff * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = (mean_reg_loss + std_reg_loss + pre_activation_reg_loss) policy_loss = policy_loss + policy_reg_loss self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() def _obtain_samples(self, runner, itr, num_samples, update_posterior_rate, add_to_enc_buffer=True): """Obtain samples. Args: runner (LocalRunner): LocalRunner. itr (int): Index of iteration (epoch). num_samples (int): Number of samples to obtain. update_posterior_rate (int): How often (in trajectories) to infer posterior of policy. add_to_enc_buffer (bool): Whether or not to add samples to encoder buffer. """ self._policy.reset_belief() total_samples = 0 if update_posterior_rate != np.inf: num_samples_per_batch = (update_posterior_rate * self.max_path_length) else: num_samples_per_batch = num_samples while total_samples < num_samples: paths = runner.obtain_samples(itr, num_samples_per_batch, self._policy, self._env[self._task_idx]) total_samples += sum([len(path['rewards']) for path in paths]) for path in paths: p = { 'observations': path['observations'], 'actions': path['actions'], 'rewards': path['rewards'].reshape(-1, 1), 'next_observations': path['next_observations'], 'dones': path['dones'].reshape(-1, 1) } self._replay_buffers[self._task_idx].add_path(p) if add_to_enc_buffer: self._context_replay_buffers[self._task_idx].add_path(p) if update_posterior_rate != np.inf: context = self._sample_context(self._task_idx) self._policy.infer_posterior(context) def _sample_data(self, indices): """Sample batch of training data from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X is the number of tasks. N is batch size. torch.Tensor: Actions, with shape :math:`(X, N, A^*)`. torch.Tensor: Rewards, with shape :math:`(X, N, 1)`. torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`. torch.Tensor: Dones, with shape :math:`(X, N, 1)`. """ # transitions sampled randomly from replay buffer initialized = False for idx in indices: batch = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] a = batch['actions'][np.newaxis] r = batch['rewards'][np.newaxis] no = batch['next_observations'][np.newaxis] d = batch['dones'][np.newaxis] initialized = True else: o = np.vstack((o, batch['observations'][np.newaxis])) a = np.vstack((a, batch['actions'][np.newaxis])) r = np.vstack((r, batch['rewards'][np.newaxis])) no = np.vstack((no, batch['next_observations'][np.newaxis])) d = np.vstack((d, batch['dones'][np.newaxis])) o = torch.as_tensor(o, device=global_device()).float() a = torch.as_tensor(a, device=global_device()).float() r = torch.as_tensor(r, device=global_device()).float() no = torch.as_tensor(no, device=global_device()).float() d = torch.as_tensor(d, device=global_device()).float() return o, a, r, no, d def _sample_context(self, indices): """Sample batch of context from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: batch = self._context_replay_buffers[idx].sample_transitions( self._embedding_batch_size) o = batch['observations'] a = batch['actions'] r = batch['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack((context, batch['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack((final_context, context[np.newaxis])) final_context = torch.as_tensor(final_context, device=global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context def _update_target_network(self): """Update parameters in the target vf network.""" for target_param, param in zip(self.target_vf.parameters(), self._vf.parameters()): target_param.data.copy_(target_param.data * (1.0 - self._soft_target_tau) + param.data * self._soft_target_tau) @property def policy(self): """Return all the policy within the model. Returns: garage.torch.policies.Policy: Policy within the model. """ return self._policy @property def networks(self): """Return all the networks within the model. Returns: list: A list of networks. """ return self._policy.networks + [self._policy] + [ self._qf1, self._qf2, self._vf, self.target_vf ]
[docs] def get_exploration_policy(self): """Return a policy used before adaptation to a specific task. Each time it is retrieved, this policy should only be evaluated in one task. Returns: garage.Policy: The policy used to obtain samples that are later used for meta-RL adaptation. """ return self._policy
[docs] def adapt_policy(self, exploration_policy, exploration_trajectories): """Produce a policy adapted for a task. Args: exploration_policy (garage.Policy): A policy which was returned from get_exploration_policy(), and which generated exploration_trajectories by interacting with an environment. The caller may not use this object after passing it into this method. exploration_trajectories (garage.TrajectoryBatch): Trajectories to adapt to, generated by exploration_policy exploring the environment. Returns: garage.Policy: A policy adapted to the task represented by the exploration_trajectories. """ total_steps = sum(exploration_trajectories.lengths) o = exploration_trajectories.observations a = exploration_trajectories.actions r = exploration_trajectories.rewards.reshape(total_steps, 1) ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1) context = torch.as_tensor(ctxt, device=global_device()).float() self._policy.infer_posterior(context) return self._policy
[docs] def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ device = device or global_device() for net in self.networks: net.to(device)
[docs] @classmethod def augment_env_spec(cls, env_spec, latent_dim): """Augment environment by a size of latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs to be augmented. latent_dim (int): Latent dimension. Returns: garage.envs.EnvSpec: Augmented environment specs. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) aug_obs = akro.Box(low=-1, high=1, shape=(obs_dim + latent_dim, ), dtype=np.float32) aug_act = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32) return EnvSpec(aug_obs, aug_act)
[docs] @classmethod def get_env_spec(cls, env_spec, latent_dim, module): """Get environment specs of encoder with latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs. latent_dim (int): Latent dimension. module (str): Module to get environment specs for. Returns: garage.envs.InOutSpec: Module environment specs with latent dimension. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) if module == 'encoder': in_dim = obs_dim + action_dim + 1 out_dim = latent_dim * 2 elif module == 'vf': in_dim = obs_dim out_dim = latent_dim in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32) out_space = akro.Box(low=-1, high=1, shape=(out_dim, ), dtype=np.float32) if module == 'encoder': spec = InOutSpec(in_space, out_space) elif module == 'vf': spec = EnvSpec(in_space, out_space) return spec
[docs]class PEARLWorker(DefaultWorker): """A worker class used in sampling for PEARL. It stores context and resample belief in the policy every step. Args: seed(int): The seed to use to intialize random number generators. max_path_length(int or float): The maximum length paths which will be sampled. Can be (floating point) infinity. worker_number(int): The number of the worker where this update is occurring. This argument is used to set a different seed for each worker. deterministic(bool): If true, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution. accum_context(bool): If true, update context of the agent. Attributes: agent(Policy or None): The worker's agent. env(gym.Env or None): The worker's environment. """ def __init__(self, *, seed, max_path_length, worker_number, deterministic=False, accum_context=False): self._deterministic = deterministic self._accum_context = accum_context super().__init__(seed=seed, max_path_length=max_path_length, worker_number=worker_number)
[docs] def start_rollout(self): """Begin a new rollout.""" self._path_length = 0 self._prev_obs = self.env.reset()
[docs] def step_rollout(self): """Take a single time-step in the current rollout. Returns: bool: True iff the path is done, either due to the environment indicating termination of due to reaching `max_path_length`. """ if self._path_length < self._max_path_length: a, agent_info = self.agent.get_action(self._prev_obs) if self._deterministic: a = agent_info['mean'] next_o, r, d, env_info = self.env.step(a) self._observations.append(self._prev_obs) self._rewards.append(r) self._actions.append(a) for k, v in agent_info.items(): self._agent_infos[k].append(v) for k, v in env_info.items(): self._env_infos[k].append(v) self._path_length += 1 self._terminals.append(d) if self._accum_context: s = TimeStep(env_spec=self.env, observation=self._prev_obs, next_observation=next_o, action=a, reward=float(r), terminal=d, env_info=env_info, agent_info=agent_info) self.agent.update_context(s) if not d: self._prev_obs = next_o return False self._lengths.append(self._path_length) self._last_observations.append(self._prev_obs) return True
[docs] def rollout(self): """Sample a single rollout of the agent in the environment. Returns: garage.TrajectoryBatch: The collected trajectory. """ self.agent.sample_from_belief() self.start_rollout() while not self.step_rollout(): pass self._agent_infos['context'] = [self.agent.z.detach().cpu().numpy() ] * self._max_path_length return self.collect_rollout()