Load and Use a Trained Policy¶
In this section you will learn how to extract a trained policy from an experiment snapshot, as well as how to evaluate that policy in an environment.
Obtaining an experiment snapshot¶
Please refer to this page for information on how to save an experiment snapshot. The snapshot contains data such as:
The trainer’s
setup_args
andtrain_args
Random seed
Batch size
Number of epochs
And more
The experiment’s
stats
The environment
The algorithm (which includes the policy we want to evaluate)
Extracting a trained policy from a snapshot¶
To extract the trained policy from a saved experiment, you only need a few lines of code:
# Load the policy
from garage.experiment import Snapshotter
import tensorflow as tf # optional, only for TensorFlow as we need a tf.Session
snapshotter = Snapshotter()
with tf.compat.v1.Session(): # optional, only for TensorFlow
data = snapshotter.load('path/to/snapshot/dir')
policy = data['algo'].policy
# You can also access other components of the experiment
env = data['env']
This code makes use of a Garage Snapshotter
instance. It calls
cloudpickle behind the scenes, and should continue to work even if we
change how we pickle (we used to use joblib, for example).
Applying the policy to an environment¶
In order to use your newly-loaded trained policy, you first have to make sure that the shapes of its observation and action spaces match those of the target environment. An easy way of doing this is run the policy in the same environment in which it was trained.
Once you have an environment initialized, the basic idea is this:
steps, max_steps = 0, 150
done = False
obs = env.reset() # The initial observation
policy.reset()
while steps < max_steps and not done:
obs, rew, done, _ = env.step(policy.get_action(obs))
env.render() # Render the environment to see what's going on (optional)
steps += 1
env.close()
This logic is bundled up in a more robust way in the rollout()
function from
garage.sampler.utils
. Let’s bring everything together:
# Load the policy and the env in which it was trained
from garage.experiment import Snapshotter
import tensorflow as tf # optional, only for TensorFlow as we need a tf.Session
snapshotter = Snapshotter()
with tf.compat.v1.Session(): # optional, only for TensorFlow
data = snapshotter.load('path/to/snapshot/dir')
policy = data['algo'].policy
env = data['env']
# See what the trained policy can accomplish
from garage import rollout
path = rollout(env, policy, animated=True)
print(path)
This page was authored by Hayden Shively (@haydenshively)