mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
125 lines
4.6 KiB
Python
125 lines
4.6 KiB
Python
"""A more stable successor to TD3.
|
|
|
|
By default, this uses a near-identical configuration to that reported in the
|
|
TD3 paper.
|
|
"""
|
|
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.typing import AlgorithmConfigDict
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
|
|
|
|
|
class TD3Config(DDPGConfig):
|
|
"""Defines a configuration class from which a TD3 Algorithm can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
|
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Algorithm object from the config and run one training iteration.
|
|
>>> algo = config.build(env="Pendulum-v1")
|
|
>>> algo.train()
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
|
>>> from ray import tune
|
|
>>> config = TD3Config()
|
|
>>> # Print out some default values.
|
|
>>> print(config.lr)
|
|
0.0004
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
|
>>> # Set the config object's env.
|
|
>>> config.environment(env="Pendulum-v1")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "TD3",
|
|
... stop={"episode_reward_mean": 200},
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, algo_class=None):
|
|
"""Initializes a TD3Config instance."""
|
|
super().__init__(algo_class=algo_class or TD3)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
|
|
# Override some of DDPG/SimpleQ/Algorithm's default values with TD3-specific
|
|
# values.
|
|
|
|
# .training()
|
|
|
|
# largest changes: twin Q functions, delayed policy updates, target
|
|
# smoothing, no l2-regularization.
|
|
self.twin_q = True
|
|
self.policy_delay = 2
|
|
self.smooth_target_policy = True,
|
|
self.l2_reg = 0.0
|
|
# Different tau (affecting target network update).
|
|
self.tau = 5e-3
|
|
# Different batch size.
|
|
self.train_batch_size = 100
|
|
# No prioritized replay by default (we may want to change this at some
|
|
# point).
|
|
self.replay_buffer_config = {
|
|
"type": "MultiAgentReplayBuffer",
|
|
# Specify prioritized replay by supplying a buffer type that supports
|
|
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
|
"prioritized_replay": DEPRECATED_VALUE,
|
|
"capacity": 1000000,
|
|
"worker_side_prioritization": False,
|
|
}
|
|
# Number of timesteps to collect from rollout workers before we start
|
|
# sampling from replay buffers for learning. Whether we count this in agent
|
|
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
|
self.num_steps_sampled_before_learning_starts = 10000
|
|
|
|
# .exploration()
|
|
# TD3 uses Gaussian Noise by default.
|
|
self.exploration_config = {
|
|
# TD3 uses simple Gaussian noise on top of deterministic NN-output
|
|
# actions (after a possible pure random phase of n timesteps).
|
|
"type": "GaussianNoise",
|
|
# For how many timesteps should we return completely random
|
|
# actions, before we start adding (scaled) noise?
|
|
"random_timesteps": 10000,
|
|
# Gaussian stddev of action noise for exploration.
|
|
"stddev": 0.1,
|
|
# Scaling settings by which the Gaussian noise is scaled before
|
|
# being added to the actions. NOTE: The scale timesteps start only
|
|
# after(!) any random steps have been finished.
|
|
# By default, do not anneal over time (fixed 1.0).
|
|
"initial_scale": 1.0,
|
|
"final_scale": 1.0,
|
|
"scale_timesteps": 1,
|
|
}
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
class TD3(DDPG):
|
|
@classmethod
|
|
@override(DDPG)
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
|
return TD3Config().to_dict()
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.ddpg..td3.TD3Config instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(TD3Config().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.algorithms.ddpg.td3::TD3_DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.td3.td3::TD3Config(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
TD3_DEFAULT_CONFIG = _deprecated_default_config()
|