mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Train] Fix train.torch.get_device()
for fractional GPU or multiple GPU per worker case (#23763)
Using the local rank as the device id only works if there is exactly 1 GPU per worker. Instead we should be using ray.get_gpu_ids() to determine which GPU device to use for the worker.
This commit is contained in:
parent
615bb7a503
commit
029517a037
2 changed files with 34 additions and 2 deletions
|
@ -38,6 +38,32 @@ def ray_start_1_cpu_1_gpu():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
|
||||
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
|
||||
def train_fn():
|
||||
return train.torch.get_device().index
|
||||
|
||||
trainer = Trainer(
|
||||
"torch",
|
||||
num_workers=2,
|
||||
use_gpu=True,
|
||||
resources_per_worker={"GPU": num_gpus_per_worker},
|
||||
)
|
||||
trainer.start()
|
||||
devices = trainer.run(train_fn)
|
||||
trainer.shutdown()
|
||||
|
||||
if num_gpus_per_worker == 0.5:
|
||||
assert devices == [0, 0]
|
||||
elif num_gpus_per_worker == 1:
|
||||
assert devices == [0, 1]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"New parameter for this test has been added without checking that the "
|
||||
"correct devices have been returned."
|
||||
)
|
||||
|
||||
|
||||
def test_torch_prepare_model(ray_start_4_cpus_2_gpus):
|
||||
"""Tests if ``prepare_model`` correctly wraps in DDP."""
|
||||
|
||||
|
|
|
@ -241,8 +241,14 @@ class TorchAccelerator(Accelerator):
|
|||
def get_device(self) -> torch.device:
|
||||
"""Gets the correct torch device to use for training."""
|
||||
if torch.cuda.is_available():
|
||||
rank = train.local_rank()
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
if len(gpu_ids) > 0:
|
||||
device_id = gpu_ids[0]
|
||||
else:
|
||||
# If called on the driver or outside of Ray Train, return the
|
||||
# 0th device.
|
||||
device_id = 0
|
||||
device = torch.device(f"cuda:{device_id}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue