mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[sgd] support operator.device (#12056)
This commit is contained in:
parent
380df89069
commit
f10cef93c7
2 changed files with 17 additions and 13 deletions
|
@ -74,16 +74,16 @@ class DistributedTorchRunner(TorchRunner):
|
|||
This helps avoid timeouts due to creator functions (perhaps
|
||||
downloading data or models).
|
||||
"""
|
||||
device_ids = None
|
||||
device = torch.device("cpu")
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
device_ids = self.get_device_ids()
|
||||
device = self.get_device()
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
world_rank=self.world_rank,
|
||||
local_rank=self.local_rank,
|
||||
is_distributed=True,
|
||||
device_ids=device_ids,
|
||||
device=device,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
|
@ -91,9 +91,9 @@ class DistributedTorchRunner(TorchRunner):
|
|||
add_dist_sampler=self.add_dist_sampler,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_device_ids(self):
|
||||
def get_device(self):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
return [0]
|
||||
return torch.device("cuda:0")
|
||||
|
||||
def train_epoch(self,
|
||||
num_steps=None,
|
||||
|
@ -296,8 +296,9 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
|||
logger.error("Failed to set local CUDA device.")
|
||||
raise
|
||||
|
||||
def get_device_ids(self):
|
||||
return [int(self.local_cuda_device)]
|
||||
def get_device(self):
|
||||
device_str = "cuda:" + self.local_cuda_device
|
||||
return torch.device(device_str)
|
||||
|
||||
def shutdown(self, cleanup=True):
|
||||
super(LocalDistributedRunner, self).shutdown()
|
||||
|
|
|
@ -121,7 +121,7 @@ class TrainingOperator:
|
|||
world_rank,
|
||||
local_rank,
|
||||
is_distributed=False,
|
||||
device_ids=None,
|
||||
device=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
|
@ -134,9 +134,8 @@ class TrainingOperator:
|
|||
self._config = config
|
||||
self._is_distributed = is_distributed
|
||||
self._use_fp16 = use_fp16
|
||||
self._device_ids = device_ids
|
||||
self._device = device
|
||||
self._use_gpu = use_gpu and torch.cuda.is_available()
|
||||
self._device = torch.device("cuda" if self._use_gpu else "cpu")
|
||||
if tqdm is None and use_tqdm:
|
||||
raise ValueError("tqdm must be installed to use tqdm in training.")
|
||||
self._use_tqdm = use_tqdm
|
||||
|
@ -874,7 +873,8 @@ class TrainingOperator:
|
|||
|
||||
@property
|
||||
def device(self):
|
||||
"""torch.device: The appropriate torch device, at your convenience."""
|
||||
"""torch.device: The appropriate torch device, at your
|
||||
convenience."""
|
||||
return self._device
|
||||
|
||||
@property
|
||||
|
@ -909,11 +909,14 @@ class TrainingOperator:
|
|||
|
||||
@property
|
||||
def device_ids(self):
|
||||
"""List[int]: Device IDs for the model.
|
||||
"""Optional[List[int]]: Device IDs for the model.
|
||||
|
||||
This is useful for using batch norm with DistributedDataParallel.
|
||||
Not applicable if not using GPU.
|
||||
"""
|
||||
return self._device_ids
|
||||
if not self.use_gpu:
|
||||
return None
|
||||
return [self.device.index]
|
||||
|
||||
@property
|
||||
def scheduler_step_freq(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue