ray/rllib/algorithms/apex_ddpg/apex_ddpg.py

232 lines
10 KiB
Python

from typing import List, Optional
from ray.actor import ActorHandle
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
)
from ray.util.iter import LocalIterator
class ApexDDPGConfig(DDPGConfig):
"""Defines a configuration class from which an ApexDDPG Trainer can be built.
Example:
>>> from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig
>>> config = ApexDDPGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run one training iteration.
>>> trainer = config.build(env="Pendulum-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig
>>> from ray import tune
>>> config = ApexDDPGConfig()
>>> # Print out some default values.
>>> print(config.lr) # doctest: +SKIP
0.0004
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="Pendulum-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "APEX_DDPG",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, algo_class=None):
"""Initializes an ApexDDPGConfig instance."""
super().__init__(algo_class=algo_class or ApexDDPG)
# fmt: off
# __sphinx_doc_begin__
# ApexDDPG-specific settings.
self.optimizer = {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False,
}
self.max_requests_in_flight_per_sampler_worker = 2
self.max_requests_in_flight_per_replay_worker = float("inf")
self.timeout_s_sampler_manager = 0.0
self.timeout_s_replay_manager = 0.0
# Override some of Trainer/DDPG's default values with ApexDDPG-specific values.
self.n_step = 3
self.exploration_config = {"type": "PerWorkerOrnsteinUhlenbeckNoise"}
self.num_gpus = 0
self.num_workers = 32
self.min_sample_timesteps_per_iteration = 25000
self.min_time_s_per_iteration = 30
self.train_batch_size = 512
self.rollout_fragment_length = 50
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000,
"no_local_replay_buffer": True,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
# Whether to compute priorities on workers.
"worker_side_prioritization": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
}
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
self.num_steps_sampled_before_learning_starts = 50000
self.target_network_update_freq = 500000
self.training_intensity = 1
# __sphinx_doc_end__
# fmt: on
@override(DDPGConfig)
def training(
self,
*,
optimizer: Optional[dict] = None,
max_requests_in_flight_per_sampler_worker: Optional[int] = None,
max_requests_in_flight_per_replay_worker: Optional[int] = None,
timeout_s_sampler_manager: Optional[float] = None,
timeout_s_replay_manager: Optional[float] = None,
**kwargs,
) -> "ApexDDPGConfig":
"""Sets the training related configuration.
Args:
optimizer: Apex-DDPG optimizer settings (dict). Set the number of reply
buffer shards in here via the `num_replay_buffer_shards` key
(default=4).
max_requests_in_flight_per_sampler_worker: Max number of inflight requests
to each sampling worker. See the AsyncRequestsManager class for more
details. Tuning these values is important when running experimens with
large sample batches, where there is the risk that the object store may
fill up, causing spilling of objects to disk. This can cause any
asynchronous requests to become very slow, making your experiment run
slow as well. You can inspect the object store during your experiment
via a call to ray memory on your headnode, and by using the ray
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight, or enable compression
in your experiment of timesteps.
max_requests_in_flight_per_replay_worker: Max number of inflight requests
to each replay (shard) worker. See the AsyncRequestsManager class for
more details. Tuning these values is important when running experimens
with large sample batches, where there is the risk that the object store
may fill up, causing spilling of objects to disk. This can cause any
asynchronous requests to become very slow, making your experiment run
slow as well. You can inspect the object store during your experiment
via a call to ray memory on your headnode, and by using the ray
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight, or enable compression
in your experiment of timesteps.
timeout_s_sampler_manager: The timeout for waiting for sampling results
for workers -- typically if this is too low, the manager won't be able
to retrieve ready sampling results.
timeout_s_replay_manager: The timeout for waiting for replay worker
results -- typically if this is too low, the manager won't be able to
retrieve ready replay requests.
Returns:
This updated ApexDDPGConfig object.
"""
super().training(**kwargs)
if optimizer is not None:
self.optimizer = optimizer
if max_requests_in_flight_per_sampler_worker is not None:
self.max_requests_in_flight_per_sampler_worker = (
max_requests_in_flight_per_sampler_worker
)
if max_requests_in_flight_per_replay_worker is not None:
self.max_requests_in_flight_per_replay_worker = (
max_requests_in_flight_per_replay_worker
)
if timeout_s_sampler_manager is not None:
self.timeout_s_sampler_manager = timeout_s_sampler_manager
if timeout_s_replay_manager is not None:
self.timeout_s_replay_manager = timeout_s_replay_manager
return self
class ApexDDPG(DDPG, ApexDQN):
@classmethod
@override(DDPG)
def get_default_config(cls) -> AlgorithmConfigDict:
return ApexDDPGConfig().to_dict()
@override(DDPG)
def setup(self, config: PartialAlgorithmConfigDict):
return ApexDQN.setup(self, config)
@override(DDPG)
def training_step(self) -> ResultDict:
"""Use APEX-DQN's training iteration function."""
return ApexDQN.training_step(self)
@override(Algorithm)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):
"""Handle the failures of remote sampling workers
Args:
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(
removed_workers, remove_in_flight_requests=True
)
self._sampling_actor_manager.add_workers(new_workers)
@staticmethod
@override(DDPG)
def execution_plan(
workers: WorkerSet, config: dict, **kwargs
) -> LocalIterator[dict]:
"""Use APEX-DQN's execution plan."""
return ApexDQN.execution_plan(workers, config, **kwargs)
# Deprecated: Use ray.rllib.algorithms.apex_ddpg.ApexDDPGConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(ApexDDPGConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.ddpg.apex.APEX_DDPG_DEFAULT_CONFIG",
new="ray.rllib.algorithms.apex_ddpg.apex_ddpg::ApexDDPGConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
APEX_DDPG_DEFAULT_CONFIG = _deprecated_default_config()