[RLlib] APEX-DQN and R2D2 config objects. (#25067)

This commit is contained in:
Sven Mika 2022-05-23 12:15:45 +02:00 committed by GitHub
parent c6edfdd2a0
commit ec89fe5203
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 401 additions and 232 deletions

View file

@ -1,5 +1,5 @@
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
@ -13,20 +13,23 @@ from ray.rllib.algorithms.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
__all__ = [ __all__ = [
"ApexConfig",
"ApexTrainer", "ApexTrainer",
"APEX_DEFAULT_CONFIG", "DQNConfig",
"DQNTFPolicy", "DQNTFPolicy",
"DQNTorchPolicy", "DQNTorchPolicy",
"DQNTrainer", "DQNTrainer",
"DEFAULT_CONFIG",
"R2D2TorchPolicy", "R2D2TorchPolicy",
"R2D2Trainer", "R2D2Trainer",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
"SimpleQConfig", "SimpleQConfig",
"SimpleQTFPolicy", "SimpleQTFPolicy",
"SimpleQTorchPolicy", "SimpleQTorchPolicy",
"SimpleQTrainer", "SimpleQTrainer",
# Deprecated.
"APEX_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
] ]
from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.deprecation import deprecation_warning

View file

@ -16,31 +16,25 @@ from collections import defaultdict
import copy import copy
import platform import platform
import random import random
from typing import Tuple, Dict, List, DefaultDict, Set from typing import Dict, List, DefaultDict, Set
import ray import ray
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.rllib import RolloutWorker
from ray.rllib.agents import Trainer from ray.rllib.agents import Trainer
from ray.rllib.algorithms.dqn.dqn import ( from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
DQNTrainer,
)
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import ( from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER, STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER, STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
) )
from ray.rllib.execution.parallel_requests import ( from ray.rllib.execution.parallel_requests import (
asynchronous_parallel_requests, asynchronous_parallel_requests,
wait_asynchronous_requests, wait_asynchronous_requests,
) )
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import ( from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS, LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED,
@ -53,7 +47,6 @@ from ray.rllib.utils.metrics import (
TARGET_NET_UPDATE_TIMER, TARGET_NET_UPDATE_TIMER,
) )
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
SampleBatchType,
TrainerConfigDict, TrainerConfigDict,
ResultDict, ResultDict,
PartialTrainerConfigDict, PartialTrainerConfigDict,
@ -61,31 +54,96 @@ from ray.rllib.utils.typing import (
) )
from ray.tune.trainable import Trainable from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.util.ml_utils.dict import merge_dicts
# fmt: off
# __sphinx_doc_begin__ class ApexConfig(DQNConfig):
APEX_DEFAULT_CONFIG = merge_dicts( """Defines a configuration class from which an ApexTrainer can be built.
# See also the options in dqn.py, which are also supported.
DQN_DEFAULT_CONFIG, Example:
{ >>> from ray.rllib.agents.dqn.apex import ApexConfig
"optimizer": merge_dicts( >>> config = ApexConfig()
DQN_DEFAULT_CONFIG["optimizer"], { >>> print(config.replay_buffer_config)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 100000,
>>> "prioritized_replay_alpha": 0.45,
>>> "prioritized_replay_beta": 0.55,
>>> "prioritized_replay_eps": 3e-6,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = ApexTrainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> from ray import tune
>>> config = ApexConfig()
>>> config.training(num_atoms=tune.grid_search(list(range(1, 11)))
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "APEX",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "EpsilonGreedy",
>>> "initial_epsilon": 0.96,
>>> "final_epsilon": 0.01,
>>> "epsilone_timesteps": 5000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "SoftQ",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or ApexTrainer)
# fmt: off
# __sphinx_doc_begin__
# APEX-DQN settings overriding DQN ones:
# .training()
self.optimizer = merge_dicts(
DQNConfig().optimizer, {
"max_weight_sync_delay": 400, "max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4, "num_replay_buffer_shards": 4,
"debug": False "debug": False
}), })
"n_step": 3, self.n_step = 3
"num_gpus": 1, self.train_batch_size = 512
"num_workers": 32, self.target_network_update_freq = 500000
self.training_intensity = 1
# TODO(jungong) : add proper replay_buffer_config after # APEX-DQN is using a distributed (non local) replay buffer.
# DistributedReplayBuffer type is supported. self.replay_buffer_config = {
"replay_buffer_config": {
"no_local_replay_buffer": True, "no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports # Specify prioritized replay by supplying a buffer type that supports
# prioritization # prioritization
"prioritized_replay": DEPRECATED_VALUE,
"type": "MultiAgentPrioritizedReplayBuffer", "type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000, "capacity": 2000000,
"replay_batch_size": 32, "replay_batch_size": 32,
@ -106,63 +164,26 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# on which the learner is located. # on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True, "replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True, "worker_side_prioritization": True,
}, # Deprecated key.
"prioritized_replay": DEPRECATED_VALUE,
}
"train_batch_size": 512, # .rollouts()
"rollout_fragment_length": 50, self.num_workers = 32
# Update the target network every `target_network_update_freq` sample timesteps. self.rollout_fragment_length = 50
"target_network_update_freq": 500000, self.exploration_config = {
# Minimum env sampling timesteps to accumulate within a single `train()` call. "type": "PerWorkerEpsilonGreedy",
# This value does not affect learning, only the number of times }
# `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one
# `step_attempt()`, the env sampling timestep count has not been reached, will
# perform n more `step_attempt()` calls until the minimum timesteps have been
# executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"min_time_s_per_reporting": 30,
# This will set the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
# buffer timesteps. Must be greater than 0.
# TODO: Find a way to support None again as a means to replay
# proceeding as fast as possible.
"training_intensity": 1,
},
)
# __sphinx_doc_end__
# fmt: on
# .resources()
self.num_gpus = 1
# Update worker weights as they finish generating experiences. # .reporting()
class UpdateWorkerWeights: self.min_time_s_per_reporting = 30
def __init__( self.min_sample_timesteps_per_reporting = 25000
self,
learner_thread: LearnerThread,
workers: WorkerSet,
max_weight_sync_delay: int,
):
self.learner_thread = learner_thread
self.workers = workers
self.steps_since_update = defaultdict(int)
self.max_weight_sync_delay = max_weight_sync_delay
self.weights = None
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]): # fmt: on
actor, batch = item # __sphinx_doc_end__
self.steps_since_update[actor] += batch.count
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors.
if self.weights is None or self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
self.weights = ray.put(self.workers.local_worker().get_weights())
actor.set_weights.remote(self.weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
self.steps_since_update[actor] = 0
# Update metrics.
metrics = _get_shared_metrics()
metrics.counters["num_weight_syncs"] += 1
class ApexTrainer(DQNTrainer): class ApexTrainer(DQNTrainer):
@ -232,7 +253,7 @@ class ApexTrainer(DQNTrainer):
@classmethod @classmethod
@override(DQNTrainer) @override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> TrainerConfigDict:
return APEX_DEFAULT_CONFIG return ApexConfig().to_dict()
@override(DQNTrainer) @override(DQNTrainer)
def validate_config(self, config): def validate_config(self, config):
@ -548,3 +569,20 @@ class ApexTrainer(DQNTrainer):
), ),
strategy=config.get("placement_strategy", "PACK"), strategy=config.get("placement_strategy", "PACK"),
) )
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(ApexConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.apex.ApexConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
APEX_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -1,37 +1,99 @@
import logging import logging
from typing import Type from typing import Optional, Type
from ray.rllib.algorithms.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG from ray.rllib.algorithms.dqn import DQNConfig, DQNTrainer
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
DQN_DEFAULT_CONFIG, # See keys in dqn.py, which are also supported.
{
# Learning rate for adam optimizer.
"lr": 1e-4,
# Discount factor.
"gamma": 0.997,
# Train batch size (in number of single timesteps).
"train_batch_size": 64,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-3,
# Run in parallel by default.
"num_workers": 2,
# Batch mode must be complete_episodes.
"batch_mode": "complete_episodes",
# === Replay buffer === class R2D2Config(DQNConfig):
"replay_buffer_config": { """Defines a configuration class from which a R2D2Trainer can be built.
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.h_function_epsilon)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 1000000,
>>> "replay_burn_in": 20,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = R2D2Trainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> from ray import tune
>>> config = R2D2Config()
>>> config.training(train_batch_size=tune.grid_search([256, 64])
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "R2D2",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "initial_epsilon": 1.0,
>>> "final_epsilon": 0.1,
>>> "epsilone_timesteps": 200000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "SoftQ",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or R2D2Trainer)
# fmt: off
# __sphinx_doc_begin__
# R2D2-specific settings:
self.zero_init_states = True
self.use_h_function = True
self.h_function_epsilon = 1e-3
# R2D2 settings overriding DQN ones:
# .training()
self.adam_epsilon = 1e-3
self.lr = 1e-4
self.gamma = 0.997
self.train_batch_size = 64
self.target_network_update_freq = 2500
# R2D2 is using a buffer that stores sequences.
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer", "type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports # Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer. # prioritization, for example: MultiAgentPrioritizedReplayBuffer.
@ -53,34 +115,54 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# used for loss calculation is `n - replay_burn_in` time steps # used for loss calculation is `n - replay_burn_in` time steps
# (n=LSTMs/attention nets max_seq_len). # (n=LSTMs/attention nets max_seq_len).
"replay_burn_in": 0, "replay_burn_in": 0,
}, }
# If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use
# it (as initial state when running through the network for training),
# and update that initial state during training (from the internal
# state outputs of the immediately preceding sequence).
"zero_init_states": True,
# Whether to use the h-function from the paper [1] to scale target # .rollouts()
# values in the R2D2-loss function: self.num_workers = 2
# h(x) = sign(x)(􏰅|x| + 1 1) + εx self.batch_mode = "complete_episodes"
"use_h_function": True,
# The epsilon parameter from the R2D2 loss function (only used
# if `use_h_function`=True.
"h_function_epsilon": 1e-3,
# Update the target network every `target_network_update_freq` sample steps. # fmt: on
"target_network_update_freq": 2500, # __sphinx_doc_end__
# Deprecated keys: self.burn_in = DEPRECATED_VALUE
# Use config["replay_buffer_config"]["replay_burn_in"] instead
"burn_in": DEPRECATED_VALUE def training(
}, self,
_allow_unknown_configs=True, *,
) zero_init_states: Optional[bool] = None,
# __sphinx_doc_end__ use_h_function: Optional[bool] = None,
# fmt: on h_function_epsilon: Optional[float] = None,
**kwargs,
) -> "R2D2Config":
"""Sets the training related configuration.
Args:
zero_init_states: If True, assume a zero-initialized state input (no
matter where in the episode the sequence is located).
If False, store the initial states along with each SampleBatch, use
it (as initial state when running through the network for training),
and update that initial state during training (from the internal
state outputs of the immediately preceding sequence).
use_h_function: Whether to use the h-function from the paper [1] to scale
target values in the R2D2-loss function:
h(x) = sign(x)(􏰅|x| + 1 1) + εx
h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
used if `use_h_function`=True.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if zero_init_states is not None:
self.zero_init_states = zero_init_states
if use_h_function is not None:
self.use_h_function = use_h_function
if h_function_epsilon is not None:
self.h_function_epsilon = h_function_epsilon
return self
# Build an R2D2 trainer, which uses the framework specific Policy # Build an R2D2 trainer, which uses the framework specific Policy
@ -103,7 +185,7 @@ class R2D2Trainer(DQNTrainer):
@classmethod @classmethod
@override(DQNTrainer) @override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> TrainerConfigDict:
return R2D2_DEFAULT_CONFIG return R2D2Config().to_dict()
@override(DQNTrainer) @override(DQNTrainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -136,3 +218,20 @@ class R2D2Trainer(DQNTrainer):
if config.get("batch_mode") != "complete_episodes": if config.get("batch_mode") != "complete_episodes":
raise ValueError("`batch_mode` must be 'complete_episodes'!") raise ValueError("`batch_mode` must be 'complete_episodes'!")
# Deprecated: Use ray.rllib.agents.dqn.r2d2.R2D2Config instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(R2D2Config().to_dict())
@Deprecated(
old="ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.r2d2.R2D2Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
R2D2_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -21,17 +21,26 @@ class TestApexDQN(unittest.TestCase):
ray.shutdown() ray.shutdown()
def test_apex_zero_workers(self): def test_apex_zero_workers(self):
config = apex.APEX_DEFAULT_CONFIG.copy() config = (
config["num_workers"] = 0 apex.ApexConfig()
config["num_gpus"] = 0 .rollouts(num_rollout_workers=0)
config["replay_buffer_config"] = { .resources(num_gpus=0)
"learning_starts": 1000, .training(
} replay_buffer_config={
config["min_sample_timesteps_per_reporting"] = 100 "learning_starts": 1000,
config["min_time_s_per_reporting"] = 1 },
config["optimizer"]["num_replay_buffer_shards"] = 1 optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
for _ in framework_iterator(config): for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0") trainer = config.build(env="CartPole-v0")
results = trainer.train() results = trainer.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
@ -39,19 +48,26 @@ class TestApexDQN(unittest.TestCase):
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
"""Test whether an APEX-DQNTrainer can be built on all frameworks.""" """Test whether an APEX-DQNTrainer can be built on all frameworks."""
config = apex.APEX_DEFAULT_CONFIG.copy() config = (
config["num_workers"] = 3 apex.ApexConfig()
config["num_gpus"] = 0 .rollouts(num_rollout_workers=3)
config["replay_buffer_config"] = { .resources(num_gpus=0)
"learning_starts": 1000, .training(
} replay_buffer_config={
config["min_sample_timesteps_per_reporting"] = 100 "learning_starts": 1000,
config["min_time_s_per_reporting"] = 1 },
config["optimizer"]["num_replay_buffer_shards"] = 1 optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
plain_config = config.copy() trainer = config.build(env="CartPole-v0")
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
# Test per-worker epsilon distribution. # Test per-worker epsilon distribution.
infos = trainer.workers.foreach_policy( infos = trainer.workers.foreach_policy(
@ -77,37 +93,43 @@ class TestApexDQN(unittest.TestCase):
trainer.stop() trainer.stop()
def test_apex_lr_schedule(self): def test_apex_lr_schedule(self):
config = apex.APEX_DEFAULT_CONFIG.copy() config = (
config["num_workers"] = 1 apex.ApexConfig()
config["num_gpus"] = 0 .rollouts(
config["train_batch_size"] = 10 num_rollout_workers=1,
config["rollout_fragment_length"] = 5 rollout_fragment_length=5,
config["replay_buffer_config"] = { )
"no_local_replay_buffer": True, .resources(num_gpus=0)
"type": "MultiAgentPrioritizedReplayBuffer", .training(
"learning_starts": 10, train_batch_size=10,
"capacity": 100, optimizer={
"replay_batch_size": 10, "num_replay_buffer_shards": 1,
"prioritized_replay_alpha": 0.6, # This makes sure learning schedule is checked every 10 timesteps.
# Beta parameter for sampling from prioritized replay buffer. "max_weight_sync_delay": 10,
"prioritized_replay_beta": 0.4, },
# Epsilon to add to the TD errors when updating priorities. replay_buffer_config={
"prioritized_replay_eps": 1e-6, "no_local_replay_buffer": True,
} "type": "MultiAgentPrioritizedReplayBuffer",
config["min_sample_timesteps_per_reporting"] = 10 "learning_starts": 10,
# 0 metrics reporting delay, this makes sure timestep, "capacity": 100,
# which lr depends on, is updated after each worker rollout. "replay_batch_size": 10,
config["min_time_s_per_reporting"] = 0 "prioritized_replay_alpha": 0.6,
config["optimizer"]["num_replay_buffer_shards"] = 1 # Beta parameter for sampling from prioritized replay buffer.
# This makes sure learning schedule is checked every 10 timesteps. "prioritized_replay_beta": 0.4,
config["optimizer"]["max_weight_sync_delay"] = 10 # Epsilon to add to the TD errors when updating priorities.
# Initial lr, doesn't really matter because of the schedule below. "prioritized_replay_eps": 1e-6,
config["lr"] = 0.2 },
lr_schedule = [ # Initial lr, doesn't really matter because of the schedule below.
[0, 0.2], lr=0.2,
[100, 0.001], lr_schedule=[[0, 0.2], [100, 0.001]],
] )
config["lr_schedule"] = lr_schedule .reporting(
min_sample_timesteps_per_reporting=10,
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
min_time_s_per_reporting=0,
)
)
def _step_n_times(trainer, n: int): def _step_n_times(trainer, n: int):
"""Step trainer n times. """Step trainer n times.
@ -122,7 +144,7 @@ class TestApexDQN(unittest.TestCase):
] ]
for _ in framework_iterator(config): for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0") trainer = config.build(env="CartPole-v0")
lr = _step_n_times(trainer, 5) # 50 timesteps lr = _step_n_times(trainer, 5) # 50 timesteps
# Close to 0.2 # Close to 0.2

View file

@ -47,26 +47,30 @@ class TestR2D2(unittest.TestCase):
def test_r2d2_compilation(self): def test_r2d2_compilation(self):
"""Test whether a R2D2Trainer can be built on all frameworks.""" """Test whether a R2D2Trainer can be built on all frameworks."""
config = dqn.R2D2_DEFAULT_CONFIG.copy() config = (
config["num_workers"] = 0 # Run locally. dqn.r2d2.R2D2Config()
# Wrap with an LSTM and use a very simple base-model. .rollouts(num_rollout_workers=0)
config["model"]["use_lstm"] = True .training(
config["model"]["max_seq_len"] = 20 model={
config["model"]["fcnet_hiddens"] = [32] # Wrap with an LSTM and use a very simple base-model.
config["model"]["lstm_cell_size"] = 64 "use_lstm": True,
"max_seq_len": 20,
config["replay_buffer_config"]["replay_burn_in"] = 20 "fcnet_hiddens": [32],
config["zero_init_states"] = True "lstm_cell_size": 64,
},
config["dueling"] = False dueling=False,
config["lr"] = 5e-4 lr=5e-4,
config["exploration_config"]["epsilon_timesteps"] = 100000 zero_init_states=True,
replay_buffer_config={"replay_burn_in": 20},
)
.exploration(exploration_config={"epsilon_timesteps": 100000})
)
num_iterations = 1 num_iterations = 1
# Test building an R2D2 agent in all frameworks. # Test building an R2D2 agent in all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") trainer = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = trainer.train()
check_train_results(results) check_train_results(results)

View file

@ -19,7 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import ( from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling, WeightedImportanceSampling,
) )
from ray.rllib.utils import deep_update from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
EnvConfigDict, EnvConfigDict,
@ -715,7 +715,7 @@ class TrainerConfig:
if model is not None: if model is not None:
self.model = model self.model = model
if optimizer is not None: if optimizer is not None:
self.optimizer = optimizer self.optimizer = merge_dicts(self.optimizer, optimizer)
return self return self
@ -1061,7 +1061,7 @@ class TrainerConfig:
timestep count has not been reached, will perform n more timestep count has not been reached, will perform n more
`step_attempt()` calls until the minimum timesteps have been executed. `step_attempt()` calls until the minimum timesteps have been executed.
Set to 0 for no minimum timesteps. Set to 0 for no minimum timesteps.
min_sample_timesteps_per_reporting: Minimum env samplingtimesteps to min_sample_timesteps_per_reporting: Minimum env sampling timesteps to
accumulate within a single `train()` call. This value does not affect accumulate within a single `train()` call. This value does not affect
learning, only the number of times `Trainer.step_attempt()` is called by learning, only the number of times `Trainer.step_attempt()` is called by
`Trauber.train()`. If - after one `step_attempt()`, the env sampling `Trauber.train()`. If - after one `step_attempt()`, the env sampling

View file

@ -29,8 +29,8 @@ class DDPGConfig(SimpleQConfig):
>>> from ray import tune >>> from ray import tune
>>> config = DDPGConfig() >>> config = DDPGConfig()
>>> # Print out some default values. >>> # Print out some default values.
>>> print(config.lr) >>> print(config.lr) # doctest: +SKIP
0.0004 0.0004
>>> # Update the config object. >>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env. >>> # Set the config object's env.

View file

@ -1,5 +1,5 @@
# from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG # from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
@ -16,6 +16,7 @@ from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
__all__ = [ __all__ = [
"ApexTrainer", "ApexTrainer",
"APEX_DEFAULT_CONFIG", "APEX_DEFAULT_CONFIG",
"DQNConfig",
"DQNTFPolicy", "DQNTFPolicy",
"DQNTorchPolicy", "DQNTorchPolicy",
"DQNTrainer", "DQNTrainer",

View file

@ -114,15 +114,13 @@ class DQNConfig(SimpleQConfig):
>>> .exploration(exploration_config=explore_config) >>> .exploration(exploration_config=explore_config)
""" """
def __init__(self): def __init__(self, trainer_class=None):
"""Initializes a DQNConfig instance.""" """Initializes a DQNConfig instance."""
super().__init__() super().__init__(trainer_class=trainer_class or DQNTrainer)
# DQN specific # DQN specific config settings.
# fmt: off # fmt: off
# __sphinx_doc_begin__ # __sphinx_doc_begin__
#
self.trainer_class = DQNTrainer
self.num_atoms = 1 self.num_atoms = 1
self.v_min = -10.0 self.v_min = -10.0
self.v_max = 10.0 self.v_max = 10.0
@ -282,20 +280,6 @@ class DQNConfig(SimpleQConfig):
return self return self
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DQNConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]: def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps""" """Calculate the round robin weights for the rollout and train steps"""
if not config["training_intensity"]: if not config["training_intensity"]:
@ -435,6 +419,20 @@ class DQNTrainer(SimpleQTrainer):
return train_results return train_results
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DQNConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config() DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -26,8 +26,8 @@ class PGConfig(TrainerConfig):
>>> from ray import tune >>> from ray import tune
>>> config = PGConfig() >>> config = PGConfig()
>>> # Print out some default values. >>> # Print out some default values.
>>> print(config.lr) >>> print(config.lr) # doctest: +SKIP
... 0.0004 0.0004
>>> # Update the config object. >>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env. >>> # Set the config object's env.

View file

@ -1,8 +1,12 @@
# Counters for sampling and training steps (env- and agent steps). # Counters for sampling and training steps (env- and agent steps).
NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled" NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled" NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter"
NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter"
NUM_ENV_STEPS_TRAINED = "num_env_steps_trained" NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained" NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter"
NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
# Counters to track target network updates. # Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts" LAST_TARGET_UPDATE_TS = "last_target_update_ts"