Source code for garage.tf.q_functions.base

"""Q-function base classes without Parameterized."""
import abc


[docs]class QFunction(abc.ABC): """Q-function base class without Parameterzied. Args: name (str): Name of the Q-fucntion, also the variable scope. """ def __init__(self, name): self.name = name or type(self).__name__ self._variable_scope = None
[docs] def get_qval_sym(self, *input_phs): """Symbolic graph for q-network. All derived classes should implement this function. Args: input_phs (list[tf.Tensor]): Recommended to be positional arguments, e.g. def get_qval_sym(self, state_input, action_input). """
[docs] def clone(self, name): """Return a clone of the Q-function. It should only copy the configuration of the Q-function, not the parameters. Args: name (str): Name of the newly created q-function. """
[docs] def get_trainable_vars(self): """Get all trainable variables under the QFunction scope.""" return self._variable_scope.trainable_variables()
[docs] def get_global_vars(self): """Get all global variables under the QFunction scope.""" return self._variable_scope.global_variables()
[docs] def get_regularizable_vars(self): """Get all network weight variables under the QFunction scope.""" trainable = self._variable_scope.global_variables() return [ var for var in trainable if 'hidden' in var.name and 'kernel' in var.name ]
[docs] def log_diagnostics(self, paths): """Log extra information per iteration based on the collected paths."""