2021-11-30 18:05:44 +01:00
|
|
|
from ray.rllib.agents.dqn.apex import ApexTrainer
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
|
2021-12-01 10:52:12 +01:00
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
from ray.util.iter import LocalIterator
|
2022-05-17 13:43:49 +02:00
|
|
|
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
|
|
|
from ray.rllib.utils.typing import ResultDict
|
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
2018-04-19 22:36:29 -07:00
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
2022-05-12 16:12:42 +02:00
|
|
|
DDPGConfig().to_dict(), # see also the options in ddpg.py, which are also supported
|
2018-05-20 16:15:06 -07:00
|
|
|
{
|
2020-03-01 20:53:35 +01:00
|
|
|
"optimizer": {
|
|
|
|
"max_weight_sync_delay": 400,
|
|
|
|
"num_replay_buffer_shards": 4,
|
|
|
|
"debug": False,
|
|
|
|
},
|
|
|
|
"exploration_config": {"type": "PerWorkerOrnsteinUhlenbeckNoise"},
|
2018-06-09 00:21:35 -07:00
|
|
|
"n_step": 3,
|
2018-11-13 18:00:03 -08:00
|
|
|
"num_gpus": 0,
|
2018-06-09 00:21:35 -07:00
|
|
|
"num_workers": 32,
|
2022-05-17 13:43:49 +02:00
|
|
|
"replay_buffer_config": {
|
|
|
|
"capacity": 2000000,
|
|
|
|
"no_local_replay_buffer": True,
|
|
|
|
# Specify prioritized replay by supplying a buffer type that supports
|
|
|
|
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
|
|
|
"prioritized_replay": DEPRECATED_VALUE,
|
|
|
|
"learning_starts": 50000,
|
|
|
|
# Whether all shards of the replay buffer must be co-located
|
|
|
|
# with the learner process (running the execution plan).
|
|
|
|
# This is preferred b/c the learner process should have quick
|
|
|
|
# access to the data from the buffer shards, avoiding network
|
|
|
|
# traffic each time samples from the buffer(s) are drawn.
|
|
|
|
# Set this to False for relaxing this constraint and allowing
|
|
|
|
# replay shards to be created on node(s) other than the one
|
|
|
|
# on which the learner is located.
|
|
|
|
"replay_buffer_shards_colocated_with_driver": True,
|
|
|
|
"worker_side_prioritization": True,
|
|
|
|
},
|
2018-06-09 00:21:35 -07:00
|
|
|
"train_batch_size": 512,
|
2020-03-14 12:05:04 -07:00
|
|
|
"rollout_fragment_length": 50,
|
2022-05-17 10:31:07 +02:00
|
|
|
# Update the target network every `target_network_update_freq` sample timesteps.
|
2018-06-09 00:21:35 -07:00
|
|
|
"target_network_update_freq": 500000,
|
2022-05-02 12:51:14 +02:00
|
|
|
"min_sample_timesteps_per_reporting": 25000,
|
2022-01-25 14:16:58 +01:00
|
|
|
"min_time_s_per_reporting": 30,
|
2018-05-20 16:15:06 -07:00
|
|
|
},
|
2022-01-13 10:52:55 +01:00
|
|
|
_allow_unknown_configs=True,
|
2018-05-20 16:15:06 -07:00
|
|
|
)
|
2018-04-19 22:36:29 -07:00
|
|
|
|
2021-12-01 10:52:12 +01:00
|
|
|
|
2022-05-17 13:43:49 +02:00
|
|
|
class ApexDDPGTrainer(DDPGTrainer, ApexTrainer):
|
2021-12-01 10:52:12 +01:00
|
|
|
@classmethod
|
|
|
|
@override(DDPGTrainer)
|
|
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
|
|
return APEX_DDPG_DEFAULT_CONFIG
|
|
|
|
|
2022-05-17 13:43:49 +02:00
|
|
|
@override(DDPGTrainer)
|
|
|
|
def setup(self, config: PartialTrainerConfigDict):
|
|
|
|
return ApexTrainer.setup(self, config)
|
|
|
|
|
|
|
|
@override(DDPGTrainer)
|
|
|
|
def training_iteration(self) -> ResultDict:
|
|
|
|
"""Use APEX-DQN's training iteration function."""
|
|
|
|
return ApexTrainer.training_iteration(self)
|
|
|
|
|
2021-12-01 10:52:12 +01:00
|
|
|
@staticmethod
|
|
|
|
@override(DDPGTrainer)
|
|
|
|
def execution_plan(
|
|
|
|
workers: WorkerSet, config: dict, **kwargs
|
|
|
|
) -> LocalIterator[dict]:
|
|
|
|
"""Use APEX-DQN's execution plan."""
|
|
|
|
return ApexTrainer.execution_plan(workers, config, **kwargs)
|