Source code for garage.sampler.batch_sampler

"""Class with batch-based sampling."""
import warnings

from garage.sampler import parallel_sampler
from garage.sampler.sampler_deprecated import BaseSampler
from garage.sampler.utils import truncate_paths


[docs]class BatchSampler(BaseSampler): """Class with batch-based sampling. Args: algo (garage.np.algos.RLAlgorithm): The algorithm. env (gym.Env): The environment. """ def __init__(self, algo, env): super().__init__(algo, env) warnings.warn( DeprecationWarning( 'BatchSampler is deprecated, and will be removed in the next ' 'release. Please use one of the samplers which implements ' 'garage.sampler.Sampler, such as LocalSampler.'))
[docs] def start_worker(self): """Start workers.""" parallel_sampler.populate_task(self.env, self.algo.policy, scope=self.algo.scope)
[docs] def shutdown_worker(self): """Shutdown workers.""" parallel_sampler.terminate_task(scope=self.algo.scope)
[docs] def obtain_samples(self, itr, batch_size=None, whole_paths=True): """Sample the policy for new trajectories. Args: itr (int): Number of iteration. batch_size (int): Number of environment steps in one batch. whole_paths (bool): Whether to use whole path or truncated. Returns: list[dict]: A list of paths. """ if not batch_size: batch_size = self.algo.max_path_length cur_params = self.algo.policy.get_param_values() paths = parallel_sampler.sample_paths( policy_params=cur_params, max_samples=batch_size, max_path_length=self.algo.max_path_length, scope=self.algo.scope, ) return paths if whole_paths else truncate_paths(paths, batch_size)