mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] TD3 config objects. (#25065)
This commit is contained in:
parent
09886d7ab8
commit
baf8c2fa1e
6 changed files with 133 additions and 63 deletions
|
@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
||||||
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
||||||
WeightedImportanceSampling,
|
WeightedImportanceSampling,
|
||||||
)
|
)
|
||||||
|
from ray.rllib.utils import deep_update
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
EnvConfigDict,
|
EnvConfigDict,
|
||||||
|
@ -754,7 +755,16 @@ class TrainerConfig:
|
||||||
if explore is not None:
|
if explore is not None:
|
||||||
self.explore = explore
|
self.explore = explore
|
||||||
if exploration_config is not None:
|
if exploration_config is not None:
|
||||||
self.exploration_config = exploration_config
|
# Override entire `exploration_config` if `type` key changes.
|
||||||
|
# Update, if `type` key remains the same or is not specified.
|
||||||
|
new_exploration_config = deep_update(
|
||||||
|
{"exploration_config": self.exploration_config},
|
||||||
|
{"exploration_config": exploration_config},
|
||||||
|
False,
|
||||||
|
["exploration_config"],
|
||||||
|
["exploration_config"],
|
||||||
|
)
|
||||||
|
self.exploration_config = new_exploration_config["exploration_config"]
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer
|
from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer
|
||||||
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG
|
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG
|
||||||
from ray.rllib.algorithms.ddpg.td3 import TD3Trainer
|
from ray.rllib.algorithms.ddpg.td3 import TD3Config, TD3Trainer
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -8,5 +8,6 @@ __all__ = [
|
||||||
"DDPGConfig",
|
"DDPGConfig",
|
||||||
"DDPGTrainer",
|
"DDPGTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
|
"TD3Config",
|
||||||
"TD3Trainer",
|
"TD3Trainer",
|
||||||
]
|
]
|
||||||
|
|
|
@ -21,7 +21,7 @@ class DDPGConfig(SimpleQConfig):
|
||||||
>>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1)
|
>>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1)
|
||||||
>>> print(config.to_dict())
|
>>> print(config.to_dict())
|
||||||
>>> # Build a Trainer object from the config and run one training iteration.
|
>>> # Build a Trainer object from the config and run one training iteration.
|
||||||
>>> trainer = config.build(env="CartPole-v1")
|
>>> trainer = config.build(env="Pendulum-v1")
|
||||||
>>> trainer.train()
|
>>> trainer.train()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -34,7 +34,7 @@ class DDPGConfig(SimpleQConfig):
|
||||||
>>> # Update the config object.
|
>>> # Update the config object.
|
||||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||||
>>> # Set the config object's env.
|
>>> # Set the config object's env.
|
||||||
>>> config.environment(env="CartPole-v1")
|
>>> config.environment(env="Pendulum-v1")
|
||||||
>>> # Use to_dict() to get the old-style python config dict
|
>>> # Use to_dict() to get the old-style python config dict
|
||||||
>>> # when running with tune.
|
>>> # when running with tune.
|
||||||
>>> tune.run(
|
>>> tune.run(
|
||||||
|
|
|
@ -5,20 +5,79 @@ TD3 paper.
|
||||||
"""
|
"""
|
||||||
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
|
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
|
|
||||||
TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|
||||||
DDPGConfig().to_dict(),
|
class TD3Config(DDPGConfig):
|
||||||
{
|
"""Defines a configuration class from which a TD3Trainer can be built.
|
||||||
# largest changes: twin Q functions, delayed policy updates, and target
|
|
||||||
# smoothing
|
Example:
|
||||||
"twin_q": True,
|
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||||
"policy_delay": 2,
|
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
||||||
"smooth_target_policy": True,
|
>>> print(config.to_dict())
|
||||||
"target_noise": 0.2,
|
>>> # Build a Trainer object from the config and run one training iteration.
|
||||||
"target_noise_clip": 0.5,
|
>>> trainer = config.build(env="Pendulum-v1")
|
||||||
"exploration_config": {
|
>>> trainer.train()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||||
|
>>> from ray import tune
|
||||||
|
>>> config = TD3Config()
|
||||||
|
>>> # Print out some default values.
|
||||||
|
>>> print(config.lr)
|
||||||
|
0.0004
|
||||||
|
>>> # Update the config object.
|
||||||
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
||||||
|
>>> # Set the config object's env.
|
||||||
|
>>> config.environment(env="Pendulum-v1")
|
||||||
|
>>> # Use to_dict() to get the old-style python config dict
|
||||||
|
>>> # when running with tune.
|
||||||
|
>>> tune.run(
|
||||||
|
... "TD3",
|
||||||
|
... stop={"episode_reward_mean": 200},
|
||||||
|
... config=config.to_dict(),
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainer_class=None):
|
||||||
|
"""Initializes a TD3Config instance."""
|
||||||
|
super().__init__(trainer_class=trainer_class or TD3Trainer)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
|
||||||
|
# Override some of DDPG/SimpleQ/Trainer's default values with TD3-specific
|
||||||
|
# values.
|
||||||
|
|
||||||
|
# .training()
|
||||||
|
|
||||||
|
# largest changes: twin Q functions, delayed policy updates, target
|
||||||
|
# smoothing, no l2-regularization.
|
||||||
|
self.twin_q = True
|
||||||
|
self.policy_delay = 2
|
||||||
|
self.smooth_target_policy = True,
|
||||||
|
self.l2_reg = 0.0
|
||||||
|
# Different tau (affecting target network update).
|
||||||
|
self.tau = 5e-3
|
||||||
|
# Different batch size.
|
||||||
|
self.train_batch_size = 100
|
||||||
|
# No prioritized replay by default (we may want to change this at some
|
||||||
|
# point).
|
||||||
|
self.replay_buffer_config = {
|
||||||
|
"type": "MultiAgentReplayBuffer",
|
||||||
|
# Specify prioritized replay by supplying a buffer type that supports
|
||||||
|
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||||
|
"prioritized_replay": DEPRECATED_VALUE,
|
||||||
|
"capacity": 1000000,
|
||||||
|
"learning_starts": 10000,
|
||||||
|
"worker_side_prioritization": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# .exploration()
|
||||||
|
# TD3 uses Gaussian Noise by default.
|
||||||
|
self.exploration_config = {
|
||||||
# TD3 uses simple Gaussian noise on top of deterministic NN-output
|
# TD3 uses simple Gaussian noise on top of deterministic NN-output
|
||||||
# actions (after a possible pure random phase of n timesteps).
|
# actions (after a possible pure random phase of n timesteps).
|
||||||
"type": "GaussianNoise",
|
"type": "GaussianNoise",
|
||||||
|
@ -34,40 +93,30 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||||
"initial_scale": 1.0,
|
"initial_scale": 1.0,
|
||||||
"final_scale": 1.0,
|
"final_scale": 1.0,
|
||||||
"scale_timesteps": 1,
|
"scale_timesteps": 1,
|
||||||
},
|
}
|
||||||
# other changes & things we want to keep fixed:
|
# __sphinx_doc_end__
|
||||||
# larger actor learning rate, no l2 regularisation, no Huber loss, etc.
|
# fmt: on
|
||||||
"actor_hiddens": [400, 300],
|
|
||||||
"critic_hiddens": [400, 300],
|
|
||||||
"n_step": 1,
|
|
||||||
"gamma": 0.99,
|
|
||||||
"actor_lr": 1e-3,
|
|
||||||
"critic_lr": 1e-3,
|
|
||||||
"l2_reg": 0.0,
|
|
||||||
"tau": 5e-3,
|
|
||||||
"train_batch_size": 100,
|
|
||||||
"use_huber": False,
|
|
||||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
|
||||||
"target_network_update_freq": 0,
|
|
||||||
"num_workers": 0,
|
|
||||||
"num_gpus_per_worker": 0,
|
|
||||||
"clip_rewards": False,
|
|
||||||
"use_state_preprocessor": False,
|
|
||||||
"replay_buffer_config": {
|
|
||||||
"type": "MultiAgentReplayBuffer",
|
|
||||||
# Specify prioritized replay by supplying a buffer type that supports
|
|
||||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
|
||||||
"prioritized_replay": DEPRECATED_VALUE,
|
|
||||||
"capacity": 1000000,
|
|
||||||
"learning_starts": 10000,
|
|
||||||
"worker_side_prioritization": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TD3Trainer(DDPGTrainer):
|
class TD3Trainer(DDPGTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(DDPGTrainer)
|
@override(DDPGTrainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return TD3_DEFAULT_CONFIG
|
return TD3Config().to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(TD3Config().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.algorithms.ddpg.td3.TD3Config(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
TD3_DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -25,12 +25,11 @@ class TestTD3(unittest.TestCase):
|
||||||
|
|
||||||
def test_td3_compilation(self):
|
def test_td3_compilation(self):
|
||||||
"""Test whether a TD3Trainer can be built with both frameworks."""
|
"""Test whether a TD3Trainer can be built with both frameworks."""
|
||||||
config = td3.TD3_DEFAULT_CONFIG.copy()
|
config = td3.TD3Config()
|
||||||
config["num_workers"] = 0 # Run locally.
|
|
||||||
|
|
||||||
# Test against all frameworks.
|
# Test against all frameworks.
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
trainer = td3.TD3Trainer(config=config, env="Pendulum-v1")
|
trainer = config.build(env="Pendulum-v1")
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
|
@ -41,15 +40,23 @@ class TestTD3(unittest.TestCase):
|
||||||
|
|
||||||
def test_td3_exploration_and_with_random_prerun(self):
|
def test_td3_exploration_and_with_random_prerun(self):
|
||||||
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
|
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
|
||||||
config = td3.TD3_DEFAULT_CONFIG.copy()
|
config = td3.TD3Config().environment(env="Pendulum-v1")
|
||||||
config["num_workers"] = 0 # Run locally.
|
no_random_init = config.exploration_config.copy()
|
||||||
|
random_init = {
|
||||||
|
# Act randomly at beginning ...
|
||||||
|
"random_timesteps": 30,
|
||||||
|
# Then act very closely to deterministic actions thereafter.
|
||||||
|
"stddev": 0.001,
|
||||||
|
"initial_scale": 0.001,
|
||||||
|
"final_scale": 0.001,
|
||||||
|
}
|
||||||
obs = np.array([0.0, 0.1, -0.1])
|
obs = np.array([0.0, 0.1, -0.1])
|
||||||
|
|
||||||
# Test against all frameworks.
|
# Test against all frameworks.
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
lcl_config = config.copy()
|
config.exploration(exploration_config=no_random_init)
|
||||||
# Default GaussianNoise setup.
|
# Default GaussianNoise setup.
|
||||||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
|
trainer = config.build()
|
||||||
# Setting explore=False should always return the same action.
|
# Setting explore=False should always return the same action.
|
||||||
a_ = trainer.compute_single_action(obs, explore=False)
|
a_ = trainer.compute_single_action(obs, explore=False)
|
||||||
check(trainer.get_policy().global_timestep, 1)
|
check(trainer.get_policy().global_timestep, 1)
|
||||||
|
@ -66,15 +73,8 @@ class TestTD3(unittest.TestCase):
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
# Check randomness at beginning.
|
# Check randomness at beginning.
|
||||||
lcl_config["exploration_config"] = {
|
config.exploration(exploration_config=random_init)
|
||||||
# Act randomly at beginning ...
|
trainer = config.build()
|
||||||
"random_timesteps": 30,
|
|
||||||
# Then act very closely to deterministic actions thereafter.
|
|
||||||
"stddev": 0.001,
|
|
||||||
"initial_scale": 0.001,
|
|
||||||
"final_scale": 0.001,
|
|
||||||
}
|
|
||||||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
|
|
||||||
# ts=0 (get a deterministic action as per explore=False).
|
# ts=0 (get a deterministic action as per explore=False).
|
||||||
deterministic_action = trainer.compute_single_action(obs, explore=False)
|
deterministic_action = trainer.compute_single_action(obs, explore=False)
|
||||||
check(trainer.get_policy().global_timestep, 1)
|
check(trainer.get_policy().global_timestep, 1)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from ray.rllib.execution.train_ops import (
|
||||||
multi_gpu_train_one_step,
|
multi_gpu_train_one_step,
|
||||||
)
|
)
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
|
from ray.rllib.utils import deep_update
|
||||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.metrics import (
|
from ray.rllib.utils.metrics import (
|
||||||
|
@ -238,7 +239,16 @@ class SimpleQConfig(TrainerConfig):
|
||||||
if target_network_update_freq is not None:
|
if target_network_update_freq is not None:
|
||||||
self.target_network_update_freq = target_network_update_freq
|
self.target_network_update_freq = target_network_update_freq
|
||||||
if replay_buffer_config is not None:
|
if replay_buffer_config is not None:
|
||||||
self.replay_buffer_config = replay_buffer_config
|
# Override entire `replay_buffer_config` if `type` key changes.
|
||||||
|
# Update, if `type` key remains the same or is not specified.
|
||||||
|
new_replay_buffer_config = deep_update(
|
||||||
|
{"replay_buffer_config": self.replay_buffer_config},
|
||||||
|
{"replay_buffer_config": replay_buffer_config},
|
||||||
|
False,
|
||||||
|
["replay_buffer_config"],
|
||||||
|
["replay_buffer_config"],
|
||||||
|
)
|
||||||
|
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
|
||||||
if store_buffer_in_checkpoints is not None:
|
if store_buffer_in_checkpoints is not None:
|
||||||
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
|
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
|
||||||
if lr_schedule is not None:
|
if lr_schedule is not None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue