"""Default TensorFlow sampler Worker."""
import tensorflow as tf
from garage.sampler import Worker
[docs]class TFWorkerClassWrapper:
"""Acts like a Worker class, but is actually an object.
When called, constructs the wrapped class and wraps it in a
TFWorkerWrapper.
Args:
wrapped_class (type): The class to wrap. Should be a subclass of
garage.sampler.Worker.
"""
# pylint: disable=too-few-public-methods
def __init__(self, wrapped_class):
self._wrapped_class = wrapped_class
def __call__(self, *args, **kwargs):
"""Construct the inner class and wrap it.
Args:
*args: Passed on to inner worker class.
**kwargs: Passed on to inner worker class.
Returns:
TFWorkerWrapper: The wrapped worker.
"""
wrapper = TFWorkerWrapper()
# Need to construct the wrapped class after we've entered the Session.
wrapper._inner_worker = self._wrapped_class(*args, **kwargs)
return wrapper
[docs]class TFWorkerWrapper(Worker):
"""Wrapper around another workers that initializes a TensorFlow Session."""
def __init__(self):
# pylint: disable=super-init-not-called
self._inner_worker = None
self._sess = None
self._sess_entered = None
self.worker_init()
[docs] def worker_init(self):
"""Initialize a worker."""
self._sess = tf.compat.v1.get_default_session()
if not self._sess:
# create a tf session for all
# sampler worker processes in
# order to execute the policy.
self._sess = tf.compat.v1.Session()
self._sess_entered = True
self._sess.__enter__()
[docs] def shutdown(self):
"""Perform shutdown processes for TF."""
self._inner_worker.shutdown()
if tf.compat.v1.get_default_session() and self._sess_entered:
self._sess_entered = False
self._sess.__exit__(None, None, None)
@property
def agent(self):
"""Returns the worker's agent.
Returns:
garage.Policy: the worker's agent.
"""
return self._inner_worker.agent
@agent.setter
def agent(self, agent):
"""Sets the worker's agent.
Args:
agent (garage.Policy): The agent.
"""
self._inner_worker.agent = agent
@property
def env(self):
"""Returns the worker's environment.
Returns:
gym.Env: the worker's environment.
"""
return self._inner_worker.env
@env.setter
def env(self, env):
"""Sets the worker's environment.
Args:
env (gym.Env): The environment.
"""
self._inner_worker.env = env
[docs] def update_agent(self, agent_update):
"""Update the worker's agent, using agent_update.
Args:
agent_update(object): An agent update. The exact type of this
argument depends on the `Worker` implementation.
"""
self._inner_worker.update_agent(agent_update)
[docs] def update_env(self, env_update):
"""Update the worker's env, using env_update.
Args:
env_update(object): An environment update. The exact type of this
argument depends on the `Worker` implementation.
"""
self._inner_worker.update_env(env_update)
[docs] def rollout(self):
"""Sample a single rollout of the agent in the environment.
Returns:
garage.TrajectoryBatch: Batch of sampled trajectories. May be
truncated if max_path_length is set.
"""
return self._inner_worker.rollout()
[docs] def start_rollout(self):
"""Begin a new rollout."""
self._inner_worker.start_rollout()
[docs] def step_rollout(self):
"""Take a single time-step in the current rollout.
Returns:
bool: True iff the path is done, either due to the environment
indicating termination of due to reaching `max_path_length`.
"""
return self._inner_worker.step_rollout()
[docs] def collect_rollout(self):
"""Collect the current rollout, clearing the internal buffer.
Returns:
garage.TrajectoryBatch: Batch of sampled trajectories. May be
truncated if the rollouts haven't completed yet.
"""
return self._inner_worker.collect_rollout()