garage.sampler.utils module

Utility functions related to sampling.

rollout(env, agent, *, max_path_length=inf, animated=False, speedup=1, deterministic=False)[source]

Sample a single rollout of the agent in the environment.

Parameters:
  • agent (Policy) – Agent used to select actions.
  • env (gym.Env) – Environment to perform actions in.
  • max_path_length (int) – If the rollout reaches this many timesteps, it is terminated.
  • animated (bool) – If true, render the environment after each step.
  • speedup (float) – Factor by which to decrease the wait time between rendered steps. Only relevant, if animated == true.
  • deterministic (bool) – If true, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution.
Returns:

Dictionary, with keys:
  • observations(np.array): Non-flattened array of observations.
  • actions(np.array): Non-flattened array of actions.
  • rewards(np.array): Array of rewards of shape (timesteps, 1).
  • agent_infos(dict[str, np.ndarray]): Dictionary of stacked,
    non-flattened `agent_info`s.
  • env_infos(dict[str, np.ndarray]): Dictionary of stacked,
    non-flattened `env_info`s.

Return type:

dict[str, np.ndarray or dict]

truncate_paths(paths, max_samples)[source]

Truncate the paths so that the total number of samples is max_samples.

This is done by removing extra paths at the end of the list, and make the last path shorter if necessary

Parameters:
  • paths (list[dict[str, np.ndarray]]) – Samples, items with keys: * observations (np.ndarray): Enviroment observations * actions (np.ndarray): Agent actions * rewards (np.ndarray): Environment rewards * env_infos (dict): Environment state information * agent_infos (dict): Agent state information
  • max_samples (int) – Maximum number of samples allowed.
Returns:

A list of paths, truncated so that the

number of samples adds up to max-samples

Return type:

list[dict[str, np.ndarray]]

Raises:

ValueError – If key a other than ‘observations’, ‘actions’, ‘rewards’, ‘env_infos’ and ‘agent_infos’ is found.