ray/rllib/agents/dqn/simple_q.py
Sven Mika 22ccc43670
[RLlib] DQN torch version. (#7597)
* Fix.

* Rollback.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix.

* Fix.

* Fix.

* WIP.

* WIP.

* Fix.

* Test case fixes.

* Test case fixes and LINT.

* Test case fixes and LINT.

* Rollback.

* WIP.

* WIP.

* Test case fixes.

* Fix.

* Fix.

* Fix.

* Add regression test for DQN w/ param noise.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Comment

* Regression test case.

* WIP.

* WIP.

* LINT.

* LINT.

* WIP.

* Fix.

* Fix.

* Fix.

* LINT.

* Fix (SAC does currently not support eager).

* Fix.

* WIP.

* LINT.

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* WIP.

* Fix.

* LINT.

* LINT.

* Fix and LINT.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Fix.

* Fix and LINT.

* Update rllib/utils/exploration/exploration.py

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Fixes.

* WIP.

* LINT.

* Fixes and LINT.

* LINT and fixes.

* LINT.

* Move action_dist back into torch extra_action_out_fn and LINT.

* Working SimpleQ learning cartpole on both torch AND tf.

* Working Rainbow learning cartpole on tf.

* Working Rainbow learning cartpole on tf.

* WIP.

* LINT.

* LINT.

* Update docs and add torch to APEX test.

* LINT.

* Fix.

* LINT.

* Fix.

* Fix.

* Fix and docstrings.

* Fix broken RLlib tests in master.

* Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier).

* Fix error_outputs option in BAZEL for RLlib regression tests.

* Fix.

* Tune param-noise tests.

* LINT.

* Fix.

* Fix.

* test

* test

* test

* Fix.

* Fix.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* WIP.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 11:56:16 -07:00

87 lines
2.9 KiB
Python

import logging
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.dqn import DQNTrainer
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Exploration Settings (Experimental) ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
},
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
# Minimum env steps to optimize for per train call. This value does
# not affect learning, only the length of iterations.
"timesteps_per_iteration": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
# Whether to LZ4 compress observations
"compress_observations": True,
# === Optimization ===
# Learning rate for adam optimizer
"lr": 5e-4,
# Learning rate schedule
"lr_schedule": None,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-8,
# If not None, clip gradients during optimization at this value
"grad_clip": 40,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Size of a batch sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.dqn.simple_q_torch_policy import \
SimpleQTorchPolicy
return SimpleQTorchPolicy
else:
return SimpleQTFPolicy
SimpleQTrainer = DQNTrainer.with_updates(
default_policy=SimpleQTFPolicy,
get_policy_class=get_policy_class,
default_config=DEFAULT_CONFIG)