Source code for garage.sampler.is_sampler

import copy
from math import exp
from math import log
import random

from numpy import var

from garage.sampler import BatchSampler
from garage.sampler.utils import truncate_paths


[docs]class ISSampler(BatchSampler): """ Sampler which alternates between live sampling iterations using BatchSampler and importance sampling iterations. Args: algo (garage.np.algos.RLAlgorithm): An algorithm instance. env (garage.envs.GarageEnv): An environement instance. n_backtrack (str/int): Number of past policies to update from n_is_pretrain (int): Number of importance sampling iterations to perform in beginning of training init_is (bool): Set initial iteration (after pretrain) an importance sampling iteration. skip_is_itrs (bool): Do not do any importance sampling iterations (after pretrain). hist_variance_penalty (int): Penalize variance of historical policy. max_is_ratio (int): Maximum allowed importance sampling ratio. ess_threshold (int): Minimum effective sample size required. randomize_draw (bool): Whether to randomize important samples. """ def __init__(self, algo, env, n_backtrack='all', n_is_pretrain=0, init_is=0, skip_is_itrs=False, hist_variance_penalty=0.0, max_is_ratio=0, ess_threshold=0, randomize_draw=False): self.n_backtrack = n_backtrack self.n_is_pretrain = n_is_pretrain self.skip_is_itrs = skip_is_itrs self.hist_variance_penalty = hist_variance_penalty self.max_is_ratio = max_is_ratio self.ess_threshold = ess_threshold self.randomize_draw = randomize_draw self._hist = [] self._is_itr = init_is super().__init__(algo, env) @property def history(self): """ History of policies that have interacted with the environment and the data from interaction episode(s) """ return self._hist
[docs] def add_history(self, policy_distribution, paths): """ Store policy distribution and paths in history """ self._hist.append((policy_distribution, paths))
[docs] def get_history_list(self, n_past='all'): """ Get list of (distribution, data) tuples from history """ if n_past == 'all': return self._hist return self._hist[-min(n_past, len(self._hist)):]
[docs] def obtain_samples(self, itr, batch_size=None, whole_paths=True): # Importance sampling for first self.n_is_pretrain iterations if itr < self.n_is_pretrain: paths = self._obtain_is_samples(itr, batch_size, whole_paths) return paths # Alternate between importance sampling and live sampling if self._is_itr and not self.skip_is_itrs: paths = self._obtain_is_samples(itr, batch_size, whole_paths) else: paths = super().obtain_samples(itr, batch_size, whole_paths) if not self.skip_is_itrs: self.add_history(self.algo.policy.distribution, paths) self._is_itr = (self._is_itr + 1) % 2 return paths
def _obtain_is_samples(self, itr, batch_size=None, whole_paths=True): if batch_size is None: batch_size = self.algo.max_path_length paths = [] for hist_policy_distribution, hist_paths in self.get_history_list( self.n_backtrack): h_paths = self._sample_isweighted_paths( policy=self.algo.policy, hist_policy_distribution=hist_policy_distribution, max_samples=batch_size, paths=hist_paths, hist_variance_penalty=self.hist_variance_penalty, max_is_ratio=self.max_is_ratio, ess_threshold=self.ess_threshold, ) paths.extend(h_paths) if len(paths) > batch_size: paths = random.sample(paths, batch_size) return paths if whole_paths else truncate_paths(paths, batch_size) def _sample_isweighted_paths(self, policy, hist_policy_distribution, max_samples, paths=None, hist_variance_penalty=0.0, max_is_ratio=10, ess_threshold=0): if not paths: return [] n_samples = min(len(paths), max_samples) samples = None if self.randomize_draw: samples = random.sample(paths, n_samples) elif paths: if n_samples == len(paths): samples = paths else: start = random.randint(0, len(paths) - n_samples) samples = paths[start:start + n_samples] # make duplicate of samples so we don't permanently alter historical # data samples = copy.deepcopy(samples) if ess_threshold > 0: is_weights = [] dist1 = policy.distribution dist2 = hist_policy_distribution for path in samples: _, agent_infos = policy.get_actions(path['observations']) hist_agent_infos = path['agent_infos'] if hist_variance_penalty > 0: hist_agent_infos['log_std'] += log(1.0 + hist_variance_penalty) path['agent_infos'] = agent_infos # compute importance sampling weight loglike_p = dist1.log_likelihood(path['actions'], agent_infos) loglike_hp = dist2.log_likelihood(path['actions'], hist_agent_infos) is_ratio = exp(sum(loglike_p) - sum(loglike_hp)) # thresholding knobs if max_is_ratio > 0: is_ratio = min(is_ratio, max_is_ratio) if ess_threshold > 0: is_weights.append(is_ratio) # apply importance sampling weight path['rewards'] *= is_ratio if ess_threshold: # Effective sample size estimate. # Kong, Augustine. "A note on importance sampling using # standardized weights." University of Chicago, Dept. # of Statistics, Tech. Rep 348 (1992). if len(is_weights) / (1 + var(is_weights)) < ess_threshold: return [] return samples