Source code for garage.experiment.deterministic

"""Utilities for ensuring that experiments are deterministic."""
import random
import sys
import warnings

import numpy as np

seed_ = None
seed_stream_ = None


[docs]def set_seed(seed): """Set the process-wide random seed. Args: seed (int): A positive integer """ seed %= 4294967294 # pylint: disable=global-statement global seed_ global seed_stream_ seed_ = seed random.seed(seed) np.random.seed(seed) if 'tensorflow' in sys.modules: import tensorflow as tf # pylint: disable=import-outside-toplevel tf.compat.v1.set_random_seed(seed) try: # pylint: disable=import-outside-toplevel import tensorflow_probability as tfp seed_stream_ = tfp.util.SeedStream(seed_, salt='garage') except ImportError: pass if 'torch' in sys.modules: warnings.warn( 'Enabeling deterministic mode in PyTorch can have a performance ' 'impact when using GPU.') import torch # pylint: disable=import-outside-toplevel torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs]def get_seed(): """Get the process-wide random seed. Returns: int: The process-wide random seed """ return seed_
[docs]def get_tf_seed_stream(): """Get the pseudo-random number generator (PRNG) for TensorFlow ops. Returns: int: A seed generated by a PRNG with fixed global seed. """ if seed_stream_ is None: set_seed(0) return seed_stream_() % 4294967294