mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] APEX-DQN and R2D2 config objects. (#25067)
This commit is contained in:
parent
c6edfdd2a0
commit
ec89fe5203
11 changed files with 401 additions and 232 deletions
|
@ -1,5 +1,5 @@
|
|||
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
|
||||
|
@ -13,20 +13,23 @@ from ray.rllib.algorithms.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
|||
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"ApexConfig",
|
||||
"ApexTrainer",
|
||||
"APEX_DEFAULT_CONFIG",
|
||||
"DQNConfig",
|
||||
"DQNTFPolicy",
|
||||
"DQNTorchPolicy",
|
||||
"DQNTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
"R2D2TorchPolicy",
|
||||
"R2D2Trainer",
|
||||
"R2D2_DEFAULT_CONFIG",
|
||||
"SIMPLE_Q_DEFAULT_CONFIG",
|
||||
"SimpleQConfig",
|
||||
"SimpleQTFPolicy",
|
||||
"SimpleQTorchPolicy",
|
||||
"SimpleQTrainer",
|
||||
# Deprecated.
|
||||
"APEX_DEFAULT_CONFIG",
|
||||
"DEFAULT_CONFIG",
|
||||
"R2D2_DEFAULT_CONFIG",
|
||||
"SIMPLE_Q_DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
|
|
@ -16,31 +16,25 @@ from collections import defaultdict
|
|||
import copy
|
||||
import platform
|
||||
import random
|
||||
from typing import Tuple, Dict, List, DefaultDict, Set
|
||||
from typing import Dict, List, DefaultDict, Set
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib import RolloutWorker
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.algorithms.dqn.dqn import (
|
||||
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
|
||||
DQNTrainer,
|
||||
)
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer
|
||||
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_TRAINED_COUNTER,
|
||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||
_get_global_vars,
|
||||
_get_shared_metrics,
|
||||
)
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
asynchronous_parallel_requests,
|
||||
wait_asynchronous_requests,
|
||||
)
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.actors import create_colocated_actors
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
|
@ -53,7 +47,6 @@ from ray.rllib.utils.metrics import (
|
|||
TARGET_NET_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
SampleBatchType,
|
||||
TrainerConfigDict,
|
||||
ResultDict,
|
||||
PartialTrainerConfigDict,
|
||||
|
@ -61,31 +54,96 @@ from ray.rllib.utils.typing import (
|
|||
)
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
# See also the options in dqn.py, which are also supported.
|
||||
DQN_DEFAULT_CONFIG,
|
||||
{
|
||||
"optimizer": merge_dicts(
|
||||
DQN_DEFAULT_CONFIG["optimizer"], {
|
||||
|
||||
class ApexConfig(DQNConfig):
|
||||
"""Defines a configuration class from which an ApexTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||
>>> config = ApexConfig()
|
||||
>>> print(config.replay_buffer_config)
|
||||
>>> replay_config = config.replay_buffer_config.update(
|
||||
>>> {
|
||||
>>> "capacity": 100000,
|
||||
>>> "prioritized_replay_alpha": 0.45,
|
||||
>>> "prioritized_replay_beta": 0.55,
|
||||
>>> "prioritized_replay_eps": 3e-6,
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(replay_buffer_config=replay_config)\
|
||||
>>> .resources(num_gpus=1)\
|
||||
>>> .rollouts(num_rollout_workers=30)\
|
||||
>>> .environment("CartPole-v1")
|
||||
>>> trainer = ApexTrainer(config=config)
|
||||
>>> while True:
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||
>>> from ray import tune
|
||||
>>> config = ApexConfig()
|
||||
>>> config.training(num_atoms=tune.grid_search(list(range(1, 11)))
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> tune.run(
|
||||
>>> "APEX",
|
||||
>>> stop={"episode_reward_mean":200},
|
||||
>>> config=config.to_dict()
|
||||
>>> )
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||
>>> config = ApexConfig()
|
||||
>>> print(config.exploration_config)
|
||||
>>> explore_config = config.exploration_config.update(
|
||||
>>> {
|
||||
>>> "type": "EpsilonGreedy",
|
||||
>>> "initial_epsilon": 0.96,
|
||||
>>> "final_epsilon": 0.01,
|
||||
>>> "epsilone_timesteps": 5000,
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||
>>> .exploration(exploration_config=explore_config)
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||
>>> config = ApexConfig()
|
||||
>>> print(config.exploration_config)
|
||||
>>> explore_config = config.exploration_config.update(
|
||||
>>> {
|
||||
>>> "type": "SoftQ",
|
||||
>>> "temperature": [1.0],
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a ApexConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or ApexTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# APEX-DQN settings overriding DQN ones:
|
||||
# .training()
|
||||
self.optimizer = merge_dicts(
|
||||
DQNConfig().optimizer, {
|
||||
"max_weight_sync_delay": 400,
|
||||
"num_replay_buffer_shards": 4,
|
||||
"debug": False
|
||||
}),
|
||||
"n_step": 3,
|
||||
"num_gpus": 1,
|
||||
"num_workers": 32,
|
||||
|
||||
# TODO(jungong) : add proper replay_buffer_config after
|
||||
# DistributedReplayBuffer type is supported.
|
||||
"replay_buffer_config": {
|
||||
})
|
||||
self.n_step = 3
|
||||
self.train_batch_size = 512
|
||||
self.target_network_update_freq = 500000
|
||||
self.training_intensity = 1
|
||||
# APEX-DQN is using a distributed (non local) replay buffer.
|
||||
self.replay_buffer_config = {
|
||||
"no_local_replay_buffer": True,
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": 2000000,
|
||||
"replay_batch_size": 32,
|
||||
|
@ -106,63 +164,26 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
# on which the learner is located.
|
||||
"replay_buffer_shards_colocated_with_driver": True,
|
||||
"worker_side_prioritization": True,
|
||||
},
|
||||
# Deprecated key.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
}
|
||||
|
||||
"train_batch_size": 512,
|
||||
"rollout_fragment_length": 50,
|
||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
||||
"target_network_update_freq": 500000,
|
||||
# Minimum env sampling timesteps to accumulate within a single `train()` call.
|
||||
# This value does not affect learning, only the number of times
|
||||
# `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one
|
||||
# `step_attempt()`, the env sampling timestep count has not been reached, will
|
||||
# perform n more `step_attempt()` calls until the minimum timesteps have been
|
||||
# executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 25000,
|
||||
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
|
||||
"min_time_s_per_reporting": 30,
|
||||
# This will set the ratio of replayed from a buffer and learned
|
||||
# on timesteps to sampled from an environment and stored in the replay
|
||||
# buffer timesteps. Must be greater than 0.
|
||||
# TODO: Find a way to support None again as a means to replay
|
||||
# proceeding as fast as possible.
|
||||
"training_intensity": 1,
|
||||
},
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
# .rollouts()
|
||||
self.num_workers = 32
|
||||
self.rollout_fragment_length = 50
|
||||
self.exploration_config = {
|
||||
"type": "PerWorkerEpsilonGreedy",
|
||||
}
|
||||
|
||||
# .resources()
|
||||
self.num_gpus = 1
|
||||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
class UpdateWorkerWeights:
|
||||
def __init__(
|
||||
self,
|
||||
learner_thread: LearnerThread,
|
||||
workers: WorkerSet,
|
||||
max_weight_sync_delay: int,
|
||||
):
|
||||
self.learner_thread = learner_thread
|
||||
self.workers = workers
|
||||
self.steps_since_update = defaultdict(int)
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
self.weights = None
|
||||
# .reporting()
|
||||
self.min_time_s_per_reporting = 30
|
||||
self.min_sample_timesteps_per_reporting = 25000
|
||||
|
||||
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
|
||||
actor, batch = item
|
||||
self.steps_since_update[actor] += batch.count
|
||||
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors.
|
||||
if self.weights is None or self.learner_thread.weights_updated:
|
||||
self.learner_thread.weights_updated = False
|
||||
self.weights = ray.put(self.workers.local_worker().get_weights())
|
||||
actor.set_weights.remote(self.weights, _get_global_vars())
|
||||
# Also update global vars of the local worker.
|
||||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
self.steps_since_update[actor] = 0
|
||||
# Update metrics.
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters["num_weight_syncs"] += 1
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
||||
class ApexTrainer(DQNTrainer):
|
||||
|
@ -232,7 +253,7 @@ class ApexTrainer(DQNTrainer):
|
|||
@classmethod
|
||||
@override(DQNTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return APEX_DEFAULT_CONFIG
|
||||
return ApexConfig().to_dict()
|
||||
|
||||
@override(DQNTrainer)
|
||||
def validate_config(self, config):
|
||||
|
@ -548,3 +569,20 @@ class ApexTrainer(DQNTrainer):
|
|||
),
|
||||
strategy=config.get("placement_strategy", "PACK"),
|
||||
)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(ApexConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.dqn.apex.ApexConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
APEX_DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -1,37 +1,99 @@
|
|||
import logging
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.algorithms.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn import DQNConfig, DQNTrainer
|
||||
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
|
||||
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||
DQN_DEFAULT_CONFIG, # See keys in dqn.py, which are also supported.
|
||||
{
|
||||
# Learning rate for adam optimizer.
|
||||
"lr": 1e-4,
|
||||
# Discount factor.
|
||||
"gamma": 0.997,
|
||||
# Train batch size (in number of single timesteps).
|
||||
"train_batch_size": 64,
|
||||
# Adam epsilon hyper parameter
|
||||
"adam_epsilon": 1e-3,
|
||||
# Run in parallel by default.
|
||||
"num_workers": 2,
|
||||
# Batch mode must be complete_episodes.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
class R2D2Config(DQNConfig):
|
||||
"""Defines a configuration class from which a R2D2Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||
>>> config = R2D2Config()
|
||||
>>> print(config.h_function_epsilon)
|
||||
>>> replay_config = config.replay_buffer_config.update(
|
||||
>>> {
|
||||
>>> "capacity": 1000000,
|
||||
>>> "replay_burn_in": 20,
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(replay_buffer_config=replay_config)\
|
||||
>>> .resources(num_gpus=1)\
|
||||
>>> .rollouts(num_rollout_workers=30)\
|
||||
>>> .environment("CartPole-v1")
|
||||
>>> trainer = R2D2Trainer(config=config)
|
||||
>>> while True:
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||
>>> from ray import tune
|
||||
>>> config = R2D2Config()
|
||||
>>> config.training(train_batch_size=tune.grid_search([256, 64])
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> tune.run(
|
||||
>>> "R2D2",
|
||||
>>> stop={"episode_reward_mean":200},
|
||||
>>> config=config.to_dict()
|
||||
>>> )
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||
>>> config = R2D2Config()
|
||||
>>> print(config.exploration_config)
|
||||
>>> explore_config = config.exploration_config.update(
|
||||
>>> {
|
||||
>>> "initial_epsilon": 1.0,
|
||||
>>> "final_epsilon": 0.1,
|
||||
>>> "epsilone_timesteps": 200000,
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||
>>> .exploration(exploration_config=explore_config)
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||
>>> config = R2D2Config()
|
||||
>>> print(config.exploration_config)
|
||||
>>> explore_config = config.exploration_config.update(
|
||||
>>> {
|
||||
>>> "type": "SoftQ",
|
||||
>>> "temperature": [1.0],
|
||||
>>> }
|
||||
>>> )
|
||||
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a ApexConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or R2D2Trainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# R2D2-specific settings:
|
||||
self.zero_init_states = True
|
||||
self.use_h_function = True
|
||||
self.h_function_epsilon = 1e-3
|
||||
|
||||
# R2D2 settings overriding DQN ones:
|
||||
# .training()
|
||||
self.adam_epsilon = 1e-3
|
||||
self.lr = 1e-4
|
||||
self.gamma = 0.997
|
||||
self.train_batch_size = 64
|
||||
self.target_network_update_freq = 2500
|
||||
# R2D2 is using a buffer that stores sequences.
|
||||
self.replay_buffer_config = {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
|
@ -53,34 +115,54 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# used for loss calculation is `n - replay_burn_in` time steps
|
||||
# (n=LSTM’s/attention net’s max_seq_len).
|
||||
"replay_burn_in": 0,
|
||||
},
|
||||
# If True, assume a zero-initialized state input (no matter where in
|
||||
# the episode the sequence is located).
|
||||
# If False, store the initial states along with each SampleBatch, use
|
||||
# it (as initial state when running through the network for training),
|
||||
# and update that initial state during training (from the internal
|
||||
# state outputs of the immediately preceding sequence).
|
||||
"zero_init_states": True,
|
||||
}
|
||||
|
||||
# Whether to use the h-function from the paper [1] to scale target
|
||||
# values in the R2D2-loss function:
|
||||
# h(x) = sign(x)(|x| + 1 − 1) + εx
|
||||
"use_h_function": True,
|
||||
# The epsilon parameter from the R2D2 loss function (only used
|
||||
# if `use_h_function`=True.
|
||||
"h_function_epsilon": 1e-3,
|
||||
# .rollouts()
|
||||
self.num_workers = 2
|
||||
self.batch_mode = "complete_episodes"
|
||||
|
||||
# Update the target network every `target_network_update_freq` sample steps.
|
||||
"target_network_update_freq": 2500,
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
# Deprecated keys:
|
||||
# Use config["replay_buffer_config"]["replay_burn_in"] instead
|
||||
"burn_in": DEPRECATED_VALUE
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
self.burn_in = DEPRECATED_VALUE
|
||||
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
zero_init_states: Optional[bool] = None,
|
||||
use_h_function: Optional[bool] = None,
|
||||
h_function_epsilon: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> "R2D2Config":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
zero_init_states: If True, assume a zero-initialized state input (no
|
||||
matter where in the episode the sequence is located).
|
||||
If False, store the initial states along with each SampleBatch, use
|
||||
it (as initial state when running through the network for training),
|
||||
and update that initial state during training (from the internal
|
||||
state outputs of the immediately preceding sequence).
|
||||
use_h_function: Whether to use the h-function from the paper [1] to scale
|
||||
target values in the R2D2-loss function:
|
||||
h(x) = sign(x)(|x| + 1 − 1) + εx
|
||||
h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
|
||||
used if `use_h_function`=True.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if zero_init_states is not None:
|
||||
self.zero_init_states = zero_init_states
|
||||
if use_h_function is not None:
|
||||
self.use_h_function = use_h_function
|
||||
if h_function_epsilon is not None:
|
||||
self.h_function_epsilon = h_function_epsilon
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# Build an R2D2 trainer, which uses the framework specific Policy
|
||||
|
@ -103,7 +185,7 @@ class R2D2Trainer(DQNTrainer):
|
|||
@classmethod
|
||||
@override(DQNTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return R2D2_DEFAULT_CONFIG
|
||||
return R2D2Config().to_dict()
|
||||
|
||||
@override(DQNTrainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
|
@ -136,3 +218,20 @@ class R2D2Trainer(DQNTrainer):
|
|||
|
||||
if config.get("batch_mode") != "complete_episodes":
|
||||
raise ValueError("`batch_mode` must be 'complete_episodes'!")
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.dqn.r2d2.R2D2Config instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(R2D2Config().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.dqn.r2d2.R2D2Config(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
R2D2_DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -21,17 +21,26 @@ class TestApexDQN(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_apex_zero_workers(self):
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0
|
||||
config["num_gpus"] = 0
|
||||
config["replay_buffer_config"] = {
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
config["min_sample_timesteps_per_reporting"] = 100
|
||||
config["min_time_s_per_reporting"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
config = (
|
||||
apex.ApexConfig()
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.resources(num_gpus=0)
|
||||
.training(
|
||||
replay_buffer_config={
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
optimizer={
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=100,
|
||||
min_time_s_per_reporting=1,
|
||||
)
|
||||
)
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
|
@ -39,19 +48,26 @@ class TestApexDQN(unittest.TestCase):
|
|||
|
||||
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
|
||||
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_gpus"] = 0
|
||||
config["replay_buffer_config"] = {
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
config["min_sample_timesteps_per_reporting"] = 100
|
||||
config["min_time_s_per_reporting"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
config = (
|
||||
apex.ApexConfig()
|
||||
.rollouts(num_rollout_workers=3)
|
||||
.resources(num_gpus=0)
|
||||
.training(
|
||||
replay_buffer_config={
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
optimizer={
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=100,
|
||||
min_time_s_per_reporting=1,
|
||||
)
|
||||
)
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
plain_config = config.copy()
|
||||
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
|
||||
# Test per-worker epsilon distribution.
|
||||
infos = trainer.workers.foreach_policy(
|
||||
|
@ -77,37 +93,43 @@ class TestApexDQN(unittest.TestCase):
|
|||
trainer.stop()
|
||||
|
||||
def test_apex_lr_schedule(self):
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["num_gpus"] = 0
|
||||
config["train_batch_size"] = 10
|
||||
config["rollout_fragment_length"] = 5
|
||||
config["replay_buffer_config"] = {
|
||||
"no_local_replay_buffer": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"learning_starts": 10,
|
||||
"capacity": 100,
|
||||
"replay_batch_size": 10,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
}
|
||||
config["min_sample_timesteps_per_reporting"] = 10
|
||||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which lr depends on, is updated after each worker rollout.
|
||||
config["min_time_s_per_reporting"] = 0
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
# This makes sure learning schedule is checked every 10 timesteps.
|
||||
config["optimizer"]["max_weight_sync_delay"] = 10
|
||||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
config["lr"] = 0.2
|
||||
lr_schedule = [
|
||||
[0, 0.2],
|
||||
[100, 0.001],
|
||||
]
|
||||
config["lr_schedule"] = lr_schedule
|
||||
config = (
|
||||
apex.ApexConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=1,
|
||||
rollout_fragment_length=5,
|
||||
)
|
||||
.resources(num_gpus=0)
|
||||
.training(
|
||||
train_batch_size=10,
|
||||
optimizer={
|
||||
"num_replay_buffer_shards": 1,
|
||||
# This makes sure learning schedule is checked every 10 timesteps.
|
||||
"max_weight_sync_delay": 10,
|
||||
},
|
||||
replay_buffer_config={
|
||||
"no_local_replay_buffer": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"learning_starts": 10,
|
||||
"capacity": 100,
|
||||
"replay_batch_size": 10,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
},
|
||||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
lr=0.2,
|
||||
lr_schedule=[[0, 0.2], [100, 0.001]],
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=10,
|
||||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which lr depends on, is updated after each worker rollout.
|
||||
min_time_s_per_reporting=0,
|
||||
)
|
||||
)
|
||||
|
||||
def _step_n_times(trainer, n: int):
|
||||
"""Step trainer n times.
|
||||
|
@ -122,7 +144,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
]
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
|
||||
lr = _step_n_times(trainer, 5) # 50 timesteps
|
||||
# Close to 0.2
|
||||
|
|
|
@ -47,26 +47,30 @@ class TestR2D2(unittest.TestCase):
|
|||
|
||||
def test_r2d2_compilation(self):
|
||||
"""Test whether a R2D2Trainer can be built on all frameworks."""
|
||||
config = dqn.R2D2_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
# Wrap with an LSTM and use a very simple base-model.
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["max_seq_len"] = 20
|
||||
config["model"]["fcnet_hiddens"] = [32]
|
||||
config["model"]["lstm_cell_size"] = 64
|
||||
|
||||
config["replay_buffer_config"]["replay_burn_in"] = 20
|
||||
config["zero_init_states"] = True
|
||||
|
||||
config["dueling"] = False
|
||||
config["lr"] = 5e-4
|
||||
config["exploration_config"]["epsilon_timesteps"] = 100000
|
||||
config = (
|
||||
dqn.r2d2.R2D2Config()
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.training(
|
||||
model={
|
||||
# Wrap with an LSTM and use a very simple base-model.
|
||||
"use_lstm": True,
|
||||
"max_seq_len": 20,
|
||||
"fcnet_hiddens": [32],
|
||||
"lstm_cell_size": 64,
|
||||
},
|
||||
dueling=False,
|
||||
lr=5e-4,
|
||||
zero_init_states=True,
|
||||
replay_buffer_config={"replay_burn_in": 20},
|
||||
)
|
||||
.exploration(exploration_config={"epsilon_timesteps": 100000})
|
||||
)
|
||||
|
||||
num_iterations = 1
|
||||
|
||||
# Test building an R2D2 agent in all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0")
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
|||
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
||||
WeightedImportanceSampling,
|
||||
)
|
||||
from ray.rllib.utils import deep_update
|
||||
from ray.rllib.utils import deep_update, merge_dicts
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvConfigDict,
|
||||
|
@ -715,7 +715,7 @@ class TrainerConfig:
|
|||
if model is not None:
|
||||
self.model = model
|
||||
if optimizer is not None:
|
||||
self.optimizer = optimizer
|
||||
self.optimizer = merge_dicts(self.optimizer, optimizer)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -1061,7 +1061,7 @@ class TrainerConfig:
|
|||
timestep count has not been reached, will perform n more
|
||||
`step_attempt()` calls until the minimum timesteps have been executed.
|
||||
Set to 0 for no minimum timesteps.
|
||||
min_sample_timesteps_per_reporting: Minimum env samplingtimesteps to
|
||||
min_sample_timesteps_per_reporting: Minimum env sampling timesteps to
|
||||
accumulate within a single `train()` call. This value does not affect
|
||||
learning, only the number of times `Trainer.step_attempt()` is called by
|
||||
`Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
||||
|
|
|
@ -29,8 +29,8 @@ class DDPGConfig(SimpleQConfig):
|
|||
>>> from ray import tune
|
||||
>>> config = DDPGConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.lr)
|
||||
0.0004
|
||||
>>> print(config.lr) # doctest: +SKIP
|
||||
0.0004
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||
>>> # Set the config object's env.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
|
||||
|
@ -16,6 +16,7 @@ from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
|||
__all__ = [
|
||||
"ApexTrainer",
|
||||
"APEX_DEFAULT_CONFIG",
|
||||
"DQNConfig",
|
||||
"DQNTFPolicy",
|
||||
"DQNTorchPolicy",
|
||||
"DQNTrainer",
|
||||
|
|
|
@ -114,15 +114,13 @@ class DQNConfig(SimpleQConfig):
|
|||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a DQNConfig instance."""
|
||||
super().__init__()
|
||||
super().__init__(trainer_class=trainer_class or DQNTrainer)
|
||||
|
||||
# DQN specific
|
||||
# DQN specific config settings.
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
#
|
||||
self.trainer_class = DQNTrainer
|
||||
self.num_atoms = 1
|
||||
self.v_min = -10.0
|
||||
self.v_max = 10.0
|
||||
|
@ -282,20 +280,6 @@ class DQNConfig(SimpleQConfig):
|
|||
return self
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(DQNConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
||||
"""Calculate the round robin weights for the rollout and train steps"""
|
||||
if not config["training_intensity"]:
|
||||
|
@ -435,6 +419,20 @@ class DQNTrainer(SimpleQTrainer):
|
|||
return train_results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(DQNConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ class PGConfig(TrainerConfig):
|
|||
>>> from ray import tune
|
||||
>>> config = PGConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.lr)
|
||||
... 0.0004
|
||||
>>> print(config.lr) # doctest: +SKIP
|
||||
0.0004
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||
>>> # Set the config object's env.
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
# Counters for sampling and training steps (env- and agent steps).
|
||||
NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
|
||||
NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
|
||||
NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter"
|
||||
NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter"
|
||||
NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
|
||||
NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
|
||||
NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter"
|
||||
NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
|
||||
|
||||
# Counters to track target network updates.
|
||||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||
|
|
Loading…
Add table
Reference in a new issue