[RLlib] TD3 config objects. (#25065)

This commit is contained in:
Sven Mika 2022-05-23 10:07:13 +02:00 committed by GitHub
parent 09886d7ab8
commit baf8c2fa1e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 63 deletions

View file

@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import ( from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling, WeightedImportanceSampling,
) )
from ray.rllib.utils import deep_update
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
EnvConfigDict, EnvConfigDict,
@ -754,7 +755,16 @@ class TrainerConfig:
if explore is not None: if explore is not None:
self.explore = explore self.explore = explore
if exploration_config is not None: if exploration_config is not None:
self.exploration_config = exploration_config # Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_exploration_config = deep_update(
{"exploration_config": self.exploration_config},
{"exploration_config": exploration_config},
False,
["exploration_config"],
["exploration_config"],
)
self.exploration_config = new_exploration_config["exploration_config"]
return self return self

View file

@ -1,6 +1,6 @@
from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.ddpg.td3 import TD3Trainer from ray.rllib.algorithms.ddpg.td3 import TD3Config, TD3Trainer
__all__ = [ __all__ = [
@ -8,5 +8,6 @@ __all__ = [
"DDPGConfig", "DDPGConfig",
"DDPGTrainer", "DDPGTrainer",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
"TD3Config",
"TD3Trainer", "TD3Trainer",
] ]

View file

@ -21,7 +21,7 @@ class DDPGConfig(SimpleQConfig):
>>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1) >>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict()) >>> print(config.to_dict())
>>> # Build a Trainer object from the config and run one training iteration. >>> # Build a Trainer object from the config and run one training iteration.
>>> trainer = config.build(env="CartPole-v1") >>> trainer = config.build(env="Pendulum-v1")
>>> trainer.train() >>> trainer.train()
Example: Example:
@ -34,7 +34,7 @@ class DDPGConfig(SimpleQConfig):
>>> # Update the config object. >>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env. >>> # Set the config object's env.
>>> config.environment(env="CartPole-v1") >>> config.environment(env="Pendulum-v1")
>>> # Use to_dict() to get the old-style python config dict >>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune. >>> # when running with tune.
>>> tune.run( >>> tune.run(

View file

@ -5,20 +5,79 @@ TD3 paper.
""" """
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE
TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPGConfig().to_dict(), class TD3Config(DDPGConfig):
{ """Defines a configuration class from which a TD3Trainer can be built.
# largest changes: twin Q functions, delayed policy updates, and target
# smoothing Example:
"twin_q": True, >>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
"policy_delay": 2, >>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
"smooth_target_policy": True, >>> print(config.to_dict())
"target_noise": 0.2, >>> # Build a Trainer object from the config and run one training iteration.
"target_noise_clip": 0.5, >>> trainer = config.build(env="Pendulum-v1")
"exploration_config": { >>> trainer.train()
Example:
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
>>> from ray import tune
>>> config = TD3Config()
>>> # Print out some default values.
>>> print(config.lr)
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="Pendulum-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "TD3",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, trainer_class=None):
"""Initializes a TD3Config instance."""
super().__init__(trainer_class=trainer_class or TD3Trainer)
# fmt: off
# __sphinx_doc_begin__
# Override some of DDPG/SimpleQ/Trainer's default values with TD3-specific
# values.
# .training()
# largest changes: twin Q functions, delayed policy updates, target
# smoothing, no l2-regularization.
self.twin_q = True
self.policy_delay = 2
self.smooth_target_policy = True,
self.l2_reg = 0.0
# Different tau (affecting target network update).
self.tau = 5e-3
# Different batch size.
self.train_batch_size = 100
# No prioritized replay by default (we may want to change this at some
# point).
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
}
# .exploration()
# TD3 uses Gaussian Noise by default.
self.exploration_config = {
# TD3 uses simple Gaussian noise on top of deterministic NN-output # TD3 uses simple Gaussian noise on top of deterministic NN-output
# actions (after a possible pure random phase of n timesteps). # actions (after a possible pure random phase of n timesteps).
"type": "GaussianNoise", "type": "GaussianNoise",
@ -34,40 +93,30 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"initial_scale": 1.0, "initial_scale": 1.0,
"final_scale": 1.0, "final_scale": 1.0,
"scale_timesteps": 1, "scale_timesteps": 1,
}, }
# other changes & things we want to keep fixed: # __sphinx_doc_end__
# larger actor learning rate, no l2 regularisation, no Huber loss, etc. # fmt: on
"actor_hiddens": [400, 300],
"critic_hiddens": [400, 300],
"n_step": 1,
"gamma": 0.99,
"actor_lr": 1e-3,
"critic_lr": 1e-3,
"l2_reg": 0.0,
"tau": 5e-3,
"train_batch_size": 100,
"use_huber": False,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 0,
"num_workers": 0,
"num_gpus_per_worker": 0,
"clip_rewards": False,
"use_state_preprocessor": False,
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
},
},
)
class TD3Trainer(DDPGTrainer): class TD3Trainer(DDPGTrainer):
@classmethod @classmethod
@override(DDPGTrainer) @override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> TrainerConfigDict:
return TD3_DEFAULT_CONFIG return TD3Config().to_dict()
# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(TD3Config().to_dict())
@Deprecated(
old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG",
new="ray.rllib.algorithms.ddpg.td3.TD3Config(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
TD3_DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -25,12 +25,11 @@ class TestTD3(unittest.TestCase):
def test_td3_compilation(self): def test_td3_compilation(self):
"""Test whether a TD3Trainer can be built with both frameworks.""" """Test whether a TD3Trainer can be built with both frameworks."""
config = td3.TD3_DEFAULT_CONFIG.copy() config = td3.TD3Config()
config["num_workers"] = 0 # Run locally.
# Test against all frameworks. # Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
trainer = td3.TD3Trainer(config=config, env="Pendulum-v1") trainer = config.build(env="Pendulum-v1")
num_iterations = 1 num_iterations = 1
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = trainer.train()
@ -41,15 +40,23 @@ class TestTD3(unittest.TestCase):
def test_td3_exploration_and_with_random_prerun(self): def test_td3_exploration_and_with_random_prerun(self):
"""Tests TD3's Exploration (w/ random actions for n timesteps).""" """Tests TD3's Exploration (w/ random actions for n timesteps)."""
config = td3.TD3_DEFAULT_CONFIG.copy() config = td3.TD3Config().environment(env="Pendulum-v1")
config["num_workers"] = 0 # Run locally. no_random_init = config.exploration_config.copy()
random_init = {
# Act randomly at beginning ...
"random_timesteps": 30,
# Then act very closely to deterministic actions thereafter.
"stddev": 0.001,
"initial_scale": 0.001,
"final_scale": 0.001,
}
obs = np.array([0.0, 0.1, -0.1]) obs = np.array([0.0, 0.1, -0.1])
# Test against all frameworks. # Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
lcl_config = config.copy() config.exploration(exploration_config=no_random_init)
# Default GaussianNoise setup. # Default GaussianNoise setup.
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1") trainer = config.build()
# Setting explore=False should always return the same action. # Setting explore=False should always return the same action.
a_ = trainer.compute_single_action(obs, explore=False) a_ = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1) check(trainer.get_policy().global_timestep, 1)
@ -66,15 +73,8 @@ class TestTD3(unittest.TestCase):
trainer.stop() trainer.stop()
# Check randomness at beginning. # Check randomness at beginning.
lcl_config["exploration_config"] = { config.exploration(exploration_config=random_init)
# Act randomly at beginning ... trainer = config.build()
"random_timesteps": 30,
# Then act very closely to deterministic actions thereafter.
"stddev": 0.001,
"initial_scale": 0.001,
"final_scale": 0.001,
}
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1")
# ts=0 (get a deterministic action as per explore=False). # ts=0 (get a deterministic action as per explore=False).
deterministic_action = trainer.compute_single_action(obs, explore=False) deterministic_action = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1) check(trainer.get_policy().global_timestep, 1)

View file

@ -29,6 +29,7 @@ from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step, multi_gpu_train_one_step,
) )
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import ( from ray.rllib.utils.metrics import (
@ -238,7 +239,16 @@ class SimpleQConfig(TrainerConfig):
if target_network_update_freq is not None: if target_network_update_freq is not None:
self.target_network_update_freq = target_network_update_freq self.target_network_update_freq = target_network_update_freq
if replay_buffer_config is not None: if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config # Override entire `replay_buffer_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_replay_buffer_config = deep_update(
{"replay_buffer_config": self.replay_buffer_config},
{"replay_buffer_config": replay_buffer_config},
False,
["replay_buffer_config"],
["replay_buffer_config"],
)
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
if store_buffer_in_checkpoints is not None: if store_buffer_in_checkpoints is not None:
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
if lr_schedule is not None: if lr_schedule is not None: