mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
a4b418d30c
commit
ede9347127
1 changed files with 9 additions and 1 deletions
|
@ -62,6 +62,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
|||
# shouldn't need to adjust them. ***
|
||||
# DDPPO requires PyTorch distributed.
|
||||
"framework": "torch",
|
||||
# The communication backend for PyTorch distributed.
|
||||
"torch_distributed_backend": "gloo",
|
||||
# Learning is no longer done on the driver process, so
|
||||
# giving GPUs to the driver does not make sense!
|
||||
"num_gpus": 0,
|
||||
|
@ -106,6 +108,9 @@ def validate_config(config):
|
|||
if config["framework"] != "torch":
|
||||
raise ValueError(
|
||||
"Distributed data parallel is only supported for PyTorch")
|
||||
if config["torch_distributed_backend"] not in ("gloo", "mpi", "nccl"):
|
||||
raise ValueError("Only gloo, mpi, or nccl is supported for "
|
||||
"the backend of PyTorch distributed.")
|
||||
# `num_gpus` must be 0/None, since all optimization happens on Workers.
|
||||
if config["num_gpus"]:
|
||||
raise ValueError(
|
||||
|
@ -148,7 +153,10 @@ def execution_plan(workers: WorkerSet,
|
|||
# Get setup tasks in order to throw errors on failure.
|
||||
ray.get([
|
||||
worker.setup_torch_data_parallel.remote(
|
||||
address, i, len(workers.remote_workers()), backend="gloo")
|
||||
address,
|
||||
i,
|
||||
len(workers.remote_workers()),
|
||||
backend=config["torch_distributed_backend"])
|
||||
for i, worker in enumerate(workers.remote_workers())
|
||||
])
|
||||
logger.info("Torch process group init completed")
|
||||
|
|
Loading…
Add table
Reference in a new issue