Load and Use a Trained Policy¶
In this section you will learn how to extract a trained policy from an experiment snapshot, as well as how to evaluate that policy in an environment.
Obtaining an experiment snapshot¶
Please refer to this page for information on how to save an experiment snapshot. The snapshot contains data such as:
Number of epochs
The algorithm (which includes the policy we want to evaluate)
Extracting a trained policy from a snapshot¶
To extract the trained policy from a saved experiment, you only need a few lines of code:
# Load the policy from garage.experiment import Snapshotter import tensorflow as tf # optional, only for TensorFlow as we need a tf.Session snapshotter = Snapshotter() with tf.compat.v1.Session(): # optional, only for TensorFlow data = snapshotter.load('path/to/snapshot/dir') policy = data['algo'].policy # You can also access other components of the experiment env = data['env']
This code makes use of a Garage
Snapshotter instance. It calls
cloudpickle behind the scenes, and should continue to work even if we
change how we pickle (we used to use joblib, for example).
Applying the policy to an environment¶
In order to use your newly-loaded trained policy, you first have to make sure that the shapes of its observation and action spaces match those of the target environment. An easy way of doing this is run the policy in the same environment in which it was trained.
Once you have an environment initialized, the basic idea is this:
steps, max_steps = 0, 150 done = False obs = env.reset() # The initial observation policy.reset() while steps < max_steps and not done: obs, rew, done, _ = env.step(policy.get_action(obs)) env.render() # Render the environment to see what's going on (optional) steps += 1 env.close()
This logic is bundled up in a more robust way in the
rollout() function from
garage.sampler.utils. Let’s bring everything together:
# Load the policy and the env in which it was trained from garage.experiment import Snapshotter import tensorflow as tf # optional, only for TensorFlow as we need a tf.Session snapshotter = Snapshotter() with tf.compat.v1.Session(): # optional, only for TensorFlow data = snapshotter.load('path/to/snapshot/dir') policy = data['algo'].policy env = data['env'] # See what the trained policy can accomplish from garage import rollout path = rollout(env, policy, animated=True) print(path)
This page was authored by Hayden Shively (@haydenshively)