"""Renders rollouts of the policy as it trains."""
import atexit
from collections import namedtuple
from enum import Enum
import platform
from queue import Queue
from threading import Thread
import numpy as np
import tensorflow as tf
from garage.sampler.utils import rollout as default_rollout
__all__ = ['Plotter']
class Op(Enum):
"""Message types."""
STOP = 0
UPDATE = 1
DEMO = 2
Message = namedtuple('Message', ['op', 'args', 'kwargs'])
[docs]class Plotter:
"""Renders rollouts of the policy as it trains.
Usually, this class is used by sending plot=True to LocalRunner.train().
Args:
env (gym.Env): The environment to perform rollouts in. This will be
used without copying in the current process but in a separate thread,
so it should be given a unique copy (in particular, do not pass the
environment here, then try to pickle it, or you will occasionally get
crashes).
policy (garage.tf.Policy): The policy to do the rollouts with.
sess (tf.Session): The TensorFlow session to use.
graph (tf.Graph): The TensorFlow graph to use.
rollout (callable): The rollout function to call.
"""
# List containing all plotters instantiated in the process
__plotters = []
def __init__(self,
env,
policy,
sess=None,
graph=None,
rollout=default_rollout):
Plotter.__plotters.append(self)
self._env = env
self.sess = tf.compat.v1.Session() if sess is None else sess
self.graph = tf.compat.v1.get_default_graph(
) if graph is None else graph
with self.sess.as_default(), self.graph.as_default():
self._policy = policy.clone('plotter_policy')
self.rollout = rollout
self.worker_thread = Thread(target=self._start_worker, daemon=True)
self.queue = Queue()
# Needed in order to draw glfw window on the main thread
if 'Darwin' in platform.platform():
self.rollout(self._env,
self._policy,
max_path_length=np.inf,
animated=True,
speedup=5)
def _start_worker(self):
max_length = None
initial_rollout = True
try:
with self.sess.as_default(), self.sess.graph.as_default():
# Each iteration will process ALL messages currently in the
# queue
while True:
msgs = {}
# If true, block and yield processor
if initial_rollout:
msg = self.queue.get()
msgs[msg.op] = msg
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get()
msgs[msg.op] = msg
else:
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get_nowait()
msgs[msg.op] = msg
if Op.STOP in msgs:
self.queue.task_done()
break
if Op.UPDATE in msgs:
self._env, self._policy = msgs[Op.UPDATE].args
self.queue.task_done()
if Op.DEMO in msgs:
param_values, max_length = msgs[Op.DEMO].args
self._policy.set_param_values(param_values)
initial_rollout = False
self.rollout(self._env,
self._policy,
max_path_length=max_length,
animated=True,
speedup=5)
self.queue.task_done()
else:
if max_length:
self.rollout(self._env,
self._policy,
max_path_length=max_length,
animated=True,
speedup=5)
except KeyboardInterrupt:
pass
[docs] def close(self):
"""Stop the Plotter's worker thread."""
if self.worker_thread.is_alive():
while not self.queue.empty():
self.queue.get()
self.queue.task_done()
self.queue.put(Message(op=Op.STOP, args=None, kwargs=None))
self.queue.join()
self.worker_thread.join()
[docs] @staticmethod
def get_plotters():
"""Return all garage.tf.Plotter's.
Returns:
list[garage.tf.Plotter]: All the garage.tf.Plotter's
"""
return Plotter.__plotters
[docs] def start(self):
"""Start the Plotter's worker thread."""
if not self.worker_thread.is_alive():
tf.compat.v1.get_variable_scope().reuse_variables()
self.worker_thread.start()
self.queue.put(
Message(op=Op.UPDATE,
args=(self._env, self._policy),
kwargs=None))
atexit.register(self.close)
[docs] def update_plot(self, policy, max_length=np.inf):
"""Update the policy being plotted.
Args:
policy (garage.tf.Policy): The policy to rollout.
max_length (int or float): The maximum length to allow a rollout to
be. Defaults to infinity.
"""
if self.worker_thread.is_alive():
self.queue.put(
Message(op=Op.DEMO,
args=(policy.get_param_values(), max_length),
kwargs=None))