[Train] Fix train.torch.get_device() for fractional GPU or multiple GPU per worker case (#23763)

Using the local rank as the device id only works if there is exactly 1 GPU per worker. Instead we should be using ray.get_gpu_ids() to determine which GPU device to use for the worker.
This commit is contained in:
Amog Kamsetty 2022-04-08 14:35:06 -07:00 committed by GitHub
parent 615bb7a503
commit 029517a037
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 2 deletions

View file

@ -38,6 +38,32 @@ def ray_start_1_cpu_1_gpu():
ray.shutdown()
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
def train_fn():
return train.torch.get_device().index
trainer = Trainer(
"torch",
num_workers=2,
use_gpu=True,
resources_per_worker={"GPU": num_gpus_per_worker},
)
trainer.start()
devices = trainer.run(train_fn)
trainer.shutdown()
if num_gpus_per_worker == 0.5:
assert devices == [0, 0]
elif num_gpus_per_worker == 1:
assert devices == [0, 1]
else:
raise RuntimeError(
"New parameter for this test has been added without checking that the "
"correct devices have been returned."
)
def test_torch_prepare_model(ray_start_4_cpus_2_gpus):
"""Tests if ``prepare_model`` correctly wraps in DDP."""

View file

@ -241,8 +241,14 @@ class TorchAccelerator(Accelerator):
def get_device(self) -> torch.device:
"""Gets the correct torch device to use for training."""
if torch.cuda.is_available():
rank = train.local_rank()
device = torch.device(f"cuda:{rank}")
gpu_ids = ray.get_gpu_ids()
if len(gpu_ids) > 0:
device_id = gpu_ids[0]
else:
# If called on the driver or outside of Ray Train, return the
# 0th device.
device_id = 0
device = torch.device(f"cuda:{device_id}")
else:
device = torch.device("cpu")