Source code for garage.envs.wrappers.resize

"""Resize wrapper for gym.Env."""
import warnings

import gym
import gym.spaces
import numpy as np
from skimage import img_as_ubyte
from skimage.transform import resize

[docs]class Resize(gym.Wrapper): """gym.Env wrapper for resizing frame to (width, height). Only works with gym.spaces.Box environment with 2D single channel frames. Example: | env = gym.make('Env') | # env.observation_space = (100, 100) | env_wrapped = Resize(gym.make('Env'), width=64, height=64) | # env.observation_space = (64, 64) Args: env: gym.Env to wrap. width: resized frame width. height: resized frame height. Raises: ValueError: If observation space shape is not 2 or environment is not gym.spaces.Box. """ def __init__(self, env, width, height): if not isinstance(env.observation_space, gym.spaces.Box): raise ValueError('Resize only works with Box environment.') if len(env.observation_space.shape) != 2: raise ValueError('Resize only works with 2D single channel image.') super().__init__(env) _low = env.observation_space.low.flatten()[0] _high = env.observation_space.high.flatten()[0] self._dtype = env.observation_space.dtype self._observation_space = gym.spaces.Box(_low, _high, shape=[width, height], dtype=self._dtype) self._width = width self._height = height @property def observation_space(self): """gym.Env observation space.""" return self._observation_space @observation_space.setter def observation_space(self, observation_space): self._observation_space = observation_space def _observation(self, obs): with warnings.catch_warnings(): """ Suppressing warnings for 1. possible precision loss when converting from float64 to uint8 2. anti-aliasing will be enabled by default in skimage 0.15 """ warnings.simplefilter('ignore') obs = resize(obs, (self._width, self._height)) # now it's float if self._dtype == np.uint8: obs = img_as_ubyte(obs) return obs
[docs] def reset(self): """gym.Env reset function.""" return self._observation(self.env.reset())
[docs] def step(self, action): """gym.Env step function.""" obs, reward, done, info = self.env.step(action) return self._observation(obs), reward, done, info