Source code for garage.tf.optimizers.utils

"""Utilities for TensorFlow optimizers."""
import numpy as np


[docs]def sliced_fun(f, n_slices): """Divide function f's inputs into several slices. Evaluate f on those slices, and then average the result. It is useful when memory is not enough to process all data at once. Assume: 1. each of f's inputs is iterable and composed of multiple "samples" 2. outputs can be averaged over "samples" """ def _sliced_f(sliced_inputs, non_sliced_inputs=None): # yapf: disable if non_sliced_inputs is None: non_sliced_inputs = [] if isinstance(non_sliced_inputs, tuple): non_sliced_inputs = list(non_sliced_inputs) n_paths = len(sliced_inputs[0]) slice_size = max(1, n_paths // n_slices) ret_vals = None for start in range(0, n_paths, slice_size): inputs_slice = [v[start:start + slice_size] for v in sliced_inputs] slice_ret_vals = f(*(inputs_slice + non_sliced_inputs)) if not isinstance(slice_ret_vals, (tuple, list)): slice_ret_vals_as_list = [slice_ret_vals] else: slice_ret_vals_as_list = slice_ret_vals scaled_ret_vals = [ np.asarray(v) * len(inputs_slice[0]) for v in slice_ret_vals_as_list ] if ret_vals is None: ret_vals = scaled_ret_vals else: ret_vals = [x + y for x, y in zip(ret_vals, scaled_ret_vals)] ret_vals = [v / n_paths for v in ret_vals] if not isinstance(slice_ret_vals, (tuple, list)): ret_vals = ret_vals[0] elif isinstance(slice_ret_vals, tuple): ret_vals = tuple(ret_vals) return ret_vals return _sliced_f
[docs]class LazyDict: """An immutable, lazily-evaluated dict.""" def __init__(self, **kwargs): self._lazy_dict = kwargs self._dict = {} def __getitem__(self, key): """Implement `object.__getitem__`.""" if key not in self._dict: self._dict[key] = self._lazy_dict[key]() return self._dict[key] def __setitem__(self, i, y): """Implement `object.__setitem__`.""" self.set(i, y)
[docs] def get(self, key, default=None): """Implement `dict.get`.""" if key in self._lazy_dict: return self[key] return default
[docs] def set(self, key, value): """Implement `dict.set`.""" self._lazy_dict[key] = value