Source code for garage.replay_buffer.base

This module implements a replay buffer memory.

Replay buffer is an important technique in reinforcement learning. It
stores transitions in a memory buffer of fixed size. When the buffer is
full, oldest memory will be discarded. At each step, a batch of memories
will be sampled from the buffer to update the agent's parameters. In a
word, replay buffer breaks temporal correlations and thus benefits RL

import abc
from abc import abstractmethod

import numpy as np

[docs]class ReplayBuffer(metaclass=abc.ABCMeta): """ Abstract class for Replay Buffer. Args: env_spec (garage.envs.EnvSpec): Environment specification. size_in_transitions (int): total size of transitions in the buffer time_horizon (int): time horizon of rollout. """ def __init__(self, env_spec, size_in_transitions, time_horizon): self._current_size = 0 self._current_ptr = 0 self._n_transitions_stored = 0 self._time_horizon = time_horizon self._size_in_transitions = size_in_transitions self._size = size_in_transitions // time_horizon self._initialized_buffer = False self._buffer = {} self._episode_buffer = {}
[docs] def store_episode(self): """Add an episode to the buffer.""" episode_buffer = self._convert_episode_to_batch_major() rollout_batch_size = len(episode_buffer['observation']) idx = self._get_storage_idx(rollout_batch_size) for key in self._buffer.keys(): self._buffer[key][idx] = episode_buffer[key] self._n_transitions_stored = min( self._size_in_transitions, self._n_transitions_stored + self._time_horizon * rollout_batch_size)
[docs] @abstractmethod def sample(self, batch_size): """Sample a transition of batch_size.""" raise NotImplementedError
[docs] def add_transition(self, **kwargs): """Add one transition into the replay buffer.""" transition = {k: [v] for k, v in kwargs.items()} self.add_transitions(**transition)
[docs] def add_transitions(self, **kwargs): """ Add multiple transitions into the replay buffer. A transition contains one or multiple entries, e.g. observation, action, reward, terminal and next_observation. The same entry of all the transitions are stacked, e.g. {'observation': [obs1, obs2, obs3]} where obs1 is one numpy.ndarray observation from the environment. Args: kwargs (dict(str, [numpy.ndarray])): Dictionary that holds the transitions. """ if not self._initialized_buffer: self._initialize_buffer(**kwargs) for key, value in kwargs.items(): self._episode_buffer[key].append(value) if len(self._episode_buffer['observation']) == self._time_horizon: self.store_episode() for key in self._episode_buffer.keys(): self._episode_buffer[key].clear()
def _initialize_buffer(self, **kwargs): for key, value in kwargs.items(): self._episode_buffer[key] = list() values = np.array(value) self._buffer[key] = np.zeros( [self._size, self._time_horizon, *values.shape[1:]], dtype=values.dtype) self._initialized_buffer = True def _get_storage_idx(self, size_increment=1): """Get the storage index for the episode to add into the buffer.""" if self._current_size + size_increment <= self._size: idx = np.arange(self._current_size, self._current_size + size_increment) elif self._current_size < self._size: overflow = size_increment - (self._size - self._current_size) idx_a = np.arange(self._current_size, self._size) idx_b = np.arange(0, overflow) idx = np.concatenate([idx_a, idx_b]) self._current_ptr = overflow else: if self._current_ptr + size_increment <= self._size: idx = np.arange(self._current_ptr, self._current_ptr + size_increment) self._current_ptr += size_increment else: overflow = size_increment - (self._size - self._current_size) idx_a = np.arange(self._current_ptr, self._size) idx_b = np.arange(0, overflow) idx = np.concatenate([idx_a, idx_b]) self._current_ptr = overflow # Update replay size self._current_size = min(self._size, self._current_size + size_increment) if size_increment == 1: idx = idx[0] return idx def _convert_episode_to_batch_major(self): """ Convert the shape of episode_buffer. episode_buffer: {time_horizon, algo.rollout_batch_size, flat_dim}. buffer: {size, time_horizon, flat_dim}. """ transitions = {} for key in self._episode_buffer.keys(): val = np.array(self._episode_buffer[key]) transitions[key] = val.swapaxes(0, 1) return transitions @property def full(self): """Whether the buffer is full.""" return self._current_size == self._size @property def n_transitions_stored(self): """ Return the size of the replay buffer. Returns: self._size: Size of the current replay buffer. """ return self._n_transitions_stored