ray/rllib/algorithms/ddpg/apex.py
2022-05-25 17:54:08 +02:00

106 lines
4.6 KiB
Python

from typing import List
from ray.actor import ActorHandle
from ray.rllib.agents import Trainer
from ray.rllib.agents.dqn.apex import ApexTrainer
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPGConfig().to_dict(), # see also the options in ddpg.py, which are also supported
{
"optimizer": {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False,
},
"exploration_config": {"type": "PerWorkerOrnsteinUhlenbeckNoise"},
"n_step": 3,
"num_gpus": 0,
"num_workers": 32,
"replay_buffer_config": {
"capacity": 2000000,
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"learning_starts": 50000,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
"min_sample_timesteps_per_reporting": 25000,
"min_time_s_per_reporting": 30,
"training_intensity": 1,
# max number of inflight requests to each sampling worker
# see the AsyncRequestsManager class for more details
# Tuning these values is important when running experimens with large sample
# batches. If the sample batches are large in size, then there is the risk that
# the object store may fill up, causing the store to spill objects to disk.
# This can cause any asynchronous requests to become very slow, making your
# experiment run slowly. You can inspect the object store during your
# experiment via a call to ray memory on your headnode, and by using the ray
# dashboard. If you're seeing that the object store is filling up, turn down
# the number of remote requests in flight, or enable compression in your
# experiment of timesteps.
"max_requests_in_flight_per_sampler_worker": 2,
"max_requests_in_flight_per_replay_worker": float("inf"),
"timeout_s_sampler_manager": 0.0,
"timeout_s_replay_manager": 0.0,
},
_allow_unknown_configs=True,
)
class ApexDDPGTrainer(DDPGTrainer, ApexTrainer):
@classmethod
@override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DDPG_DEFAULT_CONFIG
@override(DDPGTrainer)
def setup(self, config: PartialTrainerConfigDict):
return ApexTrainer.setup(self, config)
@override(DDPGTrainer)
def training_iteration(self) -> ResultDict:
"""Use APEX-DQN's training iteration function."""
return ApexTrainer.training_iteration(self)
@override(Trainer)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):
"""Handle the failures of remote sampling workers
Args:
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@staticmethod
@override(DDPGTrainer)
def execution_plan(
workers: WorkerSet, config: dict, **kwargs
) -> LocalIterator[dict]:
"""Use APEX-DQN's execution plan."""
return ApexTrainer.execution_plan(workers, config, **kwargs)