[RLlib] APEX-DQN and R2D2 config objects. (#25067)

This commit is contained in:
Sven Mika 2022-05-23 12:15:45 +02:00 committed by GitHub
parent c6edfdd2a0
commit ec89fe5203
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 401 additions and 232 deletions

View file

@ -1,5 +1,5 @@
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
@ -13,20 +13,23 @@ from ray.rllib.algorithms.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
__all__ = [
"ApexConfig",
"ApexTrainer",
"APEX_DEFAULT_CONFIG",
"DQNConfig",
"DQNTFPolicy",
"DQNTorchPolicy",
"DQNTrainer",
"DEFAULT_CONFIG",
"R2D2TorchPolicy",
"R2D2Trainer",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
"SimpleQConfig",
"SimpleQTFPolicy",
"SimpleQTorchPolicy",
"SimpleQTrainer",
# Deprecated.
"APEX_DEFAULT_CONFIG",
"DEFAULT_CONFIG",
"R2D2_DEFAULT_CONFIG",
"SIMPLE_Q_DEFAULT_CONFIG",
]
from ray.rllib.utils.deprecation import deprecation_warning

View file

@ -16,31 +16,25 @@ from collections import defaultdict
import copy
import platform
import random
from typing import Tuple, Dict, List, DefaultDict, Set
from typing import Dict, List, DefaultDict, Set
import ray
from ray.actor import ActorHandle
from ray.rllib import RolloutWorker
from ray.rllib.agents import Trainer
from ray.rllib.algorithms.dqn.dqn import (
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
DQNTrainer,
)
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.parallel_requests import (
asynchronous_parallel_requests,
wait_asynchronous_requests,
)
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
@ -53,7 +47,6 @@ from ray.rllib.utils.metrics import (
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
SampleBatchType,
TrainerConfigDict,
ResultDict,
PartialTrainerConfigDict,
@ -61,31 +54,96 @@ from ray.rllib.utils.typing import (
)
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.util.ml_utils.dict import merge_dicts
# fmt: off
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
# See also the options in dqn.py, which are also supported.
DQN_DEFAULT_CONFIG,
{
"optimizer": merge_dicts(
DQN_DEFAULT_CONFIG["optimizer"], {
class ApexConfig(DQNConfig):
"""Defines a configuration class from which an ApexTrainer can be built.
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.replay_buffer_config)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 100000,
>>> "prioritized_replay_alpha": 0.45,
>>> "prioritized_replay_beta": 0.55,
>>> "prioritized_replay_eps": 3e-6,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = ApexTrainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> from ray import tune
>>> config = ApexConfig()
>>> config.training(num_atoms=tune.grid_search(list(range(1, 11)))
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "APEX",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "EpsilonGreedy",
>>> "initial_epsilon": 0.96,
>>> "final_epsilon": 0.01,
>>> "epsilone_timesteps": 5000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.apex import ApexConfig
>>> config = ApexConfig()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "SoftQ",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or ApexTrainer)
# fmt: off
# __sphinx_doc_begin__
# APEX-DQN settings overriding DQN ones:
# .training()
self.optimizer = merge_dicts(
DQNConfig().optimizer, {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
}),
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": {
})
self.n_step = 3
self.train_batch_size = 512
self.target_network_update_freq = 500000
self.training_intensity = 1
# APEX-DQN is using a distributed (non local) replay buffer.
self.replay_buffer_config = {
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"prioritized_replay": DEPRECATED_VALUE,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000,
"replay_batch_size": 32,
@ -106,63 +164,26 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},
# Deprecated key.
"prioritized_replay": DEPRECATED_VALUE,
}
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
# Minimum env sampling timesteps to accumulate within a single `train()` call.
# This value does not affect learning, only the number of times
# `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one
# `step_attempt()`, the env sampling timestep count has not been reached, will
# perform n more `step_attempt()` calls until the minimum timesteps have been
# executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"min_time_s_per_reporting": 30,
# This will set the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
# buffer timesteps. Must be greater than 0.
# TODO: Find a way to support None again as a means to replay
# proceeding as fast as possible.
"training_intensity": 1,
},
)
# __sphinx_doc_end__
# fmt: on
# .rollouts()
self.num_workers = 32
self.rollout_fragment_length = 50
self.exploration_config = {
"type": "PerWorkerEpsilonGreedy",
}
# .resources()
self.num_gpus = 1
# Update worker weights as they finish generating experiences.
class UpdateWorkerWeights:
def __init__(
self,
learner_thread: LearnerThread,
workers: WorkerSet,
max_weight_sync_delay: int,
):
self.learner_thread = learner_thread
self.workers = workers
self.steps_since_update = defaultdict(int)
self.max_weight_sync_delay = max_weight_sync_delay
self.weights = None
# .reporting()
self.min_time_s_per_reporting = 30
self.min_sample_timesteps_per_reporting = 25000
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
actor, batch = item
self.steps_since_update[actor] += batch.count
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors.
if self.weights is None or self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
self.weights = ray.put(self.workers.local_worker().get_weights())
actor.set_weights.remote(self.weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
self.steps_since_update[actor] = 0
# Update metrics.
metrics = _get_shared_metrics()
metrics.counters["num_weight_syncs"] += 1
# fmt: on
# __sphinx_doc_end__
class ApexTrainer(DQNTrainer):
@ -232,7 +253,7 @@ class ApexTrainer(DQNTrainer):
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DEFAULT_CONFIG
return ApexConfig().to_dict()
@override(DQNTrainer)
def validate_config(self, config):
@ -548,3 +569,20 @@ class ApexTrainer(DQNTrainer):
),
strategy=config.get("placement_strategy", "PACK"),
)
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(ApexConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.apex.ApexConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
APEX_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -1,37 +1,99 @@
import logging
from typing import Type
from typing import Optional, Type
from ray.rllib.algorithms.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn import DQNConfig, DQNTrainer
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
DQN_DEFAULT_CONFIG, # See keys in dqn.py, which are also supported.
{
# Learning rate for adam optimizer.
"lr": 1e-4,
# Discount factor.
"gamma": 0.997,
# Train batch size (in number of single timesteps).
"train_batch_size": 64,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-3,
# Run in parallel by default.
"num_workers": 2,
# Batch mode must be complete_episodes.
"batch_mode": "complete_episodes",
# === Replay buffer ===
"replay_buffer_config": {
class R2D2Config(DQNConfig):
"""Defines a configuration class from which a R2D2Trainer can be built.
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.h_function_epsilon)
>>> replay_config = config.replay_buffer_config.update(
>>> {
>>> "capacity": 1000000,
>>> "replay_burn_in": 20,
>>> }
>>> )
>>> config.training(replay_buffer_config=replay_config)\
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = R2D2Trainer(config=config)
>>> while True:
>>> trainer.train()
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> from ray import tune
>>> config = R2D2Config()
>>> config.training(train_batch_size=tune.grid_search([256, 64])
>>> config.environment(env="CartPole-v1")
>>> tune.run(
>>> "R2D2",
>>> stop={"episode_reward_mean":200},
>>> config=config.to_dict()
>>> )
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "initial_epsilon": 1.0,
>>> "final_epsilon": 0.1,
>>> "epsilone_timesteps": 200000,
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
Example:
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
>>> config = R2D2Config()
>>> print(config.exploration_config)
>>> explore_config = config.exploration_config.update(
>>> {
>>> "type": "SoftQ",
>>> "temperature": [1.0],
>>> }
>>> )
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or R2D2Trainer)
# fmt: off
# __sphinx_doc_begin__
# R2D2-specific settings:
self.zero_init_states = True
self.use_h_function = True
self.h_function_epsilon = 1e-3
# R2D2 settings overriding DQN ones:
# .training()
self.adam_epsilon = 1e-3
self.lr = 1e-4
self.gamma = 0.997
self.train_batch_size = 64
self.target_network_update_freq = 2500
# R2D2 is using a buffer that stores sequences.
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
@ -53,34 +115,54 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# used for loss calculation is `n - replay_burn_in` time steps
# (n=LSTMs/attention nets max_seq_len).
"replay_burn_in": 0,
},
# If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use
# it (as initial state when running through the network for training),
# and update that initial state during training (from the internal
# state outputs of the immediately preceding sequence).
"zero_init_states": True,
}
# Whether to use the h-function from the paper [1] to scale target
# values in the R2D2-loss function:
# h(x) = sign(x)(􏰅|x| + 1 1) + εx
"use_h_function": True,
# The epsilon parameter from the R2D2 loss function (only used
# if `use_h_function`=True.
"h_function_epsilon": 1e-3,
# .rollouts()
self.num_workers = 2
self.batch_mode = "complete_episodes"
# Update the target network every `target_network_update_freq` sample steps.
"target_network_update_freq": 2500,
# fmt: on
# __sphinx_doc_end__
# Deprecated keys:
# Use config["replay_buffer_config"]["replay_burn_in"] instead
"burn_in": DEPRECATED_VALUE
},
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# fmt: on
self.burn_in = DEPRECATED_VALUE
def training(
self,
*,
zero_init_states: Optional[bool] = None,
use_h_function: Optional[bool] = None,
h_function_epsilon: Optional[float] = None,
**kwargs,
) -> "R2D2Config":
"""Sets the training related configuration.
Args:
zero_init_states: If True, assume a zero-initialized state input (no
matter where in the episode the sequence is located).
If False, store the initial states along with each SampleBatch, use
it (as initial state when running through the network for training),
and update that initial state during training (from the internal
state outputs of the immediately preceding sequence).
use_h_function: Whether to use the h-function from the paper [1] to scale
target values in the R2D2-loss function:
h(x) = sign(x)(􏰅|x| + 1 1) + εx
h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
used if `use_h_function`=True.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if zero_init_states is not None:
self.zero_init_states = zero_init_states
if use_h_function is not None:
self.use_h_function = use_h_function
if h_function_epsilon is not None:
self.h_function_epsilon = h_function_epsilon
return self
# Build an R2D2 trainer, which uses the framework specific Policy
@ -103,7 +185,7 @@ class R2D2Trainer(DQNTrainer):
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return R2D2_DEFAULT_CONFIG
return R2D2Config().to_dict()
@override(DQNTrainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@ -136,3 +218,20 @@ class R2D2Trainer(DQNTrainer):
if config.get("batch_mode") != "complete_episodes":
raise ValueError("`batch_mode` must be 'complete_episodes'!")
# Deprecated: Use ray.rllib.agents.dqn.r2d2.R2D2Config instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(R2D2Config().to_dict())
@Deprecated(
old="ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG",
new="ray.rllib.agents.dqn.r2d2.R2D2Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
R2D2_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -21,17 +21,26 @@ class TestApexDQN(unittest.TestCase):
ray.shutdown()
def test_apex_zero_workers(self):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["num_gpus"] = 0
config["replay_buffer_config"] = {
"learning_starts": 1000,
}
config["min_sample_timesteps_per_reporting"] = 100
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
config = (
apex.ApexConfig()
.rollouts(num_rollout_workers=0)
.resources(num_gpus=0)
.training(
replay_buffer_config={
"learning_starts": 1000,
},
optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
results = trainer.train()
check_train_results(results)
print(results)
@ -39,19 +48,26 @@ class TestApexDQN(unittest.TestCase):
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["num_gpus"] = 0
config["replay_buffer_config"] = {
"learning_starts": 1000,
}
config["min_sample_timesteps_per_reporting"] = 100
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
config = (
apex.ApexConfig()
.rollouts(num_rollout_workers=3)
.resources(num_gpus=0)
.training(
replay_buffer_config={
"learning_starts": 1000,
},
optimizer={
"num_replay_buffer_shards": 1,
},
)
.reporting(
min_sample_timesteps_per_reporting=100,
min_time_s_per_reporting=1,
)
)
for _ in framework_iterator(config, with_eager_tracing=True):
plain_config = config.copy()
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
# Test per-worker epsilon distribution.
infos = trainer.workers.foreach_policy(
@ -77,37 +93,43 @@ class TestApexDQN(unittest.TestCase):
trainer.stop()
def test_apex_lr_schedule(self):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["num_gpus"] = 0
config["train_batch_size"] = 10
config["rollout_fragment_length"] = 5
config["replay_buffer_config"] = {
"no_local_replay_buffer": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"learning_starts": 10,
"capacity": 100,
"replay_batch_size": 10,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
}
config["min_sample_timesteps_per_reporting"] = 10
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
config["min_time_s_per_reporting"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
# This makes sure learning schedule is checked every 10 timesteps.
config["optimizer"]["max_weight_sync_delay"] = 10
# Initial lr, doesn't really matter because of the schedule below.
config["lr"] = 0.2
lr_schedule = [
[0, 0.2],
[100, 0.001],
]
config["lr_schedule"] = lr_schedule
config = (
apex.ApexConfig()
.rollouts(
num_rollout_workers=1,
rollout_fragment_length=5,
)
.resources(num_gpus=0)
.training(
train_batch_size=10,
optimizer={
"num_replay_buffer_shards": 1,
# This makes sure learning schedule is checked every 10 timesteps.
"max_weight_sync_delay": 10,
},
replay_buffer_config={
"no_local_replay_buffer": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"learning_starts": 10,
"capacity": 100,
"replay_batch_size": 10,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Initial lr, doesn't really matter because of the schedule below.
lr=0.2,
lr_schedule=[[0, 0.2], [100, 0.001]],
)
.reporting(
min_sample_timesteps_per_reporting=10,
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
min_time_s_per_reporting=0,
)
)
def _step_n_times(trainer, n: int):
"""Step trainer n times.
@ -122,7 +144,7 @@ class TestApexDQN(unittest.TestCase):
]
for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
lr = _step_n_times(trainer, 5) # 50 timesteps
# Close to 0.2

View file

@ -47,26 +47,30 @@ class TestR2D2(unittest.TestCase):
def test_r2d2_compilation(self):
"""Test whether a R2D2Trainer can be built on all frameworks."""
config = dqn.R2D2_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# Wrap with an LSTM and use a very simple base-model.
config["model"]["use_lstm"] = True
config["model"]["max_seq_len"] = 20
config["model"]["fcnet_hiddens"] = [32]
config["model"]["lstm_cell_size"] = 64
config["replay_buffer_config"]["replay_burn_in"] = 20
config["zero_init_states"] = True
config["dueling"] = False
config["lr"] = 5e-4
config["exploration_config"]["epsilon_timesteps"] = 100000
config = (
dqn.r2d2.R2D2Config()
.rollouts(num_rollout_workers=0)
.training(
model={
# Wrap with an LSTM and use a very simple base-model.
"use_lstm": True,
"max_seq_len": 20,
"fcnet_hiddens": [32],
"lstm_cell_size": 64,
},
dueling=False,
lr=5e-4,
zero_init_states=True,
replay_buffer_config={"replay_burn_in": 20},
)
.exploration(exploration_config={"epsilon_timesteps": 100000})
)
num_iterations = 1
# Test building an R2D2 agent in all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0")
trainer = config.build(env="CartPole-v0")
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)

View file

@ -19,7 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils import deep_update
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import (
EnvConfigDict,
@ -715,7 +715,7 @@ class TrainerConfig:
if model is not None:
self.model = model
if optimizer is not None:
self.optimizer = optimizer
self.optimizer = merge_dicts(self.optimizer, optimizer)
return self
@ -1061,7 +1061,7 @@ class TrainerConfig:
timestep count has not been reached, will perform n more
`step_attempt()` calls until the minimum timesteps have been executed.
Set to 0 for no minimum timesteps.
min_sample_timesteps_per_reporting: Minimum env samplingtimesteps to
min_sample_timesteps_per_reporting: Minimum env sampling timesteps to
accumulate within a single `train()` call. This value does not affect
learning, only the number of times `Trainer.step_attempt()` is called by
`Trauber.train()`. If - after one `step_attempt()`, the env sampling

View file

@ -29,8 +29,8 @@ class DDPGConfig(SimpleQConfig):
>>> from ray import tune
>>> config = DDPGConfig()
>>> # Print out some default values.
>>> print(config.lr)
0.0004
>>> print(config.lr) # doctest: +SKIP
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.

View file

@ -1,5 +1,5 @@
# from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
@ -16,6 +16,7 @@ from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
__all__ = [
"ApexTrainer",
"APEX_DEFAULT_CONFIG",
"DQNConfig",
"DQNTFPolicy",
"DQNTorchPolicy",
"DQNTrainer",

View file

@ -114,15 +114,13 @@ class DQNConfig(SimpleQConfig):
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self):
def __init__(self, trainer_class=None):
"""Initializes a DQNConfig instance."""
super().__init__()
super().__init__(trainer_class=trainer_class or DQNTrainer)
# DQN specific
# DQN specific config settings.
# fmt: off
# __sphinx_doc_begin__
#
self.trainer_class = DQNTrainer
self.num_atoms = 1
self.v_min = -10.0
self.v_max = 10.0
@ -282,20 +280,6 @@ class DQNConfig(SimpleQConfig):
return self
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DQNConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
if not config["training_intensity"]:
@ -435,6 +419,20 @@ class DQNTrainer(SimpleQTrainer):
return train_results
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DQNConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -26,8 +26,8 @@ class PGConfig(TrainerConfig):
>>> from ray import tune
>>> config = PGConfig()
>>> # Print out some default values.
>>> print(config.lr)
... 0.0004
>>> print(config.lr) # doctest: +SKIP
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.

View file

@ -1,8 +1,12 @@
# Counters for sampling and training steps (env- and agent steps).
NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter"
NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter"
NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter"
NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
# Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts"