mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
a4b418d30c
commit
ede9347127
1 changed files with 9 additions and 1 deletions
|
@ -62,6 +62,8 @@ DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
||||||
# shouldn't need to adjust them. ***
|
# shouldn't need to adjust them. ***
|
||||||
# DDPPO requires PyTorch distributed.
|
# DDPPO requires PyTorch distributed.
|
||||||
"framework": "torch",
|
"framework": "torch",
|
||||||
|
# The communication backend for PyTorch distributed.
|
||||||
|
"torch_distributed_backend": "gloo",
|
||||||
# Learning is no longer done on the driver process, so
|
# Learning is no longer done on the driver process, so
|
||||||
# giving GPUs to the driver does not make sense!
|
# giving GPUs to the driver does not make sense!
|
||||||
"num_gpus": 0,
|
"num_gpus": 0,
|
||||||
|
@ -106,6 +108,9 @@ def validate_config(config):
|
||||||
if config["framework"] != "torch":
|
if config["framework"] != "torch":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Distributed data parallel is only supported for PyTorch")
|
"Distributed data parallel is only supported for PyTorch")
|
||||||
|
if config["torch_distributed_backend"] not in ("gloo", "mpi", "nccl"):
|
||||||
|
raise ValueError("Only gloo, mpi, or nccl is supported for "
|
||||||
|
"the backend of PyTorch distributed.")
|
||||||
# `num_gpus` must be 0/None, since all optimization happens on Workers.
|
# `num_gpus` must be 0/None, since all optimization happens on Workers.
|
||||||
if config["num_gpus"]:
|
if config["num_gpus"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -148,7 +153,10 @@ def execution_plan(workers: WorkerSet,
|
||||||
# Get setup tasks in order to throw errors on failure.
|
# Get setup tasks in order to throw errors on failure.
|
||||||
ray.get([
|
ray.get([
|
||||||
worker.setup_torch_data_parallel.remote(
|
worker.setup_torch_data_parallel.remote(
|
||||||
address, i, len(workers.remote_workers()), backend="gloo")
|
address,
|
||||||
|
i,
|
||||||
|
len(workers.remote_workers()),
|
||||||
|
backend=config["torch_distributed_backend"])
|
||||||
for i, worker in enumerate(workers.remote_workers())
|
for i, worker in enumerate(workers.remote_workers())
|
||||||
])
|
])
|
||||||
logger.info("Torch process group init completed")
|
logger.info("Torch process group init completed")
|
||||||
|
|
Loading…
Add table
Reference in a new issue