Source code for garage.sampler.default_worker

"""Default Worker class."""
from collections import defaultdict

import numpy as np

from garage import EpisodeBatch, StepType
from garage.experiment import deterministic
from garage.sampler import _apply_env_update
from garage.sampler.worker import Worker


[docs]class DefaultWorker(Worker): """Initialize a worker. Args: seed (int): The seed to use to intialize random number generators. max_episode_length (int or float): The maximum length of episodes which will be sampled. Can be (floating point) infinity. worker_number (int): The number of the worker where this update is occurring. This argument is used to set a different seed for each worker. Attributes: agent (Policy or None): The worker's agent. env (Environment or None): The worker's environment. """ def __init__( self, *, # Require passing by keyword, since everything's an int. seed, max_episode_length, worker_number): super().__init__(seed=seed, max_episode_length=max_episode_length, worker_number=worker_number) self.agent = None self.env = None self._env_steps = [] self._observations = [] self._last_observations = [] self._agent_infos = defaultdict(list) self._lengths = [] self._prev_obs = None self._eps_length = 0 self.worker_init()
[docs] def worker_init(self): """Initialize a worker.""" if self._seed is not None: deterministic.set_seed(self._seed + self._worker_number)
[docs] def update_agent(self, agent_update): """Update an agent, assuming it implements :class:`~Policy`. Args: agent_update (np.ndarray or dict or Policy): If a tuple, dict, or np.ndarray, these should be parameters to agent, which should have been generated by calling `Policy.get_param_values`. Alternatively, a policy itself. Note that other implementations of `Worker` may take different types for this parameter. """ if isinstance(agent_update, (dict, tuple, np.ndarray)): self.agent.set_param_values(agent_update) elif agent_update is not None: self.agent = agent_update
[docs] def update_env(self, env_update): """Use any non-None env_update as a new environment. A simple env update function. If env_update is not None, it should be the complete new environment. This allows changing environments by passing the new environment as `env_update` into `obtain_samples`. Args: env_update(Environment or EnvUpdate or None): The environment to replace the existing env with. Note that other implementations of `Worker` may take different types for this parameter. Raises: TypeError: If env_update is not one of the documented types. """ self.env, _ = _apply_env_update(self.env, env_update)
[docs] def start_episode(self): """Begin a new episode.""" self._eps_length = 0 self._prev_obs, _ = self.env.reset() self.agent.reset()
[docs] def step_episode(self): """Take a single time-step in the current episode. Returns: bool: True iff the episode is done, either due to the environment indicating termination of due to reaching `max_episode_length`. """ if self._eps_length < self._max_episode_length: a, agent_info = self.agent.get_action(self._prev_obs) es = self.env.step(a) self._observations.append(self._prev_obs) self._env_steps.append(es) for k, v in agent_info.items(): self._agent_infos[k].append(v) self._eps_length += 1 if not es.terminal: self._prev_obs = es.observation return False self._lengths.append(self._eps_length) self._last_observations.append(self._prev_obs) return True
[docs] def collect_episode(self): """Collect the current episode, clearing the internal buffer. Returns: EpisodeBatch: A batch of the episodes completed since the last call to collect_episode(). """ observations = self._observations self._observations = [] last_observations = self._last_observations self._last_observations = [] actions = [] rewards = [] env_infos = defaultdict(list) step_types = [] for es in self._env_steps: actions.append(es.action) rewards.append(es.reward) step_types.append(es.step_type) for k, v in es.env_info.items(): env_infos[k].append(v) self._env_steps = [] agent_infos = self._agent_infos self._agent_infos = defaultdict(list) for k, v in agent_infos.items(): agent_infos[k] = np.asarray(v) for k, v in env_infos.items(): env_infos[k] = np.asarray(v) lengths = self._lengths self._lengths = [] return EpisodeBatch(env_spec=self.env.spec, observations=np.asarray(observations), last_observations=np.asarray(last_observations), actions=np.asarray(actions), rewards=np.asarray(rewards), step_types=np.asarray(step_types, dtype=StepType), env_infos=dict(env_infos), agent_infos=dict(agent_infos), lengths=np.asarray(lengths, dtype='i'))
[docs] def rollout(self): """Sample a single episode of the agent in the environment. Returns: EpisodeBatch: The collected episode. """ self.start_episode() while not self.step_episode(): pass return self.collect_episode()
[docs] def shutdown(self): """Close the worker's environment.""" self.env.close()