Source code for garage.tf.plotter.plotter

import atexit
from collections import namedtuple
from enum import Enum
import platform
from queue import Queue
from threading import Thread

import numpy as np
import tensorflow as tf

from garage.sampler.utils import rollout as default_rollout

__all__ = ['Plotter']


class Op(Enum):
    STOP = 0
    UPDATE = 1
    DEMO = 2


Message = namedtuple('Message', ['op', 'args', 'kwargs'])


[docs]class Plotter: # Static variable used to disable the plotter enable = True # List containing all plotters instantiated in the process __plotters = [] def __init__(self, env, policy, sess=None, graph=None, rollout=default_rollout): Plotter.__plotters.append(self) self.env = env self.policy = policy self.sess = tf.compat.v1.get_default_session( ) if sess is None else sess self.graph = tf.compat.v1.get_default_graph( ) if graph is None else graph self.rollout = rollout self.worker_thread = Thread(target=self._start_worker, daemon=True) self.queue = Queue() # Needed in order to draw glfw window on the main thread if ('Darwin' in platform.platform()): self.rollout( env, policy, max_path_length=np.inf, animated=True, speedup=5) def _start_worker(self): env = None policy = None max_length = None initial_rollout = True try: with self.sess.as_default(), self.sess.graph.as_default(): # Each iteration will process ALL messages currently in the # queue while True: msgs = {} # If true, block and yield processor if initial_rollout: msg = self.queue.get() msgs[msg.op] = msg # Only fetch the last message of each type while not self.queue.empty(): msg = self.queue.get() msgs[msg.op] = msg else: # Only fetch the last message of each type while not self.queue.empty(): msg = self.queue.get_nowait() msgs[msg.op] = msg if Op.STOP in msgs: self.queue.task_done() break if Op.UPDATE in msgs: env, policy = msgs[Op.UPDATE].args self.queue.task_done() if Op.DEMO in msgs: param_values, max_length = msgs[Op.DEMO].args policy.set_param_values(param_values) initial_rollout = False self.rollout( env, policy, max_path_length=max_length, animated=True, speedup=5) self.queue.task_done() else: if max_length: self.rollout( env, policy, max_path_length=max_length, animated=True, speedup=5) except KeyboardInterrupt: pass
[docs] def close(self): if self.worker_thread.is_alive(): while not self.queue.empty(): self.queue.get() self.queue.task_done() self.queue.put(Message(op=Op.STOP, args=None, kwargs=None)) self.queue.join() self.worker_thread.join()
[docs] @staticmethod def disable(): """Disable all instances of the Plotter class.""" Plotter.enable = False
[docs] @staticmethod def get_plotters(): return Plotter.__plotters
[docs] def start(self): if not Plotter.enable: return if not self.worker_thread.is_alive(): tf.compat.v1.get_variable_scope().reuse_variables() self.worker_thread.start() self.queue.put( Message( op=Op.UPDATE, args=(self.env, self.policy), kwargs=None)) atexit.register(self.close)
[docs] def update_plot(self, policy, max_length=np.inf): if not Plotter.enable: return if self.worker_thread.is_alive(): self.queue.put( Message( op=Op.DEMO, args=(policy.get_param_values(), max_length), kwargs=None))