"""This modules creates a sac model in PyTorch."""
from collections import deque
import copy
from dowel import tabular
import numpy as np
import torch
import torch.nn.functional as F
from garage import log_performance
from garage.np import obtain_evaluation_samples
from garage.np.algos import RLAlgorithm
from garage.sampler import OffPolicyVectorizedSampler
from garage.torch import dict_np_to_torch, global_device
[docs]class SAC(RLAlgorithm):
"""A SAC Model in Torch.
Based on Soft Actor-Critic and Applications:
https://arxiv.org/abs/1812.05905
Soft Actor-Critic (SAC) is an algorithm which optimizes a stochastic
policy in an off-policy way, forming a bridge between stochastic policy
optimization and DDPG-style approaches.
A central feature of SAC is entropy regularization. The policy is trained
to maximize a trade-off between expected return and entropy, a measure of
randomness in the policy. This has a close connection to the
exploration-exploitation trade-off: increasing entropy results in more
exploration, which can accelerate learning later on. It can also prevent
the policy from prematurely converging to a bad local optimum.
Args:
policy (garage.torch.policy.Policy): Policy/Actor/Agent that is being
optimized by SAC.
qf1 (garage.torch.q_function.ContinuousMLPQFunction): QFunction/Critic
used for actor/policy optimization. See Soft Actor-Critic and
Applications.
qf2 (garage.torch.q_function.ContinuousMLPQFunction): QFunction/Critic
used for actor/policy optimization. See Soft Actor-Critic and
Applications.
replay_buffer (garage.replay_buffer.ReplayBuffer): Stores transitions
that are previously collected by the sampler.
env_spec (garage.envs.env_spec.EnvSpec): The env_spec attribute of the
environment that the agent is being trained in. Usually accessable
by calling env.spec.
max_path_length (int): Max path length of the environment.
max_eval_path_length (int or None): Maximum length of paths used for
off-policy evaluation. If None, defaults to `max_path_length`.
gradient_steps_per_itr (int): Number of optimization steps that should
gradient_steps_per_itr(int): Number of optimization steps that should
occur before the training step is over and a new batch of
transitions is collected by the sampler.
fixed_alpha (float): The entropy/temperature to be used if temperature
is not supposed to be learned.
target_entropy (float): target entropy to be used during
entropy/temperature optimization. If None, the default heuristic
from Soft Actor-Critic Algorithms and Applications is used.
initial_log_entropy (float): initial entropy/temperature coefficient
to be used if a fixed_alpha is not being used (fixed_alpha=None),
and the entropy/temperature coefficient is being learned.
discount (float): Discount factor to be used during sampling and
critic/q_function optimization.
buffer_batch_size (int): The number of transitions sampled from the
replay buffer that are used during a single optimization step.
min_buffer_size (int): The minimum number of transitions that need to
be in the replay buffer before training can begin.
target_update_tau (float): coefficient that controls the rate at which
the target q_functions update over optimization iterations.
policy_lr (float): learning rate for policy optimizers.
qf_lr (float): learning rate for q_function optimizers.
reward_scale (float): reward scale. Changing this hyperparameter
changes the effect that the reward from a transition will have
during optimization.
optimizer (torch.optim.Optimizer): optimizer to be used for
policy/actor, q_functions/critics, and temperature/entropy
optimizations.
steps_per_epoch (int): Number of train_once calls per epoch.
num_evaluation_trajectories (int): The number of evaluation
trajectories used for computing eval stats at the end of every
epoch.
eval_env (garage.envs.GarageEnv): environment used for collecting
evaluation trajectories. If None, a copy of the train env is used.
"""
def __init__(
self,
env_spec,
policy,
qf1,
qf2,
replay_buffer,
*, # Everything after this is numbers.
max_path_length,
max_eval_path_length=None,
gradient_steps_per_itr,
fixed_alpha=None,
target_entropy=None,
initial_log_entropy=0.,
discount=0.99,
buffer_batch_size=64,
min_buffer_size=int(1e4),
target_update_tau=5e-3,
policy_lr=3e-4,
qf_lr=3e-4,
reward_scale=1.0,
optimizer=torch.optim.Adam,
steps_per_epoch=1,
num_evaluation_trajectories=10,
eval_env=None):
self._qf1 = qf1
self._qf2 = qf2
self.replay_buffer = replay_buffer
self._tau = target_update_tau
self._policy_lr = policy_lr
self._qf_lr = qf_lr
self._initial_log_entropy = initial_log_entropy
self._gradient_steps = gradient_steps_per_itr
self._optimizer = optimizer
self._num_evaluation_trajectories = num_evaluation_trajectories
self._eval_env = eval_env
self._min_buffer_size = min_buffer_size
self._steps_per_epoch = steps_per_epoch
self._buffer_batch_size = buffer_batch_size
self._discount = discount
self._reward_scale = reward_scale
self.max_path_length = max_path_length
self._max_eval_path_length = (max_eval_path_length or max_path_length)
# used by OffPolicyVectorizedSampler
self.policy = policy
self.env_spec = env_spec
self.replay_buffer = replay_buffer
self.exploration_policy = None
self.sampler_cls = OffPolicyVectorizedSampler
self._reward_scale = reward_scale
# use 2 target q networks
self._target_qf1 = copy.deepcopy(self._qf1)
self._target_qf2 = copy.deepcopy(self._qf2)
self._policy_optimizer = self._optimizer(self.policy.parameters(),
lr=self._policy_lr)
self._qf1_optimizer = self._optimizer(self._qf1.parameters(),
lr=self._qf_lr)
self._qf2_optimizer = self._optimizer(self._qf2.parameters(),
lr=self._qf_lr)
# automatic entropy coefficient tuning
self._use_automatic_entropy_tuning = fixed_alpha is None
self._fixed_alpha = fixed_alpha
if self._use_automatic_entropy_tuning:
if target_entropy:
self._target_entropy = target_entropy
else:
self._target_entropy = -np.prod(
self.env_spec.action_space.shape).item()
self._log_alpha = torch.Tensor([self._initial_log_entropy
]).requires_grad_()
self._alpha_optimizer = optimizer([self._log_alpha],
lr=self._policy_lr)
else:
self._log_alpha = torch.Tensor([self._fixed_alpha]).log()
self.episode_rewards = deque(maxlen=30)
[docs] def train(self, runner):
"""Obtain samplers and start actual training for each epoch.
Args:
runner (LocalRunner): LocalRunner is passed to give algorithm
the access to runner.step_epochs(), which provides services
such as snapshotting and sampler control.
Returns:
float: The average return in last epoch cycle.
"""
if not self._eval_env:
self._eval_env = runner.get_env_copy()
last_return = None
for _ in runner.step_epochs():
for _ in range(self._steps_per_epoch):
if not (self.replay_buffer.n_transitions_stored >=
self._min_buffer_size):
batch_size = int(self._min_buffer_size)
else:
batch_size = None
runner.step_path = runner.obtain_samples(
runner.step_itr, batch_size)
path_returns = []
for path in runner.step_path:
self.replay_buffer.add_path(
dict(observation=path['observations'],
action=path['actions'],
reward=path['rewards'].reshape(-1, 1),
next_observation=path['next_observations'],
terminal=path['dones'].reshape(-1, 1)))
path_returns.append(sum(path['rewards']))
assert len(path_returns) == len(runner.step_path)
self.episode_rewards.append(np.mean(path_returns))
for _ in range(self._gradient_steps):
policy_loss, qf1_loss, qf2_loss = self.train_once()
last_return = self._evaluate_policy(runner.step_itr)
self._log_statistics(policy_loss, qf1_loss, qf2_loss)
tabular.record('TotalEnvSteps', runner.total_env_steps)
runner.step_itr += 1
return np.mean(last_return)
[docs] def train_once(self, itr=None, paths=None):
"""Complete 1 training iteration of SAC.
Args:
itr (int): Iteration number. This argument is deprecated.
paths (list[dict]): A list of collected paths.
This argument is deprecated.
Returns:
torch.Tensor: loss from actor/policy network after optimization.
torch.Tensor: loss from 1st q-function after optimization.
torch.Tensor: loss from 2nd q-function after optimization.
"""
del itr
del paths
if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
samples = self.replay_buffer.sample_transitions(
self._buffer_batch_size)
samples = dict_np_to_torch(samples)
policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples)
self._update_targets()
return policy_loss, qf1_loss, qf2_loss
def _get_log_alpha(self, samples_data):
"""Return the value of log_alpha.
Args:
samples_data (dict): Transitions(S,A,R,S') that are sampled from
the replay buffer. It should have the keys 'observation',
'action', 'reward', 'terminal', and 'next_observations'.
This function exists in case there are versions of sac that need
access to a modified log_alpha, such as multi_task sac.
Note:
samples_data's entries should be torch.Tensor's with the following
shapes:
observation: :math:`(N, O^*)`
action: :math:`(N, A^*)`
reward: :math:`(N, 1)`
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Returns:
torch.Tensor: log_alpha
"""
del samples_data
log_alpha = self._log_alpha
return log_alpha
def _temperature_objective(self, log_pi, samples_data):
"""Compute the temperature/alpha coefficient loss.
Args:
log_pi(torch.Tensor): log probability of actions that are sampled
from the replay buffer. Shape is (1, buffer_batch_size).
samples_data (dict): Transitions(S,A,R,S') that are sampled from
the replay buffer. It should have the keys 'observation',
'action', 'reward', 'terminal', and 'next_observations'.
Note:
samples_data's entries should be torch.Tensor's with the following
shapes:
observation: :math:`(N, O^*)`
action: :math:`(N, A^*)`
reward: :math:`(N, 1)`
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Returns:
torch.Tensor: the temperature/alpha coefficient loss.
"""
alpha_loss = 0
if self._use_automatic_entropy_tuning:
alpha_loss = (-(self._get_log_alpha(samples_data)) *
(log_pi.detach() + self._target_entropy)).mean()
return alpha_loss
def _actor_objective(self, samples_data, new_actions, log_pi_new_actions):
"""Compute the Policy/Actor loss.
Args:
samples_data (dict): Transitions(S,A,R,S') that are sampled from
the replay buffer. It should have the keys 'observation',
'action', 'reward', 'terminal', and 'next_observations'.
new_actions(torch.Tensor): Actions resampled from the policy based
based on the Observations, obs, which were sampled from the
replay buffer. Shape is (action_dim, buffer_batch_size).
log_pi_new_actions(torch.Tensor): Log probability of the new
actions on the TanhNormal distributions that they were sampled
from. Shape is (1, buffer_batch_size).
Note:
samples_data's entries should be torch.Tensor's with the following
shapes:
observation: :math:`(N, O^*)`
action: :math:`(N, A^*)`
reward: :math:`(N, 1)`
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Returns:
torch.Tensor: loss from the Policy/Actor.
"""
obs = samples_data['observation']
with torch.no_grad():
alpha = self._get_log_alpha(samples_data).exp()
min_q_new_actions = torch.min(self._qf1(obs, new_actions),
self._qf2(obs, new_actions))
policy_objective = ((alpha * log_pi_new_actions) -
min_q_new_actions.flatten()).mean()
return policy_objective
def _critic_objective(self, samples_data):
"""Compute the Q-function/critic loss.
Args:
samples_data (dict): Transitions(S,A,R,S') that are sampled from
the replay buffer. It should have the keys 'observation',
'action', 'reward', 'terminal', and 'next_observations'.
Note:
samples_data's entries should be torch.Tensor's with the following
shapes:
observation: :math:`(N, O^*)`
action: :math:`(N, A^*)`
reward: :math:`(N, 1)`
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Returns:
torch.Tensor: loss from 1st q-function after optimization.
torch.Tensor: loss from 2nd q-function after optimization.
"""
obs = samples_data['observation']
actions = samples_data['action']
rewards = samples_data['reward'].flatten()
terminals = samples_data['terminal'].flatten()
next_obs = samples_data['next_observation']
with torch.no_grad():
alpha = self._get_log_alpha(samples_data).exp()
q1_pred = self._qf1(obs, actions)
q2_pred = self._qf2(obs, actions)
new_next_actions_dist = self.policy(next_obs)[0]
new_next_actions_pre_tanh, new_next_actions = (
new_next_actions_dist.rsample_with_pre_tanh_value())
new_log_pi = new_next_actions_dist.log_prob(
value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh)
target_q_values = torch.min(
self._target_qf1(next_obs, new_next_actions),
self._target_qf2(
next_obs, new_next_actions)).flatten() - (alpha * new_log_pi)
with torch.no_grad():
q_target = rewards * self._reward_scale + (
1. - terminals) * self._discount * target_q_values
qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)
return qf1_loss, qf2_loss
def _update_targets(self):
"""Update parameters in the target q-functions."""
target_qfs = [self._target_qf1, self._target_qf2]
qfs = [self._qf1, self._qf2]
for target_qf, qf in zip(target_qfs, qfs):
for t_param, param in zip(target_qf.parameters(), qf.parameters()):
t_param.data.copy_(t_param.data * (1.0 - self._tau) +
param.data * self._tau)
[docs] def optimize_policy(self, samples_data):
"""Optimize the policy q_functions, and temperature coefficient.
Args:
samples_data (dict): Transitions(S,A,R,S') that are sampled from
the replay buffer. It should have the keys 'observation',
'action', 'reward', 'terminal', and 'next_observations'.
Note:
samples_data's entries should be torch.Tensor's with the following
shapes:
observation: :math:`(N, O^*)`
action: :math:`(N, A^*)`
reward: :math:`(N, 1)`
terminal: :math:`(N, 1)`
next_observation: :math:`(N, O^*)`
Returns:
torch.Tensor: loss from actor/policy network after optimization.
torch.Tensor: loss from 1st q-function after optimization.
torch.Tensor: loss from 2nd q-function after optimization.
"""
obs = samples_data['observation']
qf1_loss, qf2_loss = self._critic_objective(samples_data)
self._qf1_optimizer.zero_grad()
qf1_loss.backward()
self._qf1_optimizer.step()
self._qf2_optimizer.zero_grad()
qf2_loss.backward()
self._qf2_optimizer.step()
action_dists = self.policy(obs)[0]
new_actions_pre_tanh, new_actions = (
action_dists.rsample_with_pre_tanh_value())
log_pi_new_actions = action_dists.log_prob(
value=new_actions, pre_tanh_value=new_actions_pre_tanh)
policy_loss = self._actor_objective(samples_data, new_actions,
log_pi_new_actions)
self._policy_optimizer.zero_grad()
policy_loss.backward()
self._policy_optimizer.step()
if self._use_automatic_entropy_tuning:
alpha_loss = self._temperature_objective(log_pi_new_actions,
samples_data)
self._alpha_optimizer.zero_grad()
alpha_loss.backward()
self._alpha_optimizer.step()
return policy_loss, qf1_loss, qf2_loss
def _evaluate_policy(self, epoch):
"""Evaluate the performance of the policy via deterministic rollouts.
Statistics such as (average) discounted return and success rate are
recorded.
Args:
epoch(int): The current training epoch.
Returns:
float: The average return across self._num_evaluation_trajectories
trajectories
"""
eval_trajectories = obtain_evaluation_samples(
self.policy,
self._eval_env,
max_path_length=self._max_eval_path_length,
num_trajs=self._num_evaluation_trajectories)
last_return = log_performance(epoch,
eval_trajectories,
discount=self._discount)
return last_return
def _log_statistics(self, policy_loss, qf1_loss, qf2_loss):
"""Record training statistics to dowel such as losses and returns.
Args:
policy_loss(torch.Tensor): loss from actor/policy network.
qf1_loss(torch.Tensor): loss from 1st qf/critic network.
qf2_loss(torch.Tensor): loss from 2nd qf/critic network.
"""
with torch.no_grad():
tabular.record('AlphaTemperature/mean',
self._log_alpha.exp().mean().item())
tabular.record('Policy/Loss', policy_loss.item())
tabular.record('QF/{}'.format('Qf1Loss'), float(qf1_loss))
tabular.record('QF/{}'.format('Qf2Loss'), float(qf2_loss))
tabular.record('ReplayBuffer/buffer_size',
self.replay_buffer.n_transitions_stored)
tabular.record('Average/TrainAverageReturn',
np.mean(self.episode_rewards))
@property
def networks(self):
"""Return all the networks within the model.
Returns:
list: A list of networks.
"""
return [
self.policy, self._qf1, self._qf2, self._target_qf1,
self._target_qf2
]
[docs] def to(self, device=None):
"""Put all the networks within the model on device.
Args:
device (str): ID of GPU or CPU.
"""
if device is None:
device = global_device()
for net in self.networks:
net.to(device)
if not self._use_automatic_entropy_tuning:
self._log_alpha = torch.Tensor([self._fixed_alpha
]).log().to(device)
else:
self._log_alpha = torch.Tensor([self._initial_log_entropy
]).to(device).requires_grad_()
self._alpha_optimizer = self._optimizer([self._log_alpha],
lr=self._policy_lr)