mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] TD3 config objects. (#25065)
This commit is contained in:
parent
09886d7ab8
commit
baf8c2fa1e
6 changed files with 133 additions and 63 deletions
|
@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
|||
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
||||
WeightedImportanceSampling,
|
||||
)
|
||||
from ray.rllib.utils import deep_update
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvConfigDict,
|
||||
|
@ -754,7 +755,16 @@ class TrainerConfig:
|
|||
if explore is not None:
|
||||
self.explore = explore
|
||||
if exploration_config is not None:
|
||||
self.exploration_config = exploration_config
|
||||
# Override entire `exploration_config` if `type` key changes.
|
||||
# Update, if `type` key remains the same or is not specified.
|
||||
new_exploration_config = deep_update(
|
||||
{"exploration_config": self.exploration_config},
|
||||
{"exploration_config": exploration_config},
|
||||
False,
|
||||
["exploration_config"],
|
||||
["exploration_config"],
|
||||
)
|
||||
self.exploration_config = new_exploration_config["exploration_config"]
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer
|
||||
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ddpg.td3 import TD3Trainer
|
||||
from ray.rllib.algorithms.ddpg.td3 import TD3Config, TD3Trainer
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -8,5 +8,6 @@ __all__ = [
|
|||
"DDPGConfig",
|
||||
"DDPGTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
"TD3Config",
|
||||
"TD3Trainer",
|
||||
]
|
||||
|
|
|
@ -21,7 +21,7 @@ class DDPGConfig(SimpleQConfig):
|
|||
>>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run one training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer = config.build(env="Pendulum-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
|
@ -34,7 +34,7 @@ class DDPGConfig(SimpleQConfig):
|
|||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> config.environment(env="Pendulum-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
|
|
|
@ -5,20 +5,79 @@ TD3 paper.
|
|||
"""
|
||||
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||
DDPGConfig().to_dict(),
|
||||
{
|
||||
# largest changes: twin Q functions, delayed policy updates, and target
|
||||
# smoothing
|
||||
"twin_q": True,
|
||||
"policy_delay": 2,
|
||||
"smooth_target_policy": True,
|
||||
"target_noise": 0.2,
|
||||
"target_noise_clip": 0.5,
|
||||
"exploration_config": {
|
||||
|
||||
class TD3Config(DDPGConfig):
|
||||
"""Defines a configuration class from which a TD3Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run one training iteration.
|
||||
>>> trainer = config.build(env="Pendulum-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||
>>> from ray import tune
|
||||
>>> config = TD3Config()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.lr)
|
||||
0.0004
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="Pendulum-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "TD3",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a TD3Config instance."""
|
||||
super().__init__(trainer_class=trainer_class or TD3Trainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Override some of DDPG/SimpleQ/Trainer's default values with TD3-specific
|
||||
# values.
|
||||
|
||||
# .training()
|
||||
|
||||
# largest changes: twin Q functions, delayed policy updates, target
|
||||
# smoothing, no l2-regularization.
|
||||
self.twin_q = True
|
||||
self.policy_delay = 2
|
||||
self.smooth_target_policy = True,
|
||||
self.l2_reg = 0.0
|
||||
# Different tau (affecting target network update).
|
||||
self.tau = 5e-3
|
||||
# Different batch size.
|
||||
self.train_batch_size = 100
|
||||
# No prioritized replay by default (we may want to change this at some
|
||||
# point).
|
||||
self.replay_buffer_config = {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": 1000000,
|
||||
"learning_starts": 10000,
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
|
||||
# .exploration()
|
||||
# TD3 uses Gaussian Noise by default.
|
||||
self.exploration_config = {
|
||||
# TD3 uses simple Gaussian noise on top of deterministic NN-output
|
||||
# actions (after a possible pure random phase of n timesteps).
|
||||
"type": "GaussianNoise",
|
||||
|
@ -34,40 +93,30 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
"initial_scale": 1.0,
|
||||
"final_scale": 1.0,
|
||||
"scale_timesteps": 1,
|
||||
},
|
||||
# other changes & things we want to keep fixed:
|
||||
# larger actor learning rate, no l2 regularisation, no Huber loss, etc.
|
||||
"actor_hiddens": [400, 300],
|
||||
"critic_hiddens": [400, 300],
|
||||
"n_step": 1,
|
||||
"gamma": 0.99,
|
||||
"actor_lr": 1e-3,
|
||||
"critic_lr": 1e-3,
|
||||
"l2_reg": 0.0,
|
||||
"tau": 5e-3,
|
||||
"train_batch_size": 100,
|
||||
"use_huber": False,
|
||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
||||
"target_network_update_freq": 0,
|
||||
"num_workers": 0,
|
||||
"num_gpus_per_worker": 0,
|
||||
"clip_rewards": False,
|
||||
"use_state_preprocessor": False,
|
||||
"replay_buffer_config": {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": 1000000,
|
||||
"learning_starts": 10000,
|
||||
"worker_side_prioritization": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TD3Trainer(DDPGTrainer):
|
||||
@classmethod
|
||||
@override(DDPGTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return TD3_DEFAULT_CONFIG
|
||||
return TD3Config().to_dict()
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(TD3Config().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.ddpg.td3.TD3Config(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
TD3_DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -25,12 +25,11 @@ class TestTD3(unittest.TestCase):
|
|||
|
||||
def test_td3_compilation(self):
|
||||
"""Test whether a TD3Trainer can be built with both frameworks."""
|
||||
config = td3.TD3_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config = td3.TD3Config()
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = td3.TD3Trainer(config=config, env="Pendulum-v1")
|
||||
trainer = config.build(env="Pendulum-v1")
|
||||
num_iterations = 1
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
@ -41,15 +40,23 @@ class TestTD3(unittest.TestCase):
|
|||
|
||||
def test_td3_exploration_and_with_random_prerun(self):
|
||||
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
|
||||
config = td3.TD3_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config = td3.TD3Config().environment(env="Pendulum-v1")
|
||||
no_random_init = config.exploration_config.copy()
|
||||
random_init = {
|
||||
# Act randomly at beginning ...
|
||||
"random_timesteps": 30,
|
||||
# Then act very closely to deterministic actions thereafter.
|
||||
"stddev": 0.001,
|
||||
"initial_scale": 0.001,
|
||||
"final_scale": 0.001,
|
||||
}
|
||||
obs = np.array([0.0, 0.1, -0.1])
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
lcl_config = config.copy()
|
||||
config.exploration(exploration_config=no_random_init)
|
||||
# Default GaussianNoise setup.
|
||||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
|
||||
trainer = config.build()
|
||||
# Setting explore=False should always return the same action.
|
||||
a_ = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, 1)
|
||||
|
@ -66,15 +73,8 @@ class TestTD3(unittest.TestCase):
|
|||
trainer.stop()
|
||||
|
||||
# Check randomness at beginning.
|
||||
lcl_config["exploration_config"] = {
|
||||
# Act randomly at beginning ...
|
||||
"random_timesteps": 30,
|
||||
# Then act very closely to deterministic actions thereafter.
|
||||
"stddev": 0.001,
|
||||
"initial_scale": 0.001,
|
||||
"final_scale": 0.001,
|
||||
}
|
||||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
|
||||
config.exploration(exploration_config=random_init)
|
||||
trainer = config.build()
|
||||
# ts=0 (get a deterministic action as per explore=False).
|
||||
deterministic_action = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, 1)
|
||||
|
|
|
@ -29,6 +29,7 @@ from ray.rllib.execution.train_ops import (
|
|||
multi_gpu_train_one_step,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils import deep_update
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.metrics import (
|
||||
|
@ -238,7 +239,16 @@ class SimpleQConfig(TrainerConfig):
|
|||
if target_network_update_freq is not None:
|
||||
self.target_network_update_freq = target_network_update_freq
|
||||
if replay_buffer_config is not None:
|
||||
self.replay_buffer_config = replay_buffer_config
|
||||
# Override entire `replay_buffer_config` if `type` key changes.
|
||||
# Update, if `type` key remains the same or is not specified.
|
||||
new_replay_buffer_config = deep_update(
|
||||
{"replay_buffer_config": self.replay_buffer_config},
|
||||
{"replay_buffer_config": replay_buffer_config},
|
||||
False,
|
||||
["replay_buffer_config"],
|
||||
["replay_buffer_config"],
|
||||
)
|
||||
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
|
||||
if store_buffer_in_checkpoints is not None:
|
||||
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
|
||||
if lr_schedule is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue