An example to train a task with DQN algorithm.
Here it creates a gym environment CartPole, and trains a DQN with 50k steps.
- main(env=None, seed=24, n=psutil.cpu_count(logical=False), buffer_size=None, n_steps=None, max_episode_length=None)¶
Wrapper to setup the logging directory.
env (str) – Name of the atari environment, can either be the prefix or the full name. For example, this can either be ‘Pong’ or ‘PongNoFrameskip-v4’. If the former is used, the env used will be env + ‘NoFrameskip-v4’.
seed (int) – Seed to use for the RNG.
n (int) – Number of workers to use. Defaults to the number of CPU cores available.
buffer_size (int) – size of the replay buffer in transitions. If None, defaults to hyperparams[‘buffer_size’]. This is used by the integration tests.
n_steps (float) – Total number of environment steps to run for, not not including evaluation. If this is not None, n_epochs will be recalculated based on this value.
max_episode_length (int) – Max length of an episode. If None, defaults to the timelimit specific to the environment. Used by integration tests.
- dqn_atari(ctxt=None, env=None, seed=24, n_workers=psutil.cpu_count(logical=False), max_episode_length=None, **kwargs)¶
Train DQN with PongNoFrameskip-v4 environment.
ctxt (garage.experiment.ExperimentContext) – The experiment configuration used by Trainer to create the snapshotter.
env (str) – Name of the atari environment, eg. ‘PongNoFrameskip-v4’.
seed (int) – Used to seed the random number generator to produce determinism.
n_workers (int) – Number of workers to use. Defaults to the number of CPU cores available.
kwargs (dict) – hyperparameters to be saved to variant.json.