Source code for garage.np.optimizers.minibatch_dataset

import numpy as np


[docs]class BatchDataset: def __init__(self, inputs, batch_size, extra_inputs=None): self._inputs = [i for i in inputs] if extra_inputs is None: extra_inputs = [] self._extra_inputs = extra_inputs self._batch_size = batch_size if batch_size is not None: self._ids = np.arange(self._inputs[0].shape[0]) self.update() @property def number_batches(self): if self._batch_size is None: return 1 return int(np.ceil(self._inputs[0].shape[0] * 1.0 / self._batch_size))
[docs] def iterate(self, update=True): if self._batch_size is None: yield list(self._inputs) + list(self._extra_inputs) else: for itr in range(self.number_batches): batch_start = itr * self._batch_size batch_end = (itr + 1) * self._batch_size batch_ids = self._ids[batch_start:batch_end] batch = [d[batch_ids] for d in self._inputs] yield list(batch) + list(self._extra_inputs) if update: self.update()
[docs] def update(self): np.random.shuffle(self._ids)