garage.envs.multi_env_wrapper module

A wrapper env that handles multiple tasks from different envs.

Useful while training multi-task reinforcement learning algorithms. It provides observations augmented with one-hot representation of tasks.

class MultiEnvWrapper(envs, sample_strategy=<function uniform_random_strategy>)[source]

Bases: gym.core.Wrapper

A wrapper class to handle multiple gym environments.

Parameters:
  • envs (list(gym.Env)) – A list of objects implementing gym.Env.
  • sample_strategy (function(int, int)) – Sample strategy to be used when sampling a new task.
active_task_index

Index of active task env.

Returns:Index of active task.
Return type:int
active_task_one_hot

One-hot representation of active task.

Returns:one-hot representation of active task
Return type:numpy.ndarray
close()[source]

Close all task envs.

num_tasks

Total number of tasks.

Returns:number of tasks.
Return type:int
observation_space

Observation space.

Returns:Observation space.
Return type:akro.Box
reset(**kwargs)[source]

Sample new task and call reset on new task env.

Parameters:kwargs (dict) – Keyword arguments to be passed to gym.Env.reset
Returns:active task one-hot representation + observation
Return type:numpy.ndarray
step(action)[source]

gym.Env step for the active task env.

Parameters:action (object) – object to be passed in gym.Env.reset(action)
Returns:agent’s observation of the current environment float: amount of reward returned after previous action bool: whether the episode has ended dict: contains auxiliary diagnostic information
Return type:object
task_space

Task Space.

Returns:Task space.
Return type:akro.Box
round_robin_strategy(num_tasks, last_task=None)[source]

A function for sampling tasks in round robin fashion.

Parameters:
  • num_tasks (int) – Total number of tasks.
  • last_task (int) – Previously sampled task.
Returns:

task id.

Return type:

int

uniform_random_strategy(num_tasks, _)[source]

A function for sampling tasks uniformly at random.

Parameters:
  • num_tasks (int) – Total number of tasks.
  • _ (object) – Ignored by this sampling strategy.
Returns:

task id.

Return type:

int