Source code for garage.tf.policies.discrete_qf_derived_policy

"""A Discrete QFunction-derived policy.

This policy chooses the action that yields to the largest Q-value.
"""
import akro
import numpy as np
import tensorflow as tf

from garage.tf.policies.policy import Policy


[docs]class DiscreteQfDerivedPolicy(Policy): """DiscreteQfDerived policy. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. qf (garage.q_functions.QFunction): The q-function used. name (str): Name of the policy. """ def __init__(self, env_spec, qf, name='DiscreteQfDerivedPolicy'): super().__init__(name, env_spec) assert isinstance(env_spec.action_space, akro.Discrete) self._env_spec = env_spec self._qf = qf self._initialize() def _initialize(self): self._f_qval = tf.compat.v1.get_default_session().make_callable( self._qf.q_vals, feed_list=[self._qf.model.input]) @property def vectorized(self): """Vectorized or not. Returns: Bool: True if primitive supports vectorized operations. """ return True
[docs] def get_action(self, observation): """Get action from this policy for the input observation. Args: observation (numpy.ndarray): Observation from environment. Returns: numpy.ndarray: Single optimal action from this policy. dict: Predicted action and agent information. It returns an empty dict since there is no parameterization. """ if isinstance(self.env_spec.observation_space, akro.Image) and \ len(observation.shape) < \ len(self.env_spec.observation_space.shape): observation = self.env_spec.observation_space.unflatten( observation) q_vals = self._f_qval([observation]) opt_action = np.argmax(q_vals) return opt_action, dict()
[docs] def get_actions(self, observations): """Get actions from this policy for the input observations. Args: observations (numpy.ndarray): Observations from environment. Returns: numpy.ndarray: Optimal actions from this policy. dict: Predicted action and agent information. It returns an empty dict since there is no parameterization. """ if isinstance(self.env_spec.observation_space, akro.Image) and \ len(observations[0].shape) < \ len(self.env_spec.observation_space.shape): observations = self.env_spec.observation_space.unflatten_n( observations) q_vals = self._f_qval(observations) opt_actions = np.argmax(q_vals, axis=1) return opt_actions, dict()
def __getstate__(self): """Object.__getstate__. Returns: dict: the state to be pickled for the instance. """ new_dict = self.__dict__.copy() del new_dict['_f_qval'] return new_dict def __setstate__(self, state): """Object.__setstate__. Args: state (dict): Unpickled state. """ self.__dict__.update(state) self._initialize()