ray/rllib/algorithms/ddpg/apex.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

107 lines
4.6 KiB
Python
Raw Normal View History

from typing import List
from ray.actor import ActorHandle
from ray.rllib.agents import Trainer
from ray.rllib.agents.dqn.apex import ApexTrainer
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPGConfig().to_dict(), # see also the options in ddpg.py, which are also supported
{
"optimizer": {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False,
},
"exploration_config": {"type": "PerWorkerOrnsteinUhlenbeckNoise"},
"n_step": 3,
"num_gpus": 0,
"num_workers": 32,
"replay_buffer_config": {
"capacity": 2000000,
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"learning_starts": 50000,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
"min_sample_timesteps_per_reporting": 25000,
"min_time_s_per_reporting": 30,
"training_intensity": 1,
# max number of inflight requests to each sampling worker
# see the AsyncRequestsManager class for more details
# Tuning these values is important when running experimens with large sample
# batches. If the sample batches are large in size, then there is the risk that
# the object store may fill up, causing the store to spill objects to disk.
# This can cause any asynchronous requests to become very slow, making your
# experiment run slowly. You can inspect the object store during your
# experiment via a call to ray memory on your headnode, and by using the ray
# dashboard. If you're seeing that the object store is filling up, turn down
# the number of remote requests in flight, or enable compression in your
# experiment of timesteps.
"max_requests_in_flight_per_sampler_worker": 2,
"max_requests_in_flight_per_replay_worker": float("inf"),
"timeout_s_sampler_manager": 0.0,
"timeout_s_replay_manager": 0.0,
},
_allow_unknown_configs=True,
)
class ApexDDPGTrainer(DDPGTrainer, ApexTrainer):
@classmethod
@override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DDPG_DEFAULT_CONFIG
@override(DDPGTrainer)
def setup(self, config: PartialTrainerConfigDict):
return ApexTrainer.setup(self, config)
@override(DDPGTrainer)
def training_iteration(self) -> ResultDict:
"""Use APEX-DQN's training iteration function."""
return ApexTrainer.training_iteration(self)
@override(Trainer)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):
"""Handle the failures of remote sampling workers
Args:
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@staticmethod
@override(DDPGTrainer)
def execution_plan(
workers: WorkerSet, config: dict, **kwargs
) -> LocalIterator[dict]:
"""Use APEX-DQN's execution plan."""
return ApexTrainer.execution_plan(workers, config, **kwargs)