Source code for garage.sampler.on_policy_vectorized_sampler

"""BatchSampler which uses VecEnvExecutor to run multiple environments."""
import itertools
import pickle

from dowel import logger, tabular
import numpy as np

from garage.experiment import deterministic
from garage.misc import tensor_utils
from garage.misc.prog_bar_counter import ProgBarCounter
from garage.sampler.batch_sampler import BatchSampler
from garage.sampler.stateful_pool import singleton_pool
from garage.sampler.utils import truncate_paths
from garage.sampler.vec_env_executor import VecEnvExecutor


[docs]class OnPolicyVectorizedSampler(BatchSampler): """BatchSampler which uses VecEnvExecutor to run multiple environments.""" def __init__(self, algo, env, n_envs=None): if n_envs is None: n_envs = singleton_pool.n_parallel * 4 super().__init__(algo, env) self.n_envs = n_envs
[docs] def start_worker(self): """Start workers.""" n_envs = self.n_envs envs = [pickle.loads(pickle.dumps(self.env)) for _ in range(n_envs)] # Deterministically set environment seeds based on the global seed. for (i, e) in enumerate(envs): e.seed(deterministic.get_seed() + i) self.vec_env = VecEnvExecutor( envs=envs, max_path_length=self.algo.max_path_length) self.env_spec = self.env.spec
[docs] def shutdown_worker(self): """Shutdown workers.""" self.vec_env.close()
[docs] def obtain_samples(self, itr, batch_size=None, whole_paths=True): """Obtain samples.""" logger.log('Obtaining samples for iteration %d...' % itr) if not batch_size: batch_size = self.algo.max_path_length * self.n_envs paths = [] n_samples = 0 obses = self.vec_env.reset() dones = np.asarray([True] * self.vec_env.num_envs) running_paths = [None] * self.vec_env.num_envs pbar = ProgBarCounter(batch_size) policy_time = 0 env_time = 0 process_time = 0 policy = self.algo.policy import time while n_samples < batch_size: t = time.time() policy.reset(dones) actions, agent_infos = policy.get_actions(obses) policy_time += time.time() - t t = time.time() next_obses, rewards, dones, env_infos = self.vec_env.step(actions) env_time += time.time() - t t = time.time() agent_infos = tensor_utils.split_tensor_dict_list(agent_infos) env_infos = tensor_utils.split_tensor_dict_list(env_infos) if env_infos is None: env_infos = [dict() for _ in range(self.vec_env.num_envs)] if agent_infos is None: agent_infos = [dict() for _ in range(self.vec_env.num_envs)] for idx, observation, action, reward, env_info, agent_info, done in zip( # noqa: E501 itertools.count(), obses, actions, rewards, env_infos, agent_infos, dones): if running_paths[idx] is None: running_paths[idx] = dict( observations=[], actions=[], rewards=[], env_infos=[], agent_infos=[], ) running_paths[idx]['observations'].append(observation) running_paths[idx]['actions'].append(action) running_paths[idx]['rewards'].append(reward) running_paths[idx]['env_infos'].append(env_info) running_paths[idx]['agent_infos'].append(agent_info) if done: obs = np.asarray(running_paths[idx]['observations']) actions = np.asarray(running_paths[idx]['actions']) paths.append( dict(observations=obs, actions=actions, rewards=tensor_utils.stack_tensor_list( running_paths[idx]['rewards']), env_infos=tensor_utils.stack_tensor_dict_list( running_paths[idx]['env_infos']), agent_infos=tensor_utils.stack_tensor_dict_list( running_paths[idx]['agent_infos']))) n_samples += len(running_paths[idx]['rewards']) running_paths[idx] = None process_time += time.time() - t pbar.inc(len(obses)) obses = next_obses pbar.stop() tabular.record('PolicyExecTime', policy_time) tabular.record('EnvExecTime', env_time) tabular.record('ProcessExecTime', process_time) if whole_paths: return paths else: paths_truncated = truncate_paths(paths, batch_size) return paths_truncated