Source code for garage.tf.optimizers.first_order_optimizer

"""First order optimizer."""
import time

import click
from dowel import logger
import tensorflow as tf

from garage import _Default, make_optimizer
from garage.np.optimizers import BatchDataset
from garage.tf.misc import tensor_utils
from garage.tf.optimizers.utils import LazyDict


[docs]class FirstOrderOptimizer: """First order optimier. Performs (stochastic) gradient descent, possibly using fancier methods like ADAM etc. Args: optimizer (tf.Optimizer): Optimizer to be used. learning_rate (dict): learning rate arguments. learning rates are our main interest parameters to tune optimizers. max_epochs (int): Maximum number of epochs for update. tolerance (float): Tolerance for difference in loss during update. batch_size (int): Batch size for optimization. callback (callable): Function to call during each epoch. Default is None. verbose (bool): If true, intermediate log message will be printed. name (str): Name scope of the optimizer. """ def __init__(self, optimizer=None, learning_rate=None, max_epochs=1000, tolerance=1e-6, batch_size=32, callback=None, verbose=False, name='FirstOrderOptimizer'): self._opt_fun = None self._target = None self._callback = callback if optimizer is None: optimizer = tf.compat.v1.train.AdamOptimizer learning_rate = learning_rate or dict(learning_rate=_Default(1e-3)) if not isinstance(learning_rate, dict): learning_rate = dict(learning_rate=learning_rate) self._tf_optimizer = optimizer self._learning_rate = learning_rate self._max_epochs = max_epochs self._tolerance = tolerance self._batch_size = batch_size self._verbose = verbose self._input_vars = None self._train_op = None self._name = name
[docs] def update_opt(self, loss, target, inputs, extra_inputs=None, **kwargs): """Construct operation graph for the optimizer. Args: loss (tf.Tensor): Loss objective to minimize. target (object): Target object to optimize. The object should implemenet `get_params()` and `get_param_values`. inputs (list[tf.Tensor]): List of input placeholders. extra_inputs (list[tf.Tensor]): List of extra input placeholders. kwargs (dict): Extra unused keyword arguments. Some optimizers have extra input, e.g. KL constraint. """ del kwargs with tf.name_scope(self._name): self._target = target tf_optimizer = make_optimizer(self._tf_optimizer, **self._learning_rate) self._train_op = tf_optimizer.minimize( loss, var_list=target.get_params()) if extra_inputs is None: extra_inputs = list() self._input_vars = inputs + extra_inputs self._opt_fun = LazyDict( f_loss=lambda: tensor_utils.compile_function( inputs + extra_inputs, loss), )
[docs] def loss(self, inputs, extra_inputs=None): """The loss. Args: inputs (list[numpy.ndarray]): List of input values. extra_inputs (list[numpy.ndarray]): List of extra input values. Returns: float: Loss. Raises: Exception: If loss function is None, i.e. not defined. """ if self._opt_fun is None: raise Exception( 'Use update_opt() to setup the loss function first.') if extra_inputs is None: extra_inputs = tuple() return self._opt_fun['f_loss'](*(tuple(inputs) + extra_inputs))
# pylint: disable=too-many-branches
[docs] def optimize(self, inputs, extra_inputs=None, callback=None): """Perform optimization. Args: inputs (list[numpy.ndarray]): List of input values. extra_inputs (list[numpy.ndarray]): List of extra input values. callback (callable): Function to call during each epoch. Default is None. Raises: NotImplementedError: If inputs are invalid. Exception: If loss function is None, i.e. not defined. """ if not inputs: # Assumes that we should always sample mini-batches raise NotImplementedError('No inputs are fed to optimizer.') if self._opt_fun is None: raise Exception( 'Use update_opt() to setup the loss function first.') f_loss = self._opt_fun['f_loss'] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.compat.v1.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log('Epoch {}'.format(epoch)) with click.progressbar(length=len(inputs[0]), label='Optimizing minibatches') as pbar: for batch in dataset.iterate(update=True): sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) pbar.update(len(batch[0])) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log('Epoch: {} | Loss: {}'.format(epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values() if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def __getstate__(self): """Object.__getstate__. Returns: dict: The state to be pickled for the instance. """ new_dict = self.__dict__.copy() del new_dict['_opt_fun'] del new_dict['_tf_optimizer'] del new_dict['_train_op'] del new_dict['_input_vars'] return new_dict def __setstate__(self, state): """Object.__setstate__. Args: state (dict): Unpickled state. """ obj = type(self)() self.__dict__.update(obj.__dict__) self.__dict__.update(state)