Source code for garage.torch.value_functions.value_function

"""Base class for all baselines."""
import abc

import torch.nn as nn


[docs]class ValueFunction(abc.ABC, nn.Module): """Base class for all baselines. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. name (str): Value function name, also the variable scope. """ def __init__(self, env_spec, name): super(ValueFunction, self).__init__() self._mdp_spec = env_spec self.name = name
[docs] @abc.abstractmethod def compute_loss(self, obs, returns): r"""Compute mean value of loss. Args: obs (torch.Tensor): Observation from the environment with shape :math:`(N \dot [T], O*)`. returns (torch.Tensor): Acquired returns with shape :math:`(N, )`. Returns: torch.Tensor: Calculated negative mean scalar value of objective (float). """