From baf8c2fa1eb3f52d5a8b174d1d19e2bb012397b0 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 23 May 2022 10:07:13 +0200 Subject: [PATCH] [RLlib] TD3 config objects. (#25065) --- rllib/agents/trainer_config.py | 12 ++- rllib/algorithms/ddpg/__init__.py | 3 +- rllib/algorithms/ddpg/ddpg.py | 4 +- rllib/algorithms/ddpg/td3.py | 133 ++++++++++++++++-------- rllib/algorithms/ddpg/tests/test_td3.py | 32 +++--- rllib/algorithms/dqn/simple_q.py | 12 ++- 6 files changed, 133 insertions(+), 63 deletions(-) diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index 3fa1618b2..492cd06dc 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -19,6 +19,7 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling from ray.rllib.offline.estimators.weighted_importance_sampling import ( WeightedImportanceSampling, ) +from ray.rllib.utils import deep_update from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.typing import ( EnvConfigDict, @@ -754,7 +755,16 @@ class TrainerConfig: if explore is not None: self.explore = explore if exploration_config is not None: - self.exploration_config = exploration_config + # Override entire `exploration_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_exploration_config = deep_update( + {"exploration_config": self.exploration_config}, + {"exploration_config": exploration_config}, + False, + ["exploration_config"], + ["exploration_config"], + ) + self.exploration_config = new_exploration_config["exploration_config"] return self diff --git a/rllib/algorithms/ddpg/__init__.py b/rllib/algorithms/ddpg/__init__.py index 99b0e36d7..ec473cb41 100644 --- a/rllib/algorithms/ddpg/__init__.py +++ b/rllib/algorithms/ddpg/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.algorithms.ddpg.apex import ApexDDPGTrainer from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer, DEFAULT_CONFIG -from ray.rllib.algorithms.ddpg.td3 import TD3Trainer +from ray.rllib.algorithms.ddpg.td3 import TD3Config, TD3Trainer __all__ = [ @@ -8,5 +8,6 @@ __all__ = [ "DDPGConfig", "DDPGTrainer", "DEFAULT_CONFIG", + "TD3Config", "TD3Trainer", ] diff --git a/rllib/algorithms/ddpg/ddpg.py b/rllib/algorithms/ddpg/ddpg.py index caf55902f..8ad2d7b42 100644 --- a/rllib/algorithms/ddpg/ddpg.py +++ b/rllib/algorithms/ddpg/ddpg.py @@ -21,7 +21,7 @@ class DDPGConfig(SimpleQConfig): >>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run one training iteration. - >>> trainer = config.build(env="CartPole-v1") + >>> trainer = config.build(env="Pendulum-v1") >>> trainer.train() Example: @@ -34,7 +34,7 @@ class DDPGConfig(SimpleQConfig): >>> # Update the config object. >>> config.training(lr=tune.grid_search([0.001, 0.0001])) >>> # Set the config object's env. - >>> config.environment(env="CartPole-v1") + >>> config.environment(env="Pendulum-v1") >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.run( diff --git a/rllib/algorithms/ddpg/td3.py b/rllib/algorithms/ddpg/td3.py index 37dfb317e..c86ad419c 100644 --- a/rllib/algorithms/ddpg/td3.py +++ b/rllib/algorithms/ddpg/td3.py @@ -5,20 +5,79 @@ TD3 paper. """ from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.deprecation import DEPRECATED_VALUE -TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs( - DDPGConfig().to_dict(), - { - # largest changes: twin Q functions, delayed policy updates, and target - # smoothing - "twin_q": True, - "policy_delay": 2, - "smooth_target_policy": True, - "target_noise": 0.2, - "target_noise_clip": 0.5, - "exploration_config": { + +class TD3Config(DDPGConfig): + """Defines a configuration class from which a TD3Trainer can be built. + + Example: + >>> from ray.rllib.algorithms.ddpg.td3 import TD3Config + >>> config = TD3Config().training(lr=0.01).resources(num_gpus=1) + >>> print(config.to_dict()) + >>> # Build a Trainer object from the config and run one training iteration. + >>> trainer = config.build(env="Pendulum-v1") + >>> trainer.train() + + Example: + >>> from ray.rllib.algorithms.ddpg.td3 import TD3Config + >>> from ray import tune + >>> config = TD3Config() + >>> # Print out some default values. + >>> print(config.lr) + 0.0004 + >>> # Update the config object. + >>> config.training(lr=tune.grid_search([0.001, 0.0001])) + >>> # Set the config object's env. + >>> config.environment(env="Pendulum-v1") + >>> # Use to_dict() to get the old-style python config dict + >>> # when running with tune. + >>> tune.run( + ... "TD3", + ... stop={"episode_reward_mean": 200}, + ... config=config.to_dict(), + ... ) + """ + + def __init__(self, trainer_class=None): + """Initializes a TD3Config instance.""" + super().__init__(trainer_class=trainer_class or TD3Trainer) + + # fmt: off + # __sphinx_doc_begin__ + + # Override some of DDPG/SimpleQ/Trainer's default values with TD3-specific + # values. + + # .training() + + # largest changes: twin Q functions, delayed policy updates, target + # smoothing, no l2-regularization. + self.twin_q = True + self.policy_delay = 2 + self.smooth_target_policy = True, + self.l2_reg = 0.0 + # Different tau (affecting target network update). + self.tau = 5e-3 + # Different batch size. + self.train_batch_size = 100 + # No prioritized replay by default (we may want to change this at some + # point). + self.replay_buffer_config = { + "type": "MultiAgentReplayBuffer", + # Specify prioritized replay by supplying a buffer type that supports + # prioritization, for example: MultiAgentPrioritizedReplayBuffer. + "prioritized_replay": DEPRECATED_VALUE, + "capacity": 1000000, + "learning_starts": 10000, + "worker_side_prioritization": False, + } + + # .exploration() + # TD3 uses Gaussian Noise by default. + self.exploration_config = { # TD3 uses simple Gaussian noise on top of deterministic NN-output # actions (after a possible pure random phase of n timesteps). "type": "GaussianNoise", @@ -34,40 +93,30 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs( "initial_scale": 1.0, "final_scale": 1.0, "scale_timesteps": 1, - }, - # other changes & things we want to keep fixed: - # larger actor learning rate, no l2 regularisation, no Huber loss, etc. - "actor_hiddens": [400, 300], - "critic_hiddens": [400, 300], - "n_step": 1, - "gamma": 0.99, - "actor_lr": 1e-3, - "critic_lr": 1e-3, - "l2_reg": 0.0, - "tau": 5e-3, - "train_batch_size": 100, - "use_huber": False, - # Update the target network every `target_network_update_freq` sample timesteps. - "target_network_update_freq": 0, - "num_workers": 0, - "num_gpus_per_worker": 0, - "clip_rewards": False, - "use_state_preprocessor": False, - "replay_buffer_config": { - "type": "MultiAgentReplayBuffer", - # Specify prioritized replay by supplying a buffer type that supports - # prioritization, for example: MultiAgentPrioritizedReplayBuffer. - "prioritized_replay": DEPRECATED_VALUE, - "capacity": 1000000, - "learning_starts": 10000, - "worker_side_prioritization": False, - }, - }, -) + } + # __sphinx_doc_end__ + # fmt: on class TD3Trainer(DDPGTrainer): @classmethod @override(DDPGTrainer) def get_default_config(cls) -> TrainerConfigDict: - return TD3_DEFAULT_CONFIG + return TD3Config().to_dict() + + +# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead! +class _deprecated_default_config(dict): + def __init__(self): + super().__init__(TD3Config().to_dict()) + + @Deprecated( + old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG", + new="ray.rllib.algorithms.ddpg.td3.TD3Config(...)", + error=False, + ) + def __getitem__(self, item): + return super().__getitem__(item) + + +TD3_DEFAULT_CONFIG = _deprecated_default_config() diff --git a/rllib/algorithms/ddpg/tests/test_td3.py b/rllib/algorithms/ddpg/tests/test_td3.py index 700ccfcc7..ad1f123d2 100644 --- a/rllib/algorithms/ddpg/tests/test_td3.py +++ b/rllib/algorithms/ddpg/tests/test_td3.py @@ -25,12 +25,11 @@ class TestTD3(unittest.TestCase): def test_td3_compilation(self): """Test whether a TD3Trainer can be built with both frameworks.""" - config = td3.TD3_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. + config = td3.TD3Config() # Test against all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): - trainer = td3.TD3Trainer(config=config, env="Pendulum-v1") + trainer = config.build(env="Pendulum-v1") num_iterations = 1 for i in range(num_iterations): results = trainer.train() @@ -41,15 +40,23 @@ class TestTD3(unittest.TestCase): def test_td3_exploration_and_with_random_prerun(self): """Tests TD3's Exploration (w/ random actions for n timesteps).""" - config = td3.TD3_DEFAULT_CONFIG.copy() - config["num_workers"] = 0 # Run locally. + config = td3.TD3Config().environment(env="Pendulum-v1") + no_random_init = config.exploration_config.copy() + random_init = { + # Act randomly at beginning ... + "random_timesteps": 30, + # Then act very closely to deterministic actions thereafter. + "stddev": 0.001, + "initial_scale": 0.001, + "final_scale": 0.001, + } obs = np.array([0.0, 0.1, -0.1]) # Test against all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): - lcl_config = config.copy() + config.exploration(exploration_config=no_random_init) # Default GaussianNoise setup. - trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1") + trainer = config.build() # Setting explore=False should always return the same action. a_ = trainer.compute_single_action(obs, explore=False) check(trainer.get_policy().global_timestep, 1) @@ -66,15 +73,8 @@ class TestTD3(unittest.TestCase): trainer.stop() # Check randomness at beginning. - lcl_config["exploration_config"] = { - # Act randomly at beginning ... - "random_timesteps": 30, - # Then act very closely to deterministic actions thereafter. - "stddev": 0.001, - "initial_scale": 0.001, - "final_scale": 0.001, - } - trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v1") + config.exploration(exploration_config=random_init) + trainer = config.build() # ts=0 (get a deterministic action as per explore=False). deterministic_action = trainer.compute_single_action(obs, explore=False) check(trainer.get_policy().global_timestep, 1) diff --git a/rllib/algorithms/dqn/simple_q.py b/rllib/algorithms/dqn/simple_q.py index 41d22cbe0..f50920b8d 100644 --- a/rllib/algorithms/dqn/simple_q.py +++ b/rllib/algorithms/dqn/simple_q.py @@ -29,6 +29,7 @@ from ray.rllib.execution.train_ops import ( multi_gpu_train_one_step, ) from ray.rllib.policy.policy import Policy +from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE from ray.rllib.utils.metrics import ( @@ -238,7 +239,16 @@ class SimpleQConfig(TrainerConfig): if target_network_update_freq is not None: self.target_network_update_freq = target_network_update_freq if replay_buffer_config is not None: - self.replay_buffer_config = replay_buffer_config + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] if store_buffer_in_checkpoints is not None: self.store_buffer_in_checkpoints = store_buffer_in_checkpoints if lr_schedule is not None: