mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] APEX-DQN and R2D2 config objects. (#25067)
This commit is contained in:
parent
c6edfdd2a0
commit
ec89fe5203
11 changed files with 401 additions and 232 deletions
|
@ -1,5 +1,5 @@
|
||||||
from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
|
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
|
||||||
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
|
||||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||||
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
|
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
|
||||||
|
@ -13,20 +13,23 @@ from ray.rllib.algorithms.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||||
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ApexConfig",
|
||||||
"ApexTrainer",
|
"ApexTrainer",
|
||||||
"APEX_DEFAULT_CONFIG",
|
"DQNConfig",
|
||||||
"DQNTFPolicy",
|
"DQNTFPolicy",
|
||||||
"DQNTorchPolicy",
|
"DQNTorchPolicy",
|
||||||
"DQNTrainer",
|
"DQNTrainer",
|
||||||
"DEFAULT_CONFIG",
|
|
||||||
"R2D2TorchPolicy",
|
"R2D2TorchPolicy",
|
||||||
"R2D2Trainer",
|
"R2D2Trainer",
|
||||||
"R2D2_DEFAULT_CONFIG",
|
|
||||||
"SIMPLE_Q_DEFAULT_CONFIG",
|
|
||||||
"SimpleQConfig",
|
"SimpleQConfig",
|
||||||
"SimpleQTFPolicy",
|
"SimpleQTFPolicy",
|
||||||
"SimpleQTorchPolicy",
|
"SimpleQTorchPolicy",
|
||||||
"SimpleQTrainer",
|
"SimpleQTrainer",
|
||||||
|
# Deprecated.
|
||||||
|
"APEX_DEFAULT_CONFIG",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
|
"R2D2_DEFAULT_CONFIG",
|
||||||
|
"SIMPLE_Q_DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
|
|
|
@ -16,31 +16,25 @@ from collections import defaultdict
|
||||||
import copy
|
import copy
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
from typing import Tuple, Dict, List, DefaultDict, Set
|
from typing import Dict, List, DefaultDict, Set
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.rllib import RolloutWorker
|
|
||||||
from ray.rllib.agents import Trainer
|
from ray.rllib.agents import Trainer
|
||||||
from ray.rllib.algorithms.dqn.dqn import (
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer
|
||||||
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
|
|
||||||
DQNTrainer,
|
|
||||||
)
|
|
||||||
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
|
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.execution.common import (
|
from ray.rllib.execution.common import (
|
||||||
STEPS_TRAINED_COUNTER,
|
STEPS_TRAINED_COUNTER,
|
||||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||||
_get_global_vars,
|
|
||||||
_get_shared_metrics,
|
|
||||||
)
|
)
|
||||||
from ray.rllib.execution.parallel_requests import (
|
from ray.rllib.execution.parallel_requests import (
|
||||||
asynchronous_parallel_requests,
|
asynchronous_parallel_requests,
|
||||||
wait_asynchronous_requests,
|
wait_asynchronous_requests,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils import merge_dicts
|
|
||||||
from ray.rllib.utils.actors import create_colocated_actors
|
from ray.rllib.utils.actors import create_colocated_actors
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.metrics import (
|
from ray.rllib.utils.metrics import (
|
||||||
LAST_TARGET_UPDATE_TS,
|
LAST_TARGET_UPDATE_TS,
|
||||||
NUM_AGENT_STEPS_SAMPLED,
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
@ -53,7 +47,6 @@ from ray.rllib.utils.metrics import (
|
||||||
TARGET_NET_UPDATE_TIMER,
|
TARGET_NET_UPDATE_TIMER,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
SampleBatchType,
|
|
||||||
TrainerConfigDict,
|
TrainerConfigDict,
|
||||||
ResultDict,
|
ResultDict,
|
||||||
PartialTrainerConfigDict,
|
PartialTrainerConfigDict,
|
||||||
|
@ -61,31 +54,96 @@ from ray.rllib.utils.typing import (
|
||||||
)
|
)
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.util.ml_utils.dict import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
|
class ApexConfig(DQNConfig):
|
||||||
|
"""Defines a configuration class from which an ApexTrainer can be built.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||||
|
>>> config = ApexConfig()
|
||||||
|
>>> print(config.replay_buffer_config)
|
||||||
|
>>> replay_config = config.replay_buffer_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "capacity": 100000,
|
||||||
|
>>> "prioritized_replay_alpha": 0.45,
|
||||||
|
>>> "prioritized_replay_beta": 0.55,
|
||||||
|
>>> "prioritized_replay_eps": 3e-6,
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(replay_buffer_config=replay_config)\
|
||||||
|
>>> .resources(num_gpus=1)\
|
||||||
|
>>> .rollouts(num_rollout_workers=30)\
|
||||||
|
>>> .environment("CartPole-v1")
|
||||||
|
>>> trainer = ApexTrainer(config=config)
|
||||||
|
>>> while True:
|
||||||
|
>>> trainer.train()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||||
|
>>> from ray import tune
|
||||||
|
>>> config = ApexConfig()
|
||||||
|
>>> config.training(num_atoms=tune.grid_search(list(range(1, 11)))
|
||||||
|
>>> config.environment(env="CartPole-v1")
|
||||||
|
>>> tune.run(
|
||||||
|
>>> "APEX",
|
||||||
|
>>> stop={"episode_reward_mean":200},
|
||||||
|
>>> config=config.to_dict()
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||||
|
>>> config = ApexConfig()
|
||||||
|
>>> print(config.exploration_config)
|
||||||
|
>>> explore_config = config.exploration_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "type": "EpsilonGreedy",
|
||||||
|
>>> "initial_epsilon": 0.96,
|
||||||
|
>>> "final_epsilon": 0.01,
|
||||||
|
>>> "epsilone_timesteps": 5000,
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||||
|
>>> .exploration(exploration_config=explore_config)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.apex import ApexConfig
|
||||||
|
>>> config = ApexConfig()
|
||||||
|
>>> print(config.exploration_config)
|
||||||
|
>>> explore_config = config.exploration_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "type": "SoftQ",
|
||||||
|
>>> "temperature": [1.0],
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||||
|
>>> .exploration(exploration_config=explore_config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainer_class=None):
|
||||||
|
"""Initializes a ApexConfig instance."""
|
||||||
|
super().__init__(trainer_class=trainer_class or ApexTrainer)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# __sphinx_doc_begin__
|
# __sphinx_doc_begin__
|
||||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
# APEX-DQN settings overriding DQN ones:
|
||||||
# See also the options in dqn.py, which are also supported.
|
# .training()
|
||||||
DQN_DEFAULT_CONFIG,
|
self.optimizer = merge_dicts(
|
||||||
{
|
DQNConfig().optimizer, {
|
||||||
"optimizer": merge_dicts(
|
|
||||||
DQN_DEFAULT_CONFIG["optimizer"], {
|
|
||||||
"max_weight_sync_delay": 400,
|
"max_weight_sync_delay": 400,
|
||||||
"num_replay_buffer_shards": 4,
|
"num_replay_buffer_shards": 4,
|
||||||
"debug": False
|
"debug": False
|
||||||
}),
|
})
|
||||||
"n_step": 3,
|
self.n_step = 3
|
||||||
"num_gpus": 1,
|
self.train_batch_size = 512
|
||||||
"num_workers": 32,
|
self.target_network_update_freq = 500000
|
||||||
|
self.training_intensity = 1
|
||||||
# TODO(jungong) : add proper replay_buffer_config after
|
# APEX-DQN is using a distributed (non local) replay buffer.
|
||||||
# DistributedReplayBuffer type is supported.
|
self.replay_buffer_config = {
|
||||||
"replay_buffer_config": {
|
|
||||||
"no_local_replay_buffer": True,
|
"no_local_replay_buffer": True,
|
||||||
# Specify prioritized replay by supplying a buffer type that supports
|
# Specify prioritized replay by supplying a buffer type that supports
|
||||||
# prioritization
|
# prioritization
|
||||||
"prioritized_replay": DEPRECATED_VALUE,
|
|
||||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||||
"capacity": 2000000,
|
"capacity": 2000000,
|
||||||
"replay_batch_size": 32,
|
"replay_batch_size": 32,
|
||||||
|
@ -106,63 +164,26 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||||
# on which the learner is located.
|
# on which the learner is located.
|
||||||
"replay_buffer_shards_colocated_with_driver": True,
|
"replay_buffer_shards_colocated_with_driver": True,
|
||||||
"worker_side_prioritization": True,
|
"worker_side_prioritization": True,
|
||||||
},
|
# Deprecated key.
|
||||||
|
"prioritized_replay": DEPRECATED_VALUE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# .rollouts()
|
||||||
|
self.num_workers = 32
|
||||||
|
self.rollout_fragment_length = 50
|
||||||
|
self.exploration_config = {
|
||||||
|
"type": "PerWorkerEpsilonGreedy",
|
||||||
|
}
|
||||||
|
|
||||||
|
# .resources()
|
||||||
|
self.num_gpus = 1
|
||||||
|
|
||||||
|
# .reporting()
|
||||||
|
self.min_time_s_per_reporting = 30
|
||||||
|
self.min_sample_timesteps_per_reporting = 25000
|
||||||
|
|
||||||
"train_batch_size": 512,
|
|
||||||
"rollout_fragment_length": 50,
|
|
||||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
|
||||||
"target_network_update_freq": 500000,
|
|
||||||
# Minimum env sampling timesteps to accumulate within a single `train()` call.
|
|
||||||
# This value does not affect learning, only the number of times
|
|
||||||
# `Trainer.step_attempt()` is called by `Trainer.train()`. If - after one
|
|
||||||
# `step_attempt()`, the env sampling timestep count has not been reached, will
|
|
||||||
# perform n more `step_attempt()` calls until the minimum timesteps have been
|
|
||||||
# executed. Set to 0 for no minimum timesteps.
|
|
||||||
"min_sample_timesteps_per_reporting": 25000,
|
|
||||||
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
|
|
||||||
"min_time_s_per_reporting": 30,
|
|
||||||
# This will set the ratio of replayed from a buffer and learned
|
|
||||||
# on timesteps to sampled from an environment and stored in the replay
|
|
||||||
# buffer timesteps. Must be greater than 0.
|
|
||||||
# TODO: Find a way to support None again as a means to replay
|
|
||||||
# proceeding as fast as possible.
|
|
||||||
"training_intensity": 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# __sphinx_doc_end__
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
|
||||||
# Update worker weights as they finish generating experiences.
|
|
||||||
class UpdateWorkerWeights:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
learner_thread: LearnerThread,
|
|
||||||
workers: WorkerSet,
|
|
||||||
max_weight_sync_delay: int,
|
|
||||||
):
|
|
||||||
self.learner_thread = learner_thread
|
|
||||||
self.workers = workers
|
|
||||||
self.steps_since_update = defaultdict(int)
|
|
||||||
self.max_weight_sync_delay = max_weight_sync_delay
|
|
||||||
self.weights = None
|
|
||||||
|
|
||||||
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
|
|
||||||
actor, batch = item
|
|
||||||
self.steps_since_update[actor] += batch.count
|
|
||||||
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
|
|
||||||
# Note that it's important to pull new weights once
|
|
||||||
# updated to avoid excessive correlation between actors.
|
|
||||||
if self.weights is None or self.learner_thread.weights_updated:
|
|
||||||
self.learner_thread.weights_updated = False
|
|
||||||
self.weights = ray.put(self.workers.local_worker().get_weights())
|
|
||||||
actor.set_weights.remote(self.weights, _get_global_vars())
|
|
||||||
# Also update global vars of the local worker.
|
|
||||||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
|
||||||
self.steps_since_update[actor] = 0
|
|
||||||
# Update metrics.
|
|
||||||
metrics = _get_shared_metrics()
|
|
||||||
metrics.counters["num_weight_syncs"] += 1
|
|
||||||
|
|
||||||
|
|
||||||
class ApexTrainer(DQNTrainer):
|
class ApexTrainer(DQNTrainer):
|
||||||
|
@ -232,7 +253,7 @@ class ApexTrainer(DQNTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(DQNTrainer)
|
@override(DQNTrainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return APEX_DEFAULT_CONFIG
|
return ApexConfig().to_dict()
|
||||||
|
|
||||||
@override(DQNTrainer)
|
@override(DQNTrainer)
|
||||||
def validate_config(self, config):
|
def validate_config(self, config):
|
||||||
|
@ -548,3 +569,20 @@ class ApexTrainer(DQNTrainer):
|
||||||
),
|
),
|
||||||
strategy=config.get("placement_strategy", "PACK"),
|
strategy=config.get("placement_strategy", "PACK"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(ApexConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.dqn.apex.APEX_DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.dqn.apex.ApexConfig(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
APEX_DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -1,37 +1,99 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from ray.rllib.algorithms.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_DEFAULT_CONFIG
|
from ray.rllib.algorithms.dqn import DQNConfig, DQNTrainer
|
||||||
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
|
from ray.rllib.agents.dqn.r2d2_tf_policy import R2D2TFPolicy
|
||||||
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
|
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer
|
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class R2D2Config(DQNConfig):
|
||||||
|
"""Defines a configuration class from which a R2D2Trainer can be built.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||||
|
>>> config = R2D2Config()
|
||||||
|
>>> print(config.h_function_epsilon)
|
||||||
|
>>> replay_config = config.replay_buffer_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "capacity": 1000000,
|
||||||
|
>>> "replay_burn_in": 20,
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(replay_buffer_config=replay_config)\
|
||||||
|
>>> .resources(num_gpus=1)\
|
||||||
|
>>> .rollouts(num_rollout_workers=30)\
|
||||||
|
>>> .environment("CartPole-v1")
|
||||||
|
>>> trainer = R2D2Trainer(config=config)
|
||||||
|
>>> while True:
|
||||||
|
>>> trainer.train()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||||
|
>>> from ray import tune
|
||||||
|
>>> config = R2D2Config()
|
||||||
|
>>> config.training(train_batch_size=tune.grid_search([256, 64])
|
||||||
|
>>> config.environment(env="CartPole-v1")
|
||||||
|
>>> tune.run(
|
||||||
|
>>> "R2D2",
|
||||||
|
>>> stop={"episode_reward_mean":200},
|
||||||
|
>>> config=config.to_dict()
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||||
|
>>> config = R2D2Config()
|
||||||
|
>>> print(config.exploration_config)
|
||||||
|
>>> explore_config = config.exploration_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "initial_epsilon": 1.0,
|
||||||
|
>>> "final_epsilon": 0.1,
|
||||||
|
>>> "epsilone_timesteps": 200000,
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||||
|
>>> .exploration(exploration_config=explore_config)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.dqn.r2d2 import R2D2Config
|
||||||
|
>>> config = R2D2Config()
|
||||||
|
>>> print(config.exploration_config)
|
||||||
|
>>> explore_config = config.exploration_config.update(
|
||||||
|
>>> {
|
||||||
|
>>> "type": "SoftQ",
|
||||||
|
>>> "temperature": [1.0],
|
||||||
|
>>> }
|
||||||
|
>>> )
|
||||||
|
>>> config.training(lr_schedule=[[1, 1e-3, [500, 5e-3]])\
|
||||||
|
>>> .exploration(exploration_config=explore_config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainer_class=None):
|
||||||
|
"""Initializes a ApexConfig instance."""
|
||||||
|
super().__init__(trainer_class=trainer_class or R2D2Trainer)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# __sphinx_doc_begin__
|
# __sphinx_doc_begin__
|
||||||
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
# R2D2-specific settings:
|
||||||
DQN_DEFAULT_CONFIG, # See keys in dqn.py, which are also supported.
|
self.zero_init_states = True
|
||||||
{
|
self.use_h_function = True
|
||||||
# Learning rate for adam optimizer.
|
self.h_function_epsilon = 1e-3
|
||||||
"lr": 1e-4,
|
|
||||||
# Discount factor.
|
|
||||||
"gamma": 0.997,
|
|
||||||
# Train batch size (in number of single timesteps).
|
|
||||||
"train_batch_size": 64,
|
|
||||||
# Adam epsilon hyper parameter
|
|
||||||
"adam_epsilon": 1e-3,
|
|
||||||
# Run in parallel by default.
|
|
||||||
"num_workers": 2,
|
|
||||||
# Batch mode must be complete_episodes.
|
|
||||||
"batch_mode": "complete_episodes",
|
|
||||||
|
|
||||||
# === Replay buffer ===
|
# R2D2 settings overriding DQN ones:
|
||||||
"replay_buffer_config": {
|
# .training()
|
||||||
|
self.adam_epsilon = 1e-3
|
||||||
|
self.lr = 1e-4
|
||||||
|
self.gamma = 0.997
|
||||||
|
self.train_batch_size = 64
|
||||||
|
self.target_network_update_freq = 2500
|
||||||
|
# R2D2 is using a buffer that stores sequences.
|
||||||
|
self.replay_buffer_config = {
|
||||||
"type": "MultiAgentReplayBuffer",
|
"type": "MultiAgentReplayBuffer",
|
||||||
# Specify prioritized replay by supplying a buffer type that supports
|
# Specify prioritized replay by supplying a buffer type that supports
|
||||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||||
|
@ -53,34 +115,54 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# used for loss calculation is `n - replay_burn_in` time steps
|
# used for loss calculation is `n - replay_burn_in` time steps
|
||||||
# (n=LSTM’s/attention net’s max_seq_len).
|
# (n=LSTM’s/attention net’s max_seq_len).
|
||||||
"replay_burn_in": 0,
|
"replay_burn_in": 0,
|
||||||
},
|
}
|
||||||
# If True, assume a zero-initialized state input (no matter where in
|
|
||||||
# the episode the sequence is located).
|
|
||||||
# If False, store the initial states along with each SampleBatch, use
|
|
||||||
# it (as initial state when running through the network for training),
|
|
||||||
# and update that initial state during training (from the internal
|
|
||||||
# state outputs of the immediately preceding sequence).
|
|
||||||
"zero_init_states": True,
|
|
||||||
|
|
||||||
# Whether to use the h-function from the paper [1] to scale target
|
# .rollouts()
|
||||||
# values in the R2D2-loss function:
|
self.num_workers = 2
|
||||||
# h(x) = sign(x)(|x| + 1 − 1) + εx
|
self.batch_mode = "complete_episodes"
|
||||||
"use_h_function": True,
|
|
||||||
# The epsilon parameter from the R2D2 loss function (only used
|
|
||||||
# if `use_h_function`=True.
|
|
||||||
"h_function_epsilon": 1e-3,
|
|
||||||
|
|
||||||
# Update the target network every `target_network_update_freq` sample steps.
|
|
||||||
"target_network_update_freq": 2500,
|
|
||||||
|
|
||||||
# Deprecated keys:
|
|
||||||
# Use config["replay_buffer_config"]["replay_burn_in"] instead
|
|
||||||
"burn_in": DEPRECATED_VALUE
|
|
||||||
},
|
|
||||||
_allow_unknown_configs=True,
|
|
||||||
)
|
|
||||||
# __sphinx_doc_end__
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
|
||||||
|
self.burn_in = DEPRECATED_VALUE
|
||||||
|
|
||||||
|
def training(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
zero_init_states: Optional[bool] = None,
|
||||||
|
use_h_function: Optional[bool] = None,
|
||||||
|
h_function_epsilon: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "R2D2Config":
|
||||||
|
"""Sets the training related configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zero_init_states: If True, assume a zero-initialized state input (no
|
||||||
|
matter where in the episode the sequence is located).
|
||||||
|
If False, store the initial states along with each SampleBatch, use
|
||||||
|
it (as initial state when running through the network for training),
|
||||||
|
and update that initial state during training (from the internal
|
||||||
|
state outputs of the immediately preceding sequence).
|
||||||
|
use_h_function: Whether to use the h-function from the paper [1] to scale
|
||||||
|
target values in the R2D2-loss function:
|
||||||
|
h(x) = sign(x)(|x| + 1 − 1) + εx
|
||||||
|
h_function_epsilon: The epsilon parameter from the R2D2 loss function (only
|
||||||
|
used if `use_h_function`=True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
This updated TrainerConfig object.
|
||||||
|
"""
|
||||||
|
# Pass kwargs onto super's `training()` method.
|
||||||
|
super().training(**kwargs)
|
||||||
|
|
||||||
|
if zero_init_states is not None:
|
||||||
|
self.zero_init_states = zero_init_states
|
||||||
|
if use_h_function is not None:
|
||||||
|
self.use_h_function = use_h_function
|
||||||
|
if h_function_epsilon is not None:
|
||||||
|
self.h_function_epsilon = h_function_epsilon
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Build an R2D2 trainer, which uses the framework specific Policy
|
# Build an R2D2 trainer, which uses the framework specific Policy
|
||||||
|
@ -103,7 +185,7 @@ class R2D2Trainer(DQNTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(DQNTrainer)
|
@override(DQNTrainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return R2D2_DEFAULT_CONFIG
|
return R2D2Config().to_dict()
|
||||||
|
|
||||||
@override(DQNTrainer)
|
@override(DQNTrainer)
|
||||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||||
|
@ -136,3 +218,20 @@ class R2D2Trainer(DQNTrainer):
|
||||||
|
|
||||||
if config.get("batch_mode") != "complete_episodes":
|
if config.get("batch_mode") != "complete_episodes":
|
||||||
raise ValueError("`batch_mode` must be 'complete_episodes'!")
|
raise ValueError("`batch_mode` must be 'complete_episodes'!")
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.agents.dqn.r2d2.R2D2Config instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(R2D2Config().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.dqn.r2d2.R2D2Config(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
R2D2_DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -21,17 +21,26 @@ class TestApexDQN(unittest.TestCase):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def test_apex_zero_workers(self):
|
def test_apex_zero_workers(self):
|
||||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
config = (
|
||||||
config["num_workers"] = 0
|
apex.ApexConfig()
|
||||||
config["num_gpus"] = 0
|
.rollouts(num_rollout_workers=0)
|
||||||
config["replay_buffer_config"] = {
|
.resources(num_gpus=0)
|
||||||
|
.training(
|
||||||
|
replay_buffer_config={
|
||||||
"learning_starts": 1000,
|
"learning_starts": 1000,
|
||||||
}
|
},
|
||||||
config["min_sample_timesteps_per_reporting"] = 100
|
optimizer={
|
||||||
config["min_time_s_per_reporting"] = 1
|
"num_replay_buffer_shards": 1,
|
||||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
},
|
||||||
|
)
|
||||||
|
.reporting(
|
||||||
|
min_sample_timesteps_per_reporting=100,
|
||||||
|
min_time_s_per_reporting=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
trainer = config.build(env="CartPole-v0")
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
|
@ -39,19 +48,26 @@ class TestApexDQN(unittest.TestCase):
|
||||||
|
|
||||||
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
|
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
|
||||||
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
||||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
config = (
|
||||||
config["num_workers"] = 3
|
apex.ApexConfig()
|
||||||
config["num_gpus"] = 0
|
.rollouts(num_rollout_workers=3)
|
||||||
config["replay_buffer_config"] = {
|
.resources(num_gpus=0)
|
||||||
|
.training(
|
||||||
|
replay_buffer_config={
|
||||||
"learning_starts": 1000,
|
"learning_starts": 1000,
|
||||||
}
|
},
|
||||||
config["min_sample_timesteps_per_reporting"] = 100
|
optimizer={
|
||||||
config["min_time_s_per_reporting"] = 1
|
"num_replay_buffer_shards": 1,
|
||||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
},
|
||||||
|
)
|
||||||
|
.reporting(
|
||||||
|
min_sample_timesteps_per_reporting=100,
|
||||||
|
min_time_s_per_reporting=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
plain_config = config.copy()
|
trainer = config.build(env="CartPole-v0")
|
||||||
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
|
|
||||||
|
|
||||||
# Test per-worker epsilon distribution.
|
# Test per-worker epsilon distribution.
|
||||||
infos = trainer.workers.foreach_policy(
|
infos = trainer.workers.foreach_policy(
|
||||||
|
@ -77,12 +93,21 @@ class TestApexDQN(unittest.TestCase):
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
def test_apex_lr_schedule(self):
|
def test_apex_lr_schedule(self):
|
||||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
config = (
|
||||||
config["num_workers"] = 1
|
apex.ApexConfig()
|
||||||
config["num_gpus"] = 0
|
.rollouts(
|
||||||
config["train_batch_size"] = 10
|
num_rollout_workers=1,
|
||||||
config["rollout_fragment_length"] = 5
|
rollout_fragment_length=5,
|
||||||
config["replay_buffer_config"] = {
|
)
|
||||||
|
.resources(num_gpus=0)
|
||||||
|
.training(
|
||||||
|
train_batch_size=10,
|
||||||
|
optimizer={
|
||||||
|
"num_replay_buffer_shards": 1,
|
||||||
|
# This makes sure learning schedule is checked every 10 timesteps.
|
||||||
|
"max_weight_sync_delay": 10,
|
||||||
|
},
|
||||||
|
replay_buffer_config={
|
||||||
"no_local_replay_buffer": True,
|
"no_local_replay_buffer": True,
|
||||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||||
"learning_starts": 10,
|
"learning_starts": 10,
|
||||||
|
@ -93,21 +118,18 @@ class TestApexDQN(unittest.TestCase):
|
||||||
"prioritized_replay_beta": 0.4,
|
"prioritized_replay_beta": 0.4,
|
||||||
# Epsilon to add to the TD errors when updating priorities.
|
# Epsilon to add to the TD errors when updating priorities.
|
||||||
"prioritized_replay_eps": 1e-6,
|
"prioritized_replay_eps": 1e-6,
|
||||||
}
|
},
|
||||||
config["min_sample_timesteps_per_reporting"] = 10
|
# Initial lr, doesn't really matter because of the schedule below.
|
||||||
|
lr=0.2,
|
||||||
|
lr_schedule=[[0, 0.2], [100, 0.001]],
|
||||||
|
)
|
||||||
|
.reporting(
|
||||||
|
min_sample_timesteps_per_reporting=10,
|
||||||
# 0 metrics reporting delay, this makes sure timestep,
|
# 0 metrics reporting delay, this makes sure timestep,
|
||||||
# which lr depends on, is updated after each worker rollout.
|
# which lr depends on, is updated after each worker rollout.
|
||||||
config["min_time_s_per_reporting"] = 0
|
min_time_s_per_reporting=0,
|
||||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
)
|
||||||
# This makes sure learning schedule is checked every 10 timesteps.
|
)
|
||||||
config["optimizer"]["max_weight_sync_delay"] = 10
|
|
||||||
# Initial lr, doesn't really matter because of the schedule below.
|
|
||||||
config["lr"] = 0.2
|
|
||||||
lr_schedule = [
|
|
||||||
[0, 0.2],
|
|
||||||
[100, 0.001],
|
|
||||||
]
|
|
||||||
config["lr_schedule"] = lr_schedule
|
|
||||||
|
|
||||||
def _step_n_times(trainer, n: int):
|
def _step_n_times(trainer, n: int):
|
||||||
"""Step trainer n times.
|
"""Step trainer n times.
|
||||||
|
@ -122,7 +144,7 @@ class TestApexDQN(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
trainer = config.build(env="CartPole-v0")
|
||||||
|
|
||||||
lr = _step_n_times(trainer, 5) # 50 timesteps
|
lr = _step_n_times(trainer, 5) # 50 timesteps
|
||||||
# Close to 0.2
|
# Close to 0.2
|
||||||
|
|
|
@ -47,26 +47,30 @@ class TestR2D2(unittest.TestCase):
|
||||||
|
|
||||||
def test_r2d2_compilation(self):
|
def test_r2d2_compilation(self):
|
||||||
"""Test whether a R2D2Trainer can be built on all frameworks."""
|
"""Test whether a R2D2Trainer can be built on all frameworks."""
|
||||||
config = dqn.R2D2_DEFAULT_CONFIG.copy()
|
config = (
|
||||||
config["num_workers"] = 0 # Run locally.
|
dqn.r2d2.R2D2Config()
|
||||||
|
.rollouts(num_rollout_workers=0)
|
||||||
|
.training(
|
||||||
|
model={
|
||||||
# Wrap with an LSTM and use a very simple base-model.
|
# Wrap with an LSTM and use a very simple base-model.
|
||||||
config["model"]["use_lstm"] = True
|
"use_lstm": True,
|
||||||
config["model"]["max_seq_len"] = 20
|
"max_seq_len": 20,
|
||||||
config["model"]["fcnet_hiddens"] = [32]
|
"fcnet_hiddens": [32],
|
||||||
config["model"]["lstm_cell_size"] = 64
|
"lstm_cell_size": 64,
|
||||||
|
},
|
||||||
config["replay_buffer_config"]["replay_burn_in"] = 20
|
dueling=False,
|
||||||
config["zero_init_states"] = True
|
lr=5e-4,
|
||||||
|
zero_init_states=True,
|
||||||
config["dueling"] = False
|
replay_buffer_config={"replay_burn_in": 20},
|
||||||
config["lr"] = 5e-4
|
)
|
||||||
config["exploration_config"]["epsilon_timesteps"] = 100000
|
.exploration(exploration_config={"epsilon_timesteps": 100000})
|
||||||
|
)
|
||||||
|
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
||||||
# Test building an R2D2 agent in all frameworks.
|
# Test building an R2D2 agent in all frameworks.
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0")
|
trainer = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
||||||
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
||||||
WeightedImportanceSampling,
|
WeightedImportanceSampling,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils import deep_update
|
from ray.rllib.utils import deep_update, merge_dicts
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
EnvConfigDict,
|
EnvConfigDict,
|
||||||
|
@ -715,7 +715,7 @@ class TrainerConfig:
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self.model = model
|
self.model = model
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
self.optimizer = optimizer
|
self.optimizer = merge_dicts(self.optimizer, optimizer)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DDPGConfig(SimpleQConfig):
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
>>> config = DDPGConfig()
|
>>> config = DDPGConfig()
|
||||||
>>> # Print out some default values.
|
>>> # Print out some default values.
|
||||||
>>> print(config.lr)
|
>>> print(config.lr) # doctest: +SKIP
|
||||||
0.0004
|
0.0004
|
||||||
>>> # Update the config object.
|
>>> # Update the config object.
|
||||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
|
# from ray.rllib.agents.dqn.apex import ApexTrainer, APEX_DEFAULT_CONFIG
|
||||||
from ray.rllib.algorithms.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
|
||||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApexTrainer",
|
"ApexTrainer",
|
||||||
"APEX_DEFAULT_CONFIG",
|
"APEX_DEFAULT_CONFIG",
|
||||||
|
"DQNConfig",
|
||||||
"DQNTFPolicy",
|
"DQNTFPolicy",
|
||||||
"DQNTorchPolicy",
|
"DQNTorchPolicy",
|
||||||
"DQNTrainer",
|
"DQNTrainer",
|
||||||
|
|
|
@ -114,15 +114,13 @@ class DQNConfig(SimpleQConfig):
|
||||||
>>> .exploration(exploration_config=explore_config)
|
>>> .exploration(exploration_config=explore_config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, trainer_class=None):
|
||||||
"""Initializes a DQNConfig instance."""
|
"""Initializes a DQNConfig instance."""
|
||||||
super().__init__()
|
super().__init__(trainer_class=trainer_class or DQNTrainer)
|
||||||
|
|
||||||
# DQN specific
|
# DQN specific config settings.
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# __sphinx_doc_begin__
|
# __sphinx_doc_begin__
|
||||||
#
|
|
||||||
self.trainer_class = DQNTrainer
|
|
||||||
self.num_atoms = 1
|
self.num_atoms = 1
|
||||||
self.v_min = -10.0
|
self.v_min = -10.0
|
||||||
self.v_max = 10.0
|
self.v_max = 10.0
|
||||||
|
@ -282,20 +280,6 @@ class DQNConfig(SimpleQConfig):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
|
||||||
class _deprecated_default_config(dict):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(DQNConfig().to_dict())
|
|
||||||
|
|
||||||
@Deprecated(
|
|
||||||
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
|
|
||||||
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
|
|
||||||
error=False,
|
|
||||||
)
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return super().__getitem__(item)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
||||||
"""Calculate the round robin weights for the rollout and train steps"""
|
"""Calculate the round robin weights for the rollout and train steps"""
|
||||||
if not config["training_intensity"]:
|
if not config["training_intensity"]:
|
||||||
|
@ -435,6 +419,20 @@ class DQNTrainer(SimpleQTrainer):
|
||||||
return train_results
|
return train_results
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.algorithms.dqn.DQNConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(DQNConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.algorithms.dqn.dqn.DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.algorithms.dqn.dqn.DQNConfig(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CONFIG = _deprecated_default_config()
|
DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,8 @@ class PGConfig(TrainerConfig):
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
>>> config = PGConfig()
|
>>> config = PGConfig()
|
||||||
>>> # Print out some default values.
|
>>> # Print out some default values.
|
||||||
>>> print(config.lr)
|
>>> print(config.lr) # doctest: +SKIP
|
||||||
... 0.0004
|
0.0004
|
||||||
>>> # Update the config object.
|
>>> # Update the config object.
|
||||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||||
>>> # Set the config object's env.
|
>>> # Set the config object's env.
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
# Counters for sampling and training steps (env- and agent steps).
|
# Counters for sampling and training steps (env- and agent steps).
|
||||||
NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
|
NUM_ENV_STEPS_SAMPLED = "num_env_steps_sampled"
|
||||||
NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
|
NUM_AGENT_STEPS_SAMPLED = "num_agent_steps_sampled"
|
||||||
|
NUM_ENV_STEPS_SAMPLED_THIS_ITER = "num_env_steps_sampled_this_iter"
|
||||||
|
NUM_AGENT_STEPS_SAMPLED_THIS_ITER = "num_agent_steps_sampled_this_iter"
|
||||||
NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
|
NUM_ENV_STEPS_TRAINED = "num_env_steps_trained"
|
||||||
NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
|
NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
|
||||||
|
NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter"
|
||||||
|
NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
|
||||||
|
|
||||||
# Counters to track target network updates.
|
# Counters to track target network updates.
|
||||||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||||
|
|
Loading…
Add table
Reference in a new issue