garage.tf.policies.base module

Base class for Policies.

class Policy(name, env_spec)[source]

Bases: abc.ABC

Base class for Policies.

Parameters:
action_space

The action space for the environment.

Type:akro.Space
env_spec

Policy environment specification.

Type:garage.EnvSpec
flat_to_params(flattened_params, **tags)[source]

Unflatten tensors according to their respective shapes.

Parameters:
  • flattened_params (np.ndarray) – A numpy array of flattened params.
  • tags (dict) – A map specifying the parameters and their shapes.
Returns:

A list of parameters reshaped to the shapes specified.

Return type:

tensors (List[np.ndarray])

get_action(observation)[source]

Get action sampled from the policy.

Parameters:observation (np.ndarray) – Observation from the environment.
Returns:Action sampled from the policy.
Return type:(np.ndarray)
get_actions(observations)[source]

Get action sampled from the policy.

Parameters:observations (list[np.ndarray]) – Observations from the environment.
Returns:Actions sampled from the policy.
Return type:(np.ndarray)
get_global_vars()[source]

Get global variables.

Returns:A list of global variables in the current variable scope.
Return type:List[tf.Variable]
get_param_shapes(**tags)[source]

Get parameter shapes.

get_param_values(**tags)[source]

Get param values.

Parameters:tags (dict) – A map of parameters for which the values are required.
Returns:Values of the parameters evaluated in the current session
Return type:param_values (np.ndarray)
get_params(trainable=True)[source]

Get the trainable variables.

Returns:A list of trainable variables in the current variable scope.
Return type:List[tf.Variable]
get_trainable_vars()[source]

Get trainable variables.

Returns:A list of trainable variables in the current variable scope.
Return type:List[tf.Variable]
log_diagnostics(paths)[source]

Log extra information per iteration based on the collected paths.

name

Name of the policy model and the variable scope.

Type:str
observation_space

The observation space of the environment.

Type:akro.Space
recurrent

Indicating if the policy is recurrent.

Type:bool
reset(dones=None)[source]

Reset the policy.

If dones is None, it will be by default np.array([True]) which implies the policy will not be “vectorized”, i.e. number of parallel environments for training data sampling = 1.

Parameters:dones (numpy.ndarray) – Bool that indicates terminal state(s).
set_param_values(param_values, name=None, **tags)[source]

Set param values.

Parameters:
  • param_values (np.ndarray) – A numpy array of parameter values.
  • tags (dict) – A map of parameters for which the values should be
  • loaded.
state_info_keys

State info keys.

Returns:keys for the information related to the policy’s state when taking an action.
Return type:List[str]
state_info_specs

State info specifcation.

Returns:keys and shapes for the information related to the policy’s state when taking an action.
Return type:List[str]
terminate()[source]

Clean up operation.

vectorized

Boolean for vectorized.

Returns:Indicates whether the policy is vectorized. If True, it should implement get_actions(), and support resetting with multiple simultaneous states.
Return type:bool
class StochasticPolicy(name, env_spec)[source]

Bases: garage.tf.policies.base.Policy

StochasticPolicy.

dist_info(obs, state_infos)[source]

Distribution info.

Return the distribution information about the actions.

Parameters:
  • obs (tf.Tensor) – observation values
  • state_infos (dict) – a dictionary whose values should contain information about the state of the policy at the time it received the observation
dist_info_sym(obs_var, state_info_vars, name='dist_info_sym')[source]

Symbolic graph of the distribution.

Return the symbolic distribution information about the actions. :param obs_var: symbolic variable for observations :type obs_var: tf.Tensor :param state_info_vars: a dictionary whose values should contain

information about the state of the policy at the time it received the observation.
Parameters:name (str) – Name of the symbolic graph.
distribution

Distribution.