Source code for garage.np.baselines.baseline

"""Base class for all baselines."""
import abc


[docs]class Baseline(abc.ABC): """Base class for all baselines. Args: env_spec (garage.envs.env_spec.EnvSpec): Environment specification. """ def __init__(self, env_spec): self._mdp_spec = env_spec
[docs] @abc.abstractmethod def get_param_values(self): """Get parameter values. Returns: List[np.ndarray]: A list of values of each parameter. """
[docs] @abc.abstractmethod def set_param_values(self, flattened_params): """Set param values. Args: flattened_params (np.ndarray): A numpy array of parameter values. """
[docs] @abc.abstractmethod def fit(self, paths): """Fit regressor based on paths. Args: paths (dict[numpy.ndarray]): Sample paths. """
[docs] @abc.abstractmethod def predict(self, path): """Predict value based on paths. Args: path (dict[numpy.ndarray]): Sample paths. Returns: numpy.ndarray: Predicted value. """
[docs] def log_diagnostics(self, paths): """Log diagnostic information. Args: paths (list[dict]): A list of collected paths. """