[rllib] Add torch_distributed_backend flag for DDPPO (#11362) (#11425)

This commit is contained in:
Philsik Chang 2020-10-22 10:30:42 +09:00 committed by GitHub
parent a4b418d30c
commit ede9347127
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -62,6 +62,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
# shouldn't need to adjust them. *** # shouldn't need to adjust them. ***
# DDPPO requires PyTorch distributed. # DDPPO requires PyTorch distributed.
"framework": "torch", "framework": "torch",
# The communication backend for PyTorch distributed.
"torch_distributed_backend": "gloo",
# Learning is no longer done on the driver process, so # Learning is no longer done on the driver process, so
# giving GPUs to the driver does not make sense! # giving GPUs to the driver does not make sense!
"num_gpus": 0, "num_gpus": 0,
@ -106,6 +108,9 @@ def validate_config(config):
if config["framework"] != "torch": if config["framework"] != "torch":
raise ValueError( raise ValueError(
"Distributed data parallel is only supported for PyTorch") "Distributed data parallel is only supported for PyTorch")
if config["torch_distributed_backend"] not in ("gloo", "mpi", "nccl"):
raise ValueError("Only gloo, mpi, or nccl is supported for "
"the backend of PyTorch distributed.")
# `num_gpus` must be 0/None, since all optimization happens on Workers. # `num_gpus` must be 0/None, since all optimization happens on Workers.
if config["num_gpus"]: if config["num_gpus"]:
raise ValueError( raise ValueError(
@ -148,7 +153,10 @@ def execution_plan(workers: WorkerSet,
# Get setup tasks in order to throw errors on failure. # Get setup tasks in order to throw errors on failure.
ray.get([ ray.get([
worker.setup_torch_data_parallel.remote( worker.setup_torch_data_parallel.remote(
address, i, len(workers.remote_workers()), backend="gloo") address,
i,
len(workers.remote_workers()),
backend=config["torch_distributed_backend"])
for i, worker in enumerate(workers.remote_workers()) for i, worker in enumerate(workers.remote_workers())
]) ])
logger.info("Torch process group init completed") logger.info("Torch process group init completed")