ray/rllib/agents/ddpg/apex.py

59 lines
2.2 KiB
Python
Raw Normal View History

from ray.rllib.agents.dqn.apex import ApexTrainer
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPG_CONFIG, # see also the options in ddpg.py, which are also supported
{
"optimizer": {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
},
"exploration_config": {
"type": "PerWorkerOrnsteinUhlenbeckNoise"
},
"n_step": 3,
"num_gpus": 0,
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"worker_side_prioritization": True,
"min_time_s_per_reporting": 30,
},
_allow_unknown_configs=True,
)
class ApexDDPGTrainer(DDPGTrainer):
@classmethod
@override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DDPG_DEFAULT_CONFIG
@staticmethod
@override(DDPGTrainer)
def execution_plan(workers: WorkerSet, config: dict,
**kwargs) -> LocalIterator[dict]:
"""Use APEX-DQN's execution plan."""
return ApexTrainer.execution_plan(workers, config, **kwargs)