mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
from ray.rllib.agents.dqn.apex import ApexTrainer
|
|
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
|
|
DEFAULT_CONFIG as DDPG_CONFIG
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
from ray.util.iter import LocalIterator
|
|
|
|
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|
DDPG_CONFIG, # see also the options in ddpg.py, which are also supported
|
|
{
|
|
"optimizer": {
|
|
"max_weight_sync_delay": 400,
|
|
"num_replay_buffer_shards": 4,
|
|
"debug": False
|
|
},
|
|
"exploration_config": {
|
|
"type": "PerWorkerOrnsteinUhlenbeckNoise"
|
|
},
|
|
"n_step": 3,
|
|
"num_gpus": 0,
|
|
"num_workers": 32,
|
|
"buffer_size": 2000000,
|
|
# TODO(jungong) : update once Apex supports replay_buffer_config.
|
|
"replay_buffer_config": None,
|
|
"learning_starts": 50000,
|
|
"train_batch_size": 512,
|
|
"rollout_fragment_length": 50,
|
|
"target_network_update_freq": 500000,
|
|
"timesteps_per_iteration": 25000,
|
|
"worker_side_prioritization": True,
|
|
"min_iter_time_s": 30,
|
|
},
|
|
)
|
|
|
|
|
|
class ApexDDPGTrainer(DDPGTrainer):
|
|
@classmethod
|
|
@override(DDPGTrainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return APEX_DDPG_DEFAULT_CONFIG
|
|
|
|
@staticmethod
|
|
@override(DDPGTrainer)
|
|
def execution_plan(workers: WorkerSet, config: dict,
|
|
**kwargs) -> LocalIterator[dict]:
|
|
"""Use APEX-DQN's execution plan."""
|
|
return ApexTrainer.execution_plan(workers, config, **kwargs)
|