Source code for garage.envs.wrappers.grayscale

"""Grayscale wrapper for gym.Env."""
import warnings

import gym
import gym.spaces
import numpy as np
from skimage import color
from skimage import img_as_ubyte


[docs]class Grayscale(gym.Wrapper): """Grayscale wrapper for gym.Env, converting frames to grayscale. Only works with gym.spaces.Box environment with 2D RGB frames. The last dimension (RGB) of environment observation space will be removed. Example: env = gym.make('Env') # env.observation_space = (100, 100, 3) env_wrapped = Grayscale(gym.make('Env')) # env.observation_space = (100, 100) Args: env: gym.Env to wrap. Raises: ValueError: If observation space shape is not 3 or environment is not gym.spaces.Box. """ def __init__(self, env): if not isinstance(env.observation_space, gym.spaces.Box): raise ValueError( 'Grayscale only works with gym.spaces.Box environment.') if len(env.observation_space.shape) != 3: raise ValueError('Grayscale only works with 2D RGB images') super().__init__(env) _low = env.observation_space.low.flatten()[0] _high = env.observation_space.high.flatten()[0] assert _low == 0 assert _high == 255 self._observation_space = gym.spaces.Box( _low, _high, shape=env.observation_space.shape[:-1], dtype=np.uint8) @property def observation_space(self): """gym.Env observation space.""" return self._observation_space @observation_space.setter def observation_space(self, observation_space): self._observation_space = observation_space def _observation(self, obs): with warnings.catch_warnings(): """ Suppressing warning for possible precision loss when converting from float64 to uint8 """ warnings.simplefilter('ignore') return img_as_ubyte(color.rgb2gray((obs)))
[docs] def reset(self): """gym.Env reset function.""" return self._observation(self.env.reset())
[docs] def step(self, action): """gym.Env step function.""" obs, reward, done, info = self.env.step(action) return self._observation(obs), reward, done, info