Source code for garage.torch.policies.gaussian_mlp_policy

"""GaussianMLPPolicy."""
import torch

from garage.torch.modules import GaussianMLPModule
from garage.torch.policies import Policy


[docs]class GaussianMLPPolicy(Policy, GaussianMLPModule): """ GaussianMLPPolicy. A policy that contains a MLP to make prediction based on a gaussian distribution. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. module : GaussianMLPModule to make prediction based on a gaussian distribution. :return: """ def __init__(self, env_spec, **kwargs): self._obs_dim = env_spec.observation_space.flat_dim self._action_dim = env_spec.action_space.flat_dim Policy.__init__(self, env_spec) GaussianMLPModule.__init__(self, input_dim=self._obs_dim, output_dim=self._action_dim, **kwargs)
[docs] def forward(self, inputs): """Forward method.""" return super().forward(torch.Tensor(inputs))
[docs] def get_action(self, observation): """Get a single action given an observation.""" with torch.no_grad(): observation = observation.unsqueeze(0) dist = self.forward(observation) return (dist.rsample().squeeze(0).numpy(), dict(mean=dist.mean.squeeze(0).numpy(), log_std=(dist.variance**.5).log().squeeze(0).numpy()))
[docs] def get_actions(self, observations): """Get actions given observations.""" with torch.no_grad(): dist = self.forward(observations) return (dist.rsample().numpy(), dict(mean=dist.mean.numpy(), log_std=(dist.variance**.5).log().numpy()))
[docs] def log_likelihood(self, observation, action): """Get log likelihood given observations and action.""" dist = self.forward(observation) return dist.log_prob(action)
[docs] def get_entropy(self, observation): """Get entropy given observations.""" dist = self.forward(observation) return dist.entropy()
[docs] def reset(self, dones=None): """Reset the environment.""" pass
@property def vectorized(self): """Vectorized or not.""" return True