[RLlib] TD3 config objects. (#25065)

This commit is contained in:
Sven Mika 2022-05-23 10:07:13 +02:00 committed by GitHub
parent 09886d7ab8
commit baf8c2fa1e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 63 deletions

View file

@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils import deep_update
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import (
EnvConfigDict,
@ -754,7 +755,16 @@ class TrainerConfig:
if explore is not None:
self.explore = explore
if exploration_config is not None:
self.exploration_config = exploration_config
# Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_exploration_config = deep_update(
{"exploration_config": self.exploration_config},
{"exploration_config": exploration_config},
False,
["exploration_config"],
["exploration_config"],
)
self.exploration_config = new_exploration_config["exploration_config"]
return self

View file

@ -1,6 +1,6 @@
from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.ddpg.td3 import TD3Trainer
from ray.rllib.algorithms.ddpg.td3 import TD3Config, TD3Trainer
__all__ = [
@ -8,5 +8,6 @@ __all__ = [
"DDPGConfig",
"DDPGTrainer",
"DEFAULT_CONFIG",
"TD3Config",
"TD3Trainer",
]

View file

@ -21,7 +21,7 @@ class DDPGConfig(SimpleQConfig):
>>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run one training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer = config.build(env="Pendulum-v1")
>>> trainer.train()
Example:
@ -34,7 +34,7 @@ class DDPGConfig(SimpleQConfig):
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> config.environment(env="Pendulum-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(

View file

@ -5,20 +5,79 @@ TD3 paper.
"""
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPGConfig().to_dict(),
{
# largest changes: twin Q functions, delayed policy updates, and target
# smoothing
"twin_q": True,
"policy_delay": 2,
"smooth_target_policy": True,
"target_noise": 0.2,
"target_noise_clip": 0.5,
"exploration_config": {
class TD3Config(DDPGConfig):
"""Defines a configuration class from which a TD3Trainer can be built.
Example:
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run one training iteration.
>>> trainer = config.build(env="Pendulum-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
>>> from ray import tune
>>> config = TD3Config()
>>> # Print out some default values.
>>> print(config.lr)
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="Pendulum-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "TD3",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, trainer_class=None):
"""Initializes a TD3Config instance."""
super().__init__(trainer_class=trainer_class or TD3Trainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of DDPG/SimpleQ/Trainer's default values with TD3-specific
# values.
# .training()
# largest changes: twin Q functions, delayed policy updates, target
# smoothing, no l2-regularization.
self.twin_q = True
self.policy_delay = 2
self.smooth_target_policy = True,
self.l2_reg = 0.0
# Different tau (affecting target network update).
self.tau = 5e-3
# Different batch size.
self.train_batch_size = 100
# No prioritized replay by default (we may want to change this at some
# point).
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
}
# .exploration()
# TD3 uses Gaussian Noise by default.
self.exploration_config = {
# TD3 uses simple Gaussian noise on top of deterministic NN-output
# actions (after a possible pure random phase of n timesteps).
"type": "GaussianNoise",
@ -34,40 +93,30 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"initial_scale": 1.0,
"final_scale": 1.0,
"scale_timesteps": 1,
},
# other changes & things we want to keep fixed:
# larger actor learning rate, no l2 regularisation, no Huber loss, etc.
"actor_hiddens": [400, 300],
"critic_hiddens": [400, 300],
"n_step": 1,
"gamma": 0.99,
"actor_lr": 1e-3,
"critic_lr": 1e-3,
"l2_reg": 0.0,
"tau": 5e-3,
"train_batch_size": 100,
"use_huber": False,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 0,
"num_workers": 0,
"num_gpus_per_worker": 0,
"clip_rewards": False,
"use_state_preprocessor": False,
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
},
},
)
}
# __sphinx_doc_end__
# fmt: on
class TD3Trainer(DDPGTrainer):
@classmethod
@override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return TD3_DEFAULT_CONFIG
return TD3Config().to_dict()
# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(TD3Config().to_dict())
@Deprecated(
old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG",
new="ray.rllib.algorithms.ddpg.td3.TD3Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
TD3_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -25,12 +25,11 @@ class TestTD3(unittest.TestCase):
def test_td3_compilation(self):
"""Test whether a TD3Trainer can be built with both frameworks."""
config = td3.TD3_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config = td3.TD3Config()
# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = td3.TD3Trainer(config=config, env="Pendulum-v1")
trainer = config.build(env="Pendulum-v1")
num_iterations = 1
for i in range(num_iterations):
results = trainer.train()
@ -41,15 +40,23 @@ class TestTD3(unittest.TestCase):
def test_td3_exploration_and_with_random_prerun(self):
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
config = td3.TD3_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config = td3.TD3Config().environment(env="Pendulum-v1")
no_random_init = config.exploration_config.copy()
random_init = {
# Act randomly at beginning ...
"random_timesteps": 30,
# Then act very closely to deterministic actions thereafter.
"stddev": 0.001,
"initial_scale": 0.001,
"final_scale": 0.001,
}
obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
lcl_config = config.copy()
config.exploration(exploration_config=no_random_init)
# Default GaussianNoise setup.
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
trainer = config.build()
# Setting explore=False should always return the same action.
a_ = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1)
@ -66,15 +73,8 @@ class TestTD3(unittest.TestCase):
trainer.stop()
# Check randomness at beginning.
lcl_config["exploration_config"] = {
# Act randomly at beginning ...
"random_timesteps": 30,
# Then act very closely to deterministic actions thereafter.
"stddev": 0.001,
"initial_scale": 0.001,
"final_scale": 0.001,
}
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
config.exploration(exploration_config=random_init)
trainer = config.build()
# ts=0 (get a deterministic action as per explore=False).
deterministic_action = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1)

View file

@ -29,6 +29,7 @@ from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import (
@ -238,7 +239,16 @@ class SimpleQConfig(TrainerConfig):
if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
# Override entire `replay_buffer_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_replay_buffer_config = deep_update(
{"replay_buffer_config": self.replay_buffer_config},
{"replay_buffer_config": replay_buffer_config},
False,
["replay_buffer_config"],
["replay_buffer_config"],
)
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
if store_buffer_in_checkpoints is not None:
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
if lr_schedule is not None: