mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
88 lines
3.6 KiB
Python
88 lines
3.6 KiB
Python
![]() |
from ray.rllib.agents.ppo import ppo
|
||
|
from ray.rllib.agents.trainer import with_base_config
|
||
|
from ray.rllib.optimizers import TorchDistributedDataParallelOptimizer
|
||
|
"""Decentralized Distributed PPO implementation.
|
||
|
|
||
|
Unlike APPO or PPO, learning is no longer done centralized in the trainer
|
||
|
process. Instead, gradients are computed remotely on each rollout worker and
|
||
|
all-reduced to sync them at each mini-batch. This allows each worker's GPU
|
||
|
to be used both for sampling and for training.
|
||
|
|
||
|
DD-PPO should be used if you have envs that require GPUs to function, or have
|
||
|
a very large model that cannot be effectively optimized with the GPUs available
|
||
|
on a single machine (DD-PPO allows scaling to arbitrary numbers of GPUs across
|
||
|
multiple nodes, unlike PPO/APPO which is limited to GPUs on a single node).
|
||
|
|
||
|
Paper reference: https://arxiv.org/abs/1911.00357
|
||
|
Note that unlike the paper, we currently do not implement straggler mitigation.
|
||
|
"""
|
||
|
|
||
|
# yapf: disable
|
||
|
# __sphinx_doc_begin__
|
||
|
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
|
||
|
# During the sampling phase, each rollout worker will collect a batch
|
||
|
# `sample_batch_size * num_envs_per_worker` steps in size.
|
||
|
"sample_batch_size": 100,
|
||
|
# Vectorize the env (should enable by default since each worker has a GPU).
|
||
|
"num_envs_per_worker": 5,
|
||
|
# During the SGD phase, workers iterate over minibatches of this size.
|
||
|
# The effective minibatch size will be `sgd_minibatch_size * num_workers`.
|
||
|
"sgd_minibatch_size": 50,
|
||
|
# Number of SGD epochs per optimization round.
|
||
|
"num_sgd_iter": 10,
|
||
|
|
||
|
# *** WARNING: configs below are DDPPO overrides over PPO; you
|
||
|
# shouldn't need to adjust them. ***
|
||
|
"use_pytorch": True, # DDPPO requires PyTorch distributed.
|
||
|
"num_gpus": 0, # Learning is no longer done on the driver process, so
|
||
|
# giving GPUs to the driver does not make sense!
|
||
|
"num_gpus_per_worker": 1, # Each rollout worker gets a GPU.
|
||
|
"truncate_episodes": True, # Require evenly sized batches. Otherwise,
|
||
|
# collective allreduce could fail.
|
||
|
"train_batch_size": -1, # This is auto set based on sample batch size.
|
||
|
})
|
||
|
# __sphinx_doc_end__
|
||
|
# yapf: enable
|
||
|
|
||
|
|
||
|
def validate_config(config):
|
||
|
if config["train_batch_size"] == -1:
|
||
|
# Auto set.
|
||
|
config["train_batch_size"] = (
|
||
|
config["sample_batch_size"] * config["num_envs_per_worker"])
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Set sample_batch_size instead of train_batch_size for DDPPO.")
|
||
|
ppo.validate_config(config)
|
||
|
|
||
|
|
||
|
def make_distributed_allreduce_optimizer(workers, config):
|
||
|
if not config["use_pytorch"]:
|
||
|
raise ValueError(
|
||
|
"Distributed data parallel is only supported for PyTorch")
|
||
|
if config["num_gpus"]:
|
||
|
raise ValueError(
|
||
|
"When using distributed data parallel, you should set "
|
||
|
"num_gpus=0 since all optimization "
|
||
|
"is happening on workers. Enable GPUs for workers by setting "
|
||
|
"num_gpus_per_worker=1.")
|
||
|
if config["batch_mode"] != "truncate_episodes":
|
||
|
raise ValueError(
|
||
|
"Distributed data parallel requires truncate_episodes "
|
||
|
"batch mode.")
|
||
|
|
||
|
return TorchDistributedDataParallelOptimizer(
|
||
|
workers,
|
||
|
expected_batch_size=config["sample_batch_size"] *
|
||
|
config["num_envs_per_worker"],
|
||
|
num_sgd_iter=config["num_sgd_iter"],
|
||
|
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||
|
standardize_fields=["advantages"])
|
||
|
|
||
|
|
||
|
DDPPOTrainer = ppo.PPOTrainer.with_updates(
|
||
|
name="DDPPO",
|
||
|
default_config=DEFAULT_CONFIG,
|
||
|
make_policy_optimizer=make_distributed_allreduce_optimizer,
|
||
|
validate_config=validate_config)
|