Source code for garage.torch.algos.sac

"""This modules creates a sac model in PyTorch."""
from collections import deque
import copy

from dowel import tabular
import numpy as np
import torch
import torch.nn.functional as F

from garage import log_performance
from garage.np import obtain_evaluation_samples
from garage.np.algos import RLAlgorithm
from garage.sampler import OffPolicyVectorizedSampler
from garage.torch import dict_np_to_torch, global_device


[docs]class SAC(RLAlgorithm): """A SAC Model in Torch. Based on Soft Actor-Critic and Applications: https://arxiv.org/abs/1812.05905 Soft Actor-Critic (SAC) is an algorithm which optimizes a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimization and DDPG-style approaches. A central feature of SAC is entropy regularization. The policy is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy. This has a close connection to the exploration-exploitation trade-off: increasing entropy results in more exploration, which can accelerate learning later on. It can also prevent the policy from prematurely converging to a bad local optimum. Args: policy (garage.torch.policy.Policy): Policy/Actor/Agent that is being optimized by SAC. qf1 (garage.torch.q_function.ContinuousMLPQFunction): QFunction/Critic used for actor/policy optimization. See Soft Actor-Critic and Applications. qf2 (garage.torch.q_function.ContinuousMLPQFunction): QFunction/Critic used for actor/policy optimization. See Soft Actor-Critic and Applications. replay_buffer (garage.replay_buffer.ReplayBuffer): Stores transitions that are previously collected by the sampler. env_spec (garage.envs.env_spec.EnvSpec): The env_spec attribute of the environment that the agent is being trained in. Usually accessable by calling env.spec. max_path_length (int): Max path length of the environment. max_eval_path_length (int or None): Maximum length of paths used for off-policy evaluation. If None, defaults to `max_path_length`. gradient_steps_per_itr (int): Number of optimization steps that should gradient_steps_per_itr(int): Number of optimization steps that should occur before the training step is over and a new batch of transitions is collected by the sampler. fixed_alpha (float): The entropy/temperature to be used if temperature is not supposed to be learned. target_entropy (float): target entropy to be used during entropy/temperature optimization. If None, the default heuristic from Soft Actor-Critic Algorithms and Applications is used. initial_log_entropy (float): initial entropy/temperature coefficient to be used if a fixed_alpha is not being used (fixed_alpha=None), and the entropy/temperature coefficient is being learned. discount (float): Discount factor to be used during sampling and critic/q_function optimization. buffer_batch_size (int): The number of transitions sampled from the replay buffer that are used during a single optimization step. min_buffer_size (int): The minimum number of transitions that need to be in the replay buffer before training can begin. target_update_tau (float): coefficient that controls the rate at which the target q_functions update over optimization iterations. policy_lr (float): learning rate for policy optimizers. qf_lr (float): learning rate for q_function optimizers. reward_scale (float): reward scale. Changing this hyperparameter changes the effect that the reward from a transition will have during optimization. optimizer (torch.optim.Optimizer): optimizer to be used for policy/actor, q_functions/critics, and temperature/entropy optimizations. steps_per_epoch (int): Number of train_once calls per epoch. num_evaluation_trajectories (int): The number of evaluation trajectories used for computing eval stats at the end of every epoch. eval_env (garage.envs.GarageEnv): environment used for collecting evaluation trajectories. If None, a copy of the train env is used. """ def __init__( self, env_spec, policy, qf1, qf2, replay_buffer, *, # Everything after this is numbers. max_path_length, max_eval_path_length=None, gradient_steps_per_itr, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., discount=0.99, buffer_batch_size=64, min_buffer_size=int(1e4), target_update_tau=5e-3, policy_lr=3e-4, qf_lr=3e-4, reward_scale=1.0, optimizer=torch.optim.Adam, steps_per_epoch=1, num_evaluation_trajectories=10, eval_env=None): self._qf1 = qf1 self._qf2 = qf2 self.replay_buffer = replay_buffer self._tau = target_update_tau self._policy_lr = policy_lr self._qf_lr = qf_lr self._initial_log_entropy = initial_log_entropy self._gradient_steps = gradient_steps_per_itr self._optimizer = optimizer self._num_evaluation_trajectories = num_evaluation_trajectories self._eval_env = eval_env self._min_buffer_size = min_buffer_size self._steps_per_epoch = steps_per_epoch self._buffer_batch_size = buffer_batch_size self._discount = discount self._reward_scale = reward_scale self.max_path_length = max_path_length self._max_eval_path_length = (max_eval_path_length or max_path_length) # used by OffPolicyVectorizedSampler self.policy = policy self.env_spec = env_spec self.replay_buffer = replay_buffer self.exploration_policy = None self.sampler_cls = OffPolicyVectorizedSampler self._reward_scale = reward_scale # use 2 target q networks self._target_qf1 = copy.deepcopy(self._qf1) self._target_qf2 = copy.deepcopy(self._qf2) self._policy_optimizer = self._optimizer(self.policy.parameters(), lr=self._policy_lr) self._qf1_optimizer = self._optimizer(self._qf1.parameters(), lr=self._qf_lr) self._qf2_optimizer = self._optimizer(self._qf2.parameters(), lr=self._qf_lr) # automatic entropy coefficient tuning self._use_automatic_entropy_tuning = fixed_alpha is None self._fixed_alpha = fixed_alpha if self._use_automatic_entropy_tuning: if target_entropy: self._target_entropy = target_entropy else: self._target_entropy = -np.prod( self.env_spec.action_space.shape).item() self._log_alpha = torch.Tensor([self._initial_log_entropy ]).requires_grad_() self._alpha_optimizer = optimizer([self._log_alpha], lr=self._policy_lr) else: self._log_alpha = torch.Tensor([self._fixed_alpha]).log() self.episode_rewards = deque(maxlen=30)
[docs] def train(self, runner): """Obtain samplers and start actual training for each epoch. Args: runner (LocalRunner): LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control. Returns: float: The average return in last epoch cycle. """ if not self._eval_env: self._eval_env = runner.get_env_copy() last_return = None for _ in runner.step_epochs(): for _ in range(self._steps_per_epoch): if not (self.replay_buffer.n_transitions_stored >= self._min_buffer_size): batch_size = int(self._min_buffer_size) else: batch_size = None runner.step_path = runner.obtain_samples( runner.step_itr, batch_size) path_returns = [] for path in runner.step_path: self.replay_buffer.add_path( dict(observation=path['observations'], action=path['actions'], reward=path['rewards'].reshape(-1, 1), next_observation=path['next_observations'], terminal=path['dones'].reshape(-1, 1))) path_returns.append(sum(path['rewards'])) assert len(path_returns) == len(runner.step_path) self.episode_rewards.append(np.mean(path_returns)) for _ in range(self._gradient_steps): policy_loss, qf1_loss, qf2_loss = self.train_once() last_return = self._evaluate_policy(runner.step_itr) self._log_statistics(policy_loss, qf1_loss, qf2_loss) tabular.record('TotalEnvSteps', runner.total_env_steps) runner.step_itr += 1 return np.mean(last_return)
[docs] def train_once(self, itr=None, paths=None): """Complete 1 training iteration of SAC. Args: itr (int): Iteration number. This argument is deprecated. paths (list[dict]): A list of collected paths. This argument is deprecated. Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ del itr del paths if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: samples = self.replay_buffer.sample_transitions( self._buffer_batch_size) samples = dict_np_to_torch(samples) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples) self._update_targets() return policy_loss, qf1_loss, qf2_loss
def _get_log_alpha(self, samples_data): """Return the value of log_alpha. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. This function exists in case there are versions of sac that need access to a modified log_alpha, such as multi_task sac. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: log_alpha """ del samples_data log_alpha = self._log_alpha return log_alpha def _temperature_objective(self, log_pi, samples_data): """Compute the temperature/alpha coefficient loss. Args: log_pi(torch.Tensor): log probability of actions that are sampled from the replay buffer. Shape is (1, buffer_batch_size). samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: the temperature/alpha coefficient loss. """ alpha_loss = 0 if self._use_automatic_entropy_tuning: alpha_loss = (-(self._get_log_alpha(samples_data)) * (log_pi.detach() + self._target_entropy)).mean() return alpha_loss def _actor_objective(self, samples_data, new_actions, log_pi_new_actions): """Compute the Policy/Actor loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. new_actions(torch.Tensor): Actions resampled from the policy based based on the Observations, obs, which were sampled from the replay buffer. Shape is (action_dim, buffer_batch_size). log_pi_new_actions(torch.Tensor): Log probability of the new actions on the TanhNormal distributions that they were sampled from. Shape is (1, buffer_batch_size). Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from the Policy/Actor. """ obs = samples_data['observation'] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() min_q_new_actions = torch.min(self._qf1(obs, new_actions), self._qf2(obs, new_actions)) policy_objective = ((alpha * log_pi_new_actions) - min_q_new_actions.flatten()).mean() return policy_objective def _critic_objective(self, samples_data): """Compute the Q-function/critic loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data['observation'] actions = samples_data['action'] rewards = samples_data['reward'].flatten() terminals = samples_data['terminal'].flatten() next_obs = samples_data['next_observation'] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() q1_pred = self._qf1(obs, actions) q2_pred = self._qf2(obs, actions) new_next_actions_dist = self.policy(next_obs)[0] new_next_actions_pre_tanh, new_next_actions = ( new_next_actions_dist.rsample_with_pre_tanh_value()) new_log_pi = new_next_actions_dist.log_prob( value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh) target_q_values = torch.min( self._target_qf1(next_obs, new_next_actions), self._target_qf2( next_obs, new_next_actions)).flatten() - (alpha * new_log_pi) with torch.no_grad(): q_target = rewards * self._reward_scale + ( 1. - terminals) * self._discount * target_q_values qf1_loss = F.mse_loss(q1_pred.flatten(), q_target) qf2_loss = F.mse_loss(q2_pred.flatten(), q_target) return qf1_loss, qf2_loss def _update_targets(self): """Update parameters in the target q-functions.""" target_qfs = [self._target_qf1, self._target_qf2] qfs = [self._qf1, self._qf2] for target_qf, qf in zip(target_qfs, qfs): for t_param, param in zip(target_qf.parameters(), qf.parameters()): t_param.data.copy_(t_param.data * (1.0 - self._tau) + param.data * self._tau)
[docs] def optimize_policy(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data['observation'] qf1_loss, qf2_loss = self._critic_objective(samples_data) self._qf1_optimizer.zero_grad() qf1_loss.backward() self._qf1_optimizer.step() self._qf2_optimizer.zero_grad() qf2_loss.backward() self._qf2_optimizer.step() action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value()) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) self._alpha_optimizer.zero_grad() alpha_loss.backward() self._alpha_optimizer.step() return policy_loss, qf1_loss, qf2_loss
def _evaluate_policy(self, epoch): """Evaluate the performance of the policy via deterministic rollouts. Statistics such as (average) discounted return and success rate are recorded. Args: epoch(int): The current training epoch. Returns: float: The average return across self._num_evaluation_trajectories trajectories """ eval_trajectories = obtain_evaluation_samples( self.policy, self._eval_env, max_path_length=self._max_eval_path_length, num_trajs=self._num_evaluation_trajectories) last_return = log_performance(epoch, eval_trajectories, discount=self._discount) return last_return def _log_statistics(self, policy_loss, qf1_loss, qf2_loss): """Record training statistics to dowel such as losses and returns. Args: policy_loss(torch.Tensor): loss from actor/policy network. qf1_loss(torch.Tensor): loss from 1st qf/critic network. qf2_loss(torch.Tensor): loss from 2nd qf/critic network. """ with torch.no_grad(): tabular.record('AlphaTemperature/mean', self._log_alpha.exp().mean().item()) tabular.record('Policy/Loss', policy_loss.item()) tabular.record('QF/{}'.format('Qf1Loss'), float(qf1_loss)) tabular.record('QF/{}'.format('Qf2Loss'), float(qf2_loss)) tabular.record('ReplayBuffer/buffer_size', self.replay_buffer.n_transitions_stored) tabular.record('Average/TrainAverageReturn', np.mean(self.episode_rewards)) @property def networks(self): """Return all the networks within the model. Returns: list: A list of networks. """ return [ self.policy, self._qf1, self._qf2, self._target_qf1, self._target_qf2 ]
[docs] def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ if device is None: device = global_device() for net in self.networks: net.to(device) if not self._use_automatic_entropy_tuning: self._log_alpha = torch.Tensor([self._fixed_alpha ]).log().to(device) else: self._log_alpha = torch.Tensor([self._initial_log_entropy ]).to(device).requires_grad_() self._alpha_optimizer = self._optimizer([self._log_alpha], lr=self._policy_lr)