Source code for garage.sampler.utils

"""Utility functions related to sampling."""

import time

import numpy as np

from garage.misc import tensor_utils


[docs]def rollout(env, agent, *, max_path_length=np.inf, animated=False, speedup=1, deterministic=False): """Sample a single rollout of the agent in the environment. Args: agent(Policy): Agent used to select actions. env(gym.Env): Environment to perform actions in. max_path_length(int): If the rollout reaches this many timesteps, it is terminated. animated(bool): If true, render the environment after each step. speedup(float): Factor by which to decrease the wait time between rendered steps. Only relevant, if animated == true. deterministic (bool): If true, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution. Returns: dict[str, np.ndarray or dict]: Dictionary, with keys: * observations(np.array): Non-flattened array of observations. * actions(np.array): Non-flattened array of actions. * rewards(np.array): Array of rewards of shape (timesteps, 1). * agent_infos(dict[str, np.ndarray]): Dictionary of stacked, non-flattened `agent_info`s. * env_infos(dict[str, np.ndarray]): Dictionary of stacked, non-flattened `env_info`s. """ observations = [] actions = [] rewards = [] agent_infos = [] env_infos = [] o = env.reset() agent.reset() path_length = 0 if animated: env.render() while path_length < max_path_length: a, agent_info = agent.get_action(o) if deterministic: a = agent_info['mean'] next_o, r, d, env_info = env.step(a) observations.append(o) rewards.append(r) actions.append(a) agent_infos.append(agent_info) env_infos.append(env_info) path_length += 1 if d: break o = next_o if animated: env.render() timestep = 0.05 time.sleep(timestep / speedup) return dict( observations=np.array(observations), actions=np.array(actions), rewards=np.array(rewards), agent_infos=tensor_utils.stack_tensor_dict_list(agent_infos), env_infos=tensor_utils.stack_tensor_dict_list(env_infos), )
[docs]def truncate_paths(paths, max_samples): """Truncate the paths so that the total number of samples is max_samples. This is done by removing extra paths at the end of the list, and make the last path shorter if necessary Args: paths (list[dict[str, np.ndarray]]): Samples, items with keys: * observations (np.ndarray): Enviroment observations * actions (np.ndarray): Agent actions * rewards (np.ndarray): Environment rewards * env_infos (dict): Environment state information * agent_infos (dict): Agent state information max_samples(int) : Maximum number of samples allowed. Returns: list[dict[str, np.ndarray]]: A list of paths, truncated so that the number of samples adds up to max-samples Raises: ValueError: If key a other than 'observations', 'actions', 'rewards', 'env_infos' and 'agent_infos' is found. """ # chop samples collected by extra paths # make a copy valid_keys = { 'observations', 'actions', 'rewards', 'env_infos', 'agent_infos' } paths = list(paths) total_n_samples = sum(len(path['rewards']) for path in paths) while paths and total_n_samples - len(paths[-1]['rewards']) >= max_samples: total_n_samples -= len(paths.pop(-1)['rewards']) if paths: last_path = paths.pop(-1) truncated_last_path = dict() truncated_len = len( last_path['rewards']) - (total_n_samples - max_samples) for k, v in last_path.items(): if k in ['observations', 'actions', 'rewards']: truncated_last_path[k] = v[:truncated_len] elif k in ['env_infos', 'agent_infos']: truncated_last_path[k] = tensor_utils.truncate_tensor_dict( v, truncated_len) else: raise ValueError( 'Unexpected key {} found in path. Valid keys: {}'.format( k, valid_keys)) paths.append(truncated_last_path) return paths