Source code for garage.sampler.env_update

"""A callable that "updates" an environment."""
import warnings


[docs]class EnvUpdate: """A callable that "updates" an environment. Implementors of this interface can be called on environments to update them. The passed in environment should then be ignored, and the returned one used instead. Since no new environment needs to be passed in, this type can also be used to construct new environments. """ # pylint: disable=too-few-public-methods def __call__(self, old_env=None): """Update an environment. Note that this implementation does nothing. Args: old_env (Environment or None): Previous environment. Should not be used after being passed in, and should not be closed. Returns: Environment: The new, updated environment. """ return old_env
[docs]class NewEnvUpdate(EnvUpdate): """:class:`~EnvUpdate` that creates a new environment every update. Args: env_constructor (Callable[Environment]): Callable that constructs an environment. """ # pylint: disable=too-few-public-methods def __init__(self, env_constructor): self._env_constructor = env_constructor def __call__(self, old_env=None): """Update an environment. Args: old_env (Environment or None): Previous environment. Should not be used after being passed in, and should not be closed. Returns: Environment: The new, updated environment. """ if old_env: old_env.close() return self._env_constructor()
[docs]class SetTaskUpdate(EnvUpdate): """:class:`~EnvUpdate` that calls set_task with the provided task. Args: env_type (type): Type of environment. task (object): Opaque task type. wrapper_constructor (Callable[garage.Env, garage.Env] or None): Callable that wraps constructed environments. """ # pylint: disable=too-few-public-methods def __init__(self, env_type, task, wrapper_constructor): if not isinstance(env_type, type): raise ValueError('env_type should be a type, not ' f'{type(env_type)!r}') self._env_type = env_type self._task = task self._wrapper_cons = wrapper_constructor def _make_env(self): """Construct the environment, wrapping if necessary. Returns: garage.Env: The (possibly wrapped) environment. """ env = self._env_type() env.set_task(self._task) if self._wrapper_cons is not None: env = self._wrapper_cons(env, self._task) return env def __call__(self, old_env=None): """Update an environment. Args: old_env (Environment or None): Previous environment. Should not be used after being passed in, and should not be closed. Returns: Environment: The new, updated environment. """ if old_env is None: return self._make_env() elif not isinstance(getattr(old_env, 'unwrapped', old_env), self._env_type): warnings.warn('SetTaskEnvUpdate is closing an environment. This ' 'may indicate a very slow TaskSampler setup.') old_env.close() return self._make_env() else: old_env.set_task(self._task) return old_env
[docs]class ExistingEnvUpdate(EnvUpdate): """:class:`~EnvUpdate` that carries an already constructed environment. Args: env (Environment): The environment. """ def __init__(self, env): self._env = env def __call__(self, old_env=None): """Update an environment. This implementation does not close the old environment. Args: old_env (Environment or None): Previous environment. Should not be used after being passed in, and should not be closed. Returns: Environment: The new, updated environment. """ return self._env def __getstate__(self): """Get the pickle state. Returns: dict: The pickled state. """ warnings.warn('ExistingEnvUpdate is generally not the most efficient ' 'method of transmitting environments to other ' 'processes.') return self.__dict__