import logging from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy from ray.rllib.agents.dqn.dqn import DQNTrainer logger = logging.getLogger(__name__) # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Exploration Settings (Experimental) === "exploration_config": { # The Exploration class to use. "type": "EpsilonGreedy", # Config for the Exploration class' constructor: "initial_epsilon": 1.0, "final_epsilon": 0.02, "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon. # For soft_q, use: # "exploration_config" = { # "type": "SoftQ" # "temperature": [float, e.g. 1.0] # } }, # Switch to greedy actions in evaluation workers. "evaluation_config": { "explore": False, }, # Minimum env steps to optimize for per train call. This value does # not affect learning, only the length of iterations. "timesteps_per_iteration": 1000, # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": 500, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. "buffer_size": 50000, # Whether to LZ4 compress observations "compress_observations": True, # === Optimization === # Learning rate for adam optimizer "lr": 5e-4, # Learning rate schedule "lr_schedule": None, # Adam epsilon hyper parameter "adam_epsilon": 1e-8, # If not None, clip gradients during optimization at this value "grad_clip": 40, # How many steps of the model to sample before learning starts. "learning_starts": 1000, # Update the replay buffer with this many samples at once. Note that # this setting applies per-worker if num_workers > 1. "rollout_fragment_length": 4, # Size of a batch sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. "train_batch_size": 32, # === Parallelism === # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, # Prevent iterations from going lower than this time span "min_iter_time_s": 1, }) # __sphinx_doc_end__ # yapf: enable def get_policy_class(config): if config["use_pytorch"]: from ray.rllib.agents.dqn.simple_q_torch_policy import \ SimpleQTorchPolicy return SimpleQTorchPolicy else: return SimpleQTFPolicy SimpleQTrainer = DQNTrainer.with_updates( default_policy=SimpleQTFPolicy, get_policy_class=get_policy_class, default_config=DEFAULT_CONFIG)