Source code for garage.plotter.plotter

import atexit
from collections import namedtuple
from enum import Enum
from multiprocessing import JoinableQueue
from multiprocessing import Process
import platform
from threading import Thread

import numpy as np

from garage.sampler.utils import rollout

__all__ = ['Plotter']


class Op(Enum):
    STOP = 0
    UPDATE = 1
    DEMO = 2


Message = namedtuple('Message', ['op', 'args', 'kwargs'])


[docs]class Plotter: # Static variable used to disable the plotter enable = True # List containing all plotters instantiated in the process __plotters = [] def __init__(self, standalone=False): Plotter.__plotters.append(self) self._process = None self._queue = None def _worker_start(self): env = None policy = None max_length = None initial_rollout = True try: # Each iteration will process ALL messages currently in the # queue while True: msgs = {} # If true, block and yield processor if initial_rollout: msg = self._queue.get() msgs[msg.op] = msg # Only fetch the last message of each type while not self._queue.empty(): msg = self._queue.get() msgs[msg.op] = msg else: # Only fetch the last message of each type while not self._queue.empty(): msg = self._queue.get_nowait() msgs[msg.op] = msg if Op.STOP in msgs: break elif Op.UPDATE in msgs: env, policy = msgs[Op.UPDATE].args elif Op.DEMO in msgs: param_values, max_length = msgs[Op.DEMO].args policy.set_param_values(param_values) initial_rollout = False rollout( env, policy, max_path_length=max_length, animated=True, speedup=5) else: if max_length: rollout( env, policy, max_path_length=max_length, animated=True, speedup=5) except KeyboardInterrupt: pass
[docs] def close(self): if not Plotter.enable: return if self._process and self._process.is_alive(): while not self._queue.empty(): self._queue.get() self._queue.task_done() self._queue.put(Message(op=Op.STOP, args=None, kwargs=None)) self._queue.close() self._process.join()
[docs] @staticmethod def disable(): """Disable all instances of the Plotter class.""" Plotter.enable = False
[docs] @staticmethod def get_plotters(): return Plotter.__plotters
[docs] def init_worker(self): if not Plotter.enable: return self._queue = JoinableQueue() if ('Darwin' in platform.platform()): self._process = Thread(target=self._worker_start) else: self._process = Process(target=self._worker_start) self._process.daemon = True self._process.start() atexit.register(self.close)
[docs] def init_plot(self, env, policy): if not Plotter.enable: return if not (self._process and self._queue): self.init_worker() # Needed in order to draw glfw window on the main thread if ('Darwin' in platform.platform()): rollout( env, policy, max_path_length=np.inf, animated=True, speedup=5) self._queue.put(Message(op=Op.UPDATE, args=(env, policy), kwargs=None))
[docs] def update_plot(self, policy, max_length=np.inf): if not Plotter.enable: return self._queue.put( Message( op=Op.DEMO, args=(policy.get_param_values(), max_length), kwargs=None))