garage.tf.misc.tensor_utils module

center_advs(advs, axes, eps, offset=0, scale=1, name=None)[source]

Normalize the advs tensor

compile_function(inputs, outputs, log_name=None)[source]
compute_advantages(discount, gae_lambda, max_len, baselines, rewards, name=None)[source]
concat_tensor_dict_list(tensor_dict_list)[source]
concat_tensor_list(tensor_list)[source]
discounted_returns(discount, max_len, rewards, name=None)[source]
filter_valids(t, valid, name='filter_valids')[source]
filter_valids_dict(d, valid, name=None)[source]
flatten_batch(t, name='flatten_batch')[source]
flatten_batch_dict(d, name=None)[source]
flatten_inputs(deep)[source]
flatten_tensor_variables(ts)[source]
get_target_ops(variables, target_variables, tau=None)[source]

Get target variables update operations.

In RL algorithms we often update target network every n steps. This function returns the tf.Operation for updating target variables (denoted by target_var) from variables (denote by var) with fraction tau. In other words, each time we want to keep tau of the var and add (1 - tau) of target_var to var.

Parameters:
  • variables (list[tf.Variable]) – Soure variables for update.
  • target_variable (list[tf.Variable]) – Target variables to be updated.
  • tau (float) – Fraction to update. Set it to be None for hard-update.
graph_inputs(name, **kwargs)[source]
new_tensor(name, ndim, dtype)[source]
new_tensor_like(name, arr_like)[source]
pad_tensor(x, max_len)[source]
pad_tensor_dict(tensor_dict, max_len)[source]
pad_tensor_n(xs, max_len)[source]
positive_advs(advs, eps, name=None)[source]

Make all the values in the advs tensor positive

split_tensor_dict_list(tensor_dict)[source]
stack_tensor_dict_list(tensor_dict_list)[source]

Stack a list of dictionaries of {tensors or dictionary of tensors}. :param tensor_dict_list: a list of dictionaries of {tensors or dictionary

of tensors}.
Returns:a dictionary of {stacked tensors or dictionary of stacked tensors}
stack_tensor_list(tensor_list)[source]
unflatten_tensor_variables(flatarr, shapes, symb_arrs)[source]