Source code for garage.replay_buffer.simple_replay_buffer

"""This module implements a simple replay buffer."""
import numpy as np

from garage.replay_buffer.base import ReplayBuffer


[docs]class SimpleReplayBuffer(ReplayBuffer): """ This class implements SimpleReplayBuffer. It uses random batch sample to minimize correlations between samples. """
[docs] def sample(self, batch_size): """Sample a transition of batch_size.""" assert self._n_transitions_stored >= batch_size buffer = {} for key in self._buffer.keys(): buffer[key] = self._buffer[key][:self._current_size] # Select which episodes to use time_horizon = buffer['action'].shape[1] rollout_batch_size = buffer['action'].shape[0] episode_idxs = np.random.randint(rollout_batch_size, size=batch_size) # Select time steps to use t_samples = np.random.randint(time_horizon, size=batch_size) transitions = {} for key in buffer.keys(): samples = buffer[key][episode_idxs, t_samples] transitions[key] = samples.reshape(batch_size, *samples.shape[1:]) assert (transitions['action'].shape[0] == batch_size) return transitions