garage.trainer

Provides algorithms with access to most of garage’s features.

tf = False
TFWorkerClassWrapper = False
class ExperimentStats(total_epoch, total_itr, total_env_steps, last_episode)

Statistics of a experiment.

Parameters
  • total_epoch (int) – Total epoches.

  • total_itr (int) – Total Iterations.

  • total_env_steps (int) – Total environment steps collected.

  • last_episode (list[dict]) – Last sampled episodes.

class SetupArgs(sampler_cls, sampler_args, seed)

Arguments to setup a trainer.

Parameters
  • sampler_cls (Sampler) – A sampler class.

  • sampler_args (dict) – Arguments to be passed to sampler constructor.

  • seed (int) – Random seed.

class TrainArgs(n_epochs, batch_size, plot, store_episodes, pause_for_plot, start_epoch)

Arguments to call train() or resume().

Parameters
  • n_epochs (int) – Number of epochs.

  • batch_size (int) – Number of environment steps in one batch.

  • plot (bool) – Visualize an episode of the policy after after each epoch.

  • store_episodes (bool) – Save episodes in snapshot.

  • pause_for_plot (bool) – Pause for plot.

  • start_epoch (int) – The starting epoch. Used for resume().

class Trainer(snapshot_config)

Base class of trainer.

Use trainer.setup(algo, env) to setup algorithm and environment for trainer and trainer.train() to start training.

Parameters

snapshot_config (garage.experiment.SnapshotConfig) – The snapshot configuration used by Trainer to create the snapshotter. If None, it will create one with default settings.

Note

For the use of any TensorFlow environments, policies and algorithms, please use TFTrainer().

Examples

# to train
trainer = Trainer()
env = Env(…)
policy = Policy(…)
algo = Algo(
env=env,
policy=policy,
…)
trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=4000)
# to resume immediately.
trainer = Trainer()
trainer.restore(resume_from_dir)
trainer.resume()
# to resume with modified training arguments.
trainer = Trainer()
trainer.restore(resume_from_dir)
trainer.resume(n_epochs=20)
make_sampler(self, sampler_cls, *, seed=None, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, worker_class=None, sampler_args=None, worker_args=None)

Construct a Sampler from a Sampler class.

Parameters
  • sampler_cls (type) – The type of sampler to construct.

  • seed (int) – Seed to use in sampler workers.

  • max_episode_length (int) – Maximum episode length to be sampled by the sampler. Epsiodes longer than this will be truncated.

  • n_workers (int) – The number of workers the sampler should use.

  • worker_class (type) – Type of worker the Sampler should use.

  • sampler_args (dict or None) – Additional arguments that should be passed to the sampler.

  • worker_args (dict or None) – Additional arguments that should be passed to the sampler.

Raises

ValueError – If max_episode_length isn’t passed and the algorithm doesn’t contain a max_episode_length field, or if the algorithm doesn’t have a policy field.

Returns

An instance of the sampler class.

Return type

sampler_cls

setup(self, algo, env, sampler_cls=None, sampler_args=None, n_workers=psutil.cpu_count(logical=False), worker_class=None, worker_args=None)

Set up trainer for algorithm and environment.

This method saves algo and env within trainer and creates a sampler.

Note

After setup() is called all variables in session should have been initialized. setup() respects existing values in session so policy weights can be loaded before setup().

Parameters
  • algo (RLAlgorithm) – An algorithm instance.

  • env (Environment) – An environment instance.

  • sampler_cls (type) – A class which implements Sampler.

  • sampler_args (dict) – Arguments to be passed to sampler constructor.

  • n_workers (int) – The number of workers the sampler should use.

  • worker_class (type) – Type of worker the sampler should use.

  • worker_args (dict or None) – Additional arguments that should be passed to the worker.

Raises

ValueError – If sampler_cls is passed and the algorithm doesn’t contain a max_episode_length field.

obtain_episodes(self, itr, batch_size=None, agent_update=None, env_update=None)

Obtain one batch of episodes.

Parameters
  • itr (int) – Index of iteration (epoch).

  • batch_size (int) – Number of steps in batch. This is a hint that the sampler may or may not respect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before doing sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Raises

ValueError – If the trainer was initialized without a sampler, or batch_size wasn’t provided here or to train.

Returns

Batch of episodes.

Return type

EpisodeBatch

obtain_samples(self, itr, batch_size=None, agent_update=None, env_update=None)

Obtain one batch of samples.

Parameters
  • itr (int) – Index of iteration (epoch).

  • batch_size (int) – Number of steps in batch. This is a hint that the sampler may or may not respect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Raises

ValueError – Raised if the trainer was initialized without a sampler, or batch_size wasn’t provided here or to train.

Returns

One batch of samples.

Return type

list[dict]

save(self, epoch)

Save snapshot of current batch.

Parameters

epoch (int) – Epoch.

Raises

NotSetupError – if save() is called before the trainer is set up.

restore(self, from_dir, from_epoch='last')

Restore experiment from snapshot.

Parameters
  • from_dir (str) – Directory of the pickle file to resume experiment from.

  • from_epoch (str or int) – The epoch to restore from. Can be ‘first’, ‘last’ or a number. Not applicable when snapshot_mode=’last’.

Returns

Arguments for train().

Return type

TrainArgs

log_diagnostics(self, pause_for_plot=False)

Log diagnostics.

Parameters

pause_for_plot (bool) – Pause for plot.

train(self, n_epochs, batch_size=None, plot=False, store_episodes=False, pause_for_plot=False)

Start training.

Parameters
  • n_epochs (int) – Number of epochs.

  • batch_size (int or None) – Number of environment steps in one batch.

  • plot (bool) – Visualize an episode from the policy after each epoch.

  • store_episodes (bool) – Save episodes in snapshot.

  • pause_for_plot (bool) – Pause for plot.

Raises

NotSetupError – If train() is called before setup().

Returns

The average return in last epoch cycle.

Return type

float

step_epochs(self)

Step through each epoch.

This function returns a magic generator. When iterated through, this generator automatically performs services such as snapshotting and log management. It is used inside train() in each algorithm.

The generator initializes two variables: self.step_itr and self.step_episode. To use the generator, these two have to be updated manually in each epoch, as the example shows below.

Yields

int – The next training epoch.

Examples

for epoch in trainer.step_epochs():

trainer.step_episode = trainer.obtain_samples(…) self.train_once(…) trainer.step_itr += 1

resume(self, n_epochs=None, batch_size=None, plot=None, store_episodes=None, pause_for_plot=None)

Resume from restored experiment.

This method provides the same interface as train().

If not specified, an argument will default to the saved arguments from the last call to train().

Parameters
  • n_epochs (int) – Number of epochs.

  • batch_size (int) – Number of environment steps in one batch.

  • plot (bool) – Visualize an episode from the policy after each epoch.

  • store_episodes (bool) – Save episodes in snapshot.

  • pause_for_plot (bool) – Pause for plot.

Raises

NotSetupError – If resume() is called before restore().

Returns

The average return in last epoch cycle.

Return type

float

get_env_copy(self)

Get a copy of the environment.

Returns

An environment instance.

Return type

Environment

property total_env_steps(self)

Total environment steps collected.

Returns

Total environment steps collected.

Return type

int

exception NotSetupError

Bases: Exception

Inheritance diagram of garage.trainer.NotSetupError

Raise when an experiment is about to run without setup.

class args
with_traceback()

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

class TFTrainer(snapshot_config, sess=None)

Bases: garage.trainer.Trainer

Inheritance diagram of garage.trainer.TFTrainer

This class implements a trainer for TensorFlow algorithms.

A trainer provides a default TensorFlow session using python context. This is useful for those experiment components (e.g. policy) that require a TensorFlow session during construction.

Use trainer.setup(algo, env) to setup algorithm and environment for trainer and trainer.train() to start training.

Parameters
  • snapshot_config (garage.experiment.SnapshotConfig) – The snapshot configuration used by Trainer to create the snapshotter. If None, it will create one with default settings.

  • sess (tf.Session) – An optional TensorFlow session. A new session will be created immediately if not provided.

Note

When resume via command line, new snapshots will be saved into the SAME directory if not specified.

When resume programmatically, snapshot directory should be specify manually or through @wrap_experiment interface.

Examples

# to train with TFTrainer() as trainer:

env = gym.make(‘CartPole-v1’) policy = CategoricalMLPPolicy(

env_spec=env.spec, hidden_sizes=(32, 32))

algo = TRPO(

env=env, policy=policy, baseline=baseline, max_episode_length=100, discount=0.99, max_kl_step=0.01)

trainer.setup(algo, env) trainer.train(n_epochs=100, batch_size=4000)

# to resume immediately. with TFTrainer() as trainer:

trainer.restore(resume_from_dir) trainer.resume()

# to resume with modified training arguments. with TFTrainer() as trainer:

trainer.restore(resume_from_dir) trainer.resume(n_epochs=20)

make_sampler(self, sampler_cls, *, seed=None, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, worker_class=None, sampler_args=None, worker_args=None)

Construct a Sampler from a Sampler class.

Parameters
  • sampler_cls (type) – The type of sampler to construct.

  • seed (int) – Seed to use in sampler workers.

  • max_episode_length (int) – Maximum episode length to be sampled by the sampler. Paths longer than this will be truncated.

  • n_workers (int) – The number of workers the sampler should use.

  • worker_class (type) – Type of worker the sampler should use.

  • sampler_args (dict or None) – Additional arguments that should be passed to the sampler.

  • worker_args (dict or None) – Additional arguments that should be passed to the worker.

Returns

An instance of the sampler class.

Return type

sampler_cls

setup(self, algo, env, sampler_cls=None, sampler_args=None, n_workers=psutil.cpu_count(logical=False), worker_class=None, worker_args=None)

Set up trainer and sessions for algorithm and environment.

This method saves algo and env within trainer and creates a sampler, and initializes all uninitialized variables in session.

Note

After setup() is called all variables in session should have been initialized. setup() respects existing values in session so policy weights can be loaded before setup().

Parameters
  • algo (RLAlgorithm) – An algorithm instance.

  • env (Environment) – An environment instance.

  • sampler_cls (type) – A class which implements Sampler

  • sampler_args (dict) – Arguments to be passed to sampler constructor.

  • n_workers (int) – The number of workers the sampler should use.

  • worker_class (type) – Type of worker the sampler should use.

  • worker_args (dict or None) – Additional arguments that should be passed to the worker.

initialize_tf_vars(self)

Initialize all uninitialized variables in session.

obtain_episodes(self, itr, batch_size=None, agent_update=None, env_update=None)

Obtain one batch of episodes.

Parameters
  • itr (int) – Index of iteration (epoch).

  • batch_size (int) – Number of steps in batch. This is a hint that the sampler may or may not respect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before doing sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Raises

ValueError – If the trainer was initialized without a sampler, or batch_size wasn’t provided here or to train.

Returns

Batch of episodes.

Return type

EpisodeBatch

obtain_samples(self, itr, batch_size=None, agent_update=None, env_update=None)

Obtain one batch of samples.

Parameters
  • itr (int) – Index of iteration (epoch).

  • batch_size (int) – Number of steps in batch. This is a hint that the sampler may or may not respect.

  • agent_update (object) – Value which will be passed into the agent_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

  • env_update (object) – Value which will be passed into the env_update_fn before sampling episodes. If a list is passed in, it must have length exactly factory.n_workers, and will be spread across the workers.

Raises

ValueError – Raised if the trainer was initialized without a sampler, or batch_size wasn’t provided here or to train.

Returns

One batch of samples.

Return type

list[dict]

save(self, epoch)

Save snapshot of current batch.

Parameters

epoch (int) – Epoch.

Raises

NotSetupError – if save() is called before the trainer is set up.

restore(self, from_dir, from_epoch='last')

Restore experiment from snapshot.

Parameters
  • from_dir (str) – Directory of the pickle file to resume experiment from.

  • from_epoch (str or int) – The epoch to restore from. Can be ‘first’, ‘last’ or a number. Not applicable when snapshot_mode=’last’.

Returns

Arguments for train().

Return type

TrainArgs

log_diagnostics(self, pause_for_plot=False)

Log diagnostics.

Parameters

pause_for_plot (bool) – Pause for plot.

train(self, n_epochs, batch_size=None, plot=False, store_episodes=False, pause_for_plot=False)

Start training.

Parameters
  • n_epochs (int) – Number of epochs.

  • batch_size (int or None) – Number of environment steps in one batch.

  • plot (bool) – Visualize an episode from the policy after each epoch.

  • store_episodes (bool) – Save episodes in snapshot.

  • pause_for_plot (bool) – Pause for plot.

Raises

NotSetupError – If train() is called before setup().

Returns

The average return in last epoch cycle.

Return type

float

step_epochs(self)

Step through each epoch.

This function returns a magic generator. When iterated through, this generator automatically performs services such as snapshotting and log management. It is used inside train() in each algorithm.

The generator initializes two variables: self.step_itr and self.step_episode. To use the generator, these two have to be updated manually in each epoch, as the example shows below.

Yields

int – The next training epoch.

Examples

for epoch in trainer.step_epochs():

trainer.step_episode = trainer.obtain_samples(…) self.train_once(…) trainer.step_itr += 1

resume(self, n_epochs=None, batch_size=None, plot=None, store_episodes=None, pause_for_plot=None)

Resume from restored experiment.

This method provides the same interface as train().

If not specified, an argument will default to the saved arguments from the last call to train().

Parameters
  • n_epochs (int) – Number of epochs.

  • batch_size (int) – Number of environment steps in one batch.

  • plot (bool) – Visualize an episode from the policy after each epoch.

  • store_episodes (bool) – Save episodes in snapshot.

  • pause_for_plot (bool) – Pause for plot.

Raises

NotSetupError – If resume() is called before restore().

Returns

The average return in last epoch cycle.

Return type

float

get_env_copy(self)

Get a copy of the environment.

Returns

An environment instance.

Return type

Environment

property total_env_steps(self)

Total environment steps collected.

Returns

Total environment steps collected.

Return type

int

TFTrainer