mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[sgd] support operator.device (#12056)
This commit is contained in:
parent
380df89069
commit
f10cef93c7
2 changed files with 17 additions and 13 deletions
|
@ -74,16 +74,16 @@ class DistributedTorchRunner(TorchRunner):
|
||||||
This helps avoid timeouts due to creator functions (perhaps
|
This helps avoid timeouts due to creator functions (perhaps
|
||||||
downloading data or models).
|
downloading data or models).
|
||||||
"""
|
"""
|
||||||
device_ids = None
|
device = torch.device("cpu")
|
||||||
if self.use_gpu and torch.cuda.is_available():
|
if self.use_gpu and torch.cuda.is_available():
|
||||||
device_ids = self.get_device_ids()
|
device = self.get_device()
|
||||||
|
|
||||||
self.training_operator = self.training_operator_cls(
|
self.training_operator = self.training_operator_cls(
|
||||||
self.config,
|
self.config,
|
||||||
world_rank=self.world_rank,
|
world_rank=self.world_rank,
|
||||||
local_rank=self.local_rank,
|
local_rank=self.local_rank,
|
||||||
is_distributed=True,
|
is_distributed=True,
|
||||||
device_ids=device_ids,
|
device=device,
|
||||||
use_gpu=self.use_gpu,
|
use_gpu=self.use_gpu,
|
||||||
use_fp16=self.use_fp16,
|
use_fp16=self.use_fp16,
|
||||||
use_tqdm=self.use_tqdm,
|
use_tqdm=self.use_tqdm,
|
||||||
|
@ -91,9 +91,9 @@ class DistributedTorchRunner(TorchRunner):
|
||||||
add_dist_sampler=self.add_dist_sampler,
|
add_dist_sampler=self.add_dist_sampler,
|
||||||
scheduler_step_freq=self.scheduler_step_freq)
|
scheduler_step_freq=self.scheduler_step_freq)
|
||||||
|
|
||||||
def get_device_ids(self):
|
def get_device(self):
|
||||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||||
return [0]
|
return torch.device("cuda:0")
|
||||||
|
|
||||||
def train_epoch(self,
|
def train_epoch(self,
|
||||||
num_steps=None,
|
num_steps=None,
|
||||||
|
@ -296,8 +296,9 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||||
logger.error("Failed to set local CUDA device.")
|
logger.error("Failed to set local CUDA device.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_device_ids(self):
|
def get_device(self):
|
||||||
return [int(self.local_cuda_device)]
|
device_str = "cuda:" + self.local_cuda_device
|
||||||
|
return torch.device(device_str)
|
||||||
|
|
||||||
def shutdown(self, cleanup=True):
|
def shutdown(self, cleanup=True):
|
||||||
super(LocalDistributedRunner, self).shutdown()
|
super(LocalDistributedRunner, self).shutdown()
|
||||||
|
|
|
@ -121,7 +121,7 @@ class TrainingOperator:
|
||||||
world_rank,
|
world_rank,
|
||||||
local_rank,
|
local_rank,
|
||||||
is_distributed=False,
|
is_distributed=False,
|
||||||
device_ids=None,
|
device=None,
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
use_tqdm=False,
|
use_tqdm=False,
|
||||||
|
@ -134,9 +134,8 @@ class TrainingOperator:
|
||||||
self._config = config
|
self._config = config
|
||||||
self._is_distributed = is_distributed
|
self._is_distributed = is_distributed
|
||||||
self._use_fp16 = use_fp16
|
self._use_fp16 = use_fp16
|
||||||
self._device_ids = device_ids
|
self._device = device
|
||||||
self._use_gpu = use_gpu and torch.cuda.is_available()
|
self._use_gpu = use_gpu and torch.cuda.is_available()
|
||||||
self._device = torch.device("cuda" if self._use_gpu else "cpu")
|
|
||||||
if tqdm is None and use_tqdm:
|
if tqdm is None and use_tqdm:
|
||||||
raise ValueError("tqdm must be installed to use tqdm in training.")
|
raise ValueError("tqdm must be installed to use tqdm in training.")
|
||||||
self._use_tqdm = use_tqdm
|
self._use_tqdm = use_tqdm
|
||||||
|
@ -874,7 +873,8 @@ class TrainingOperator:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
"""torch.device: The appropriate torch device, at your convenience."""
|
"""torch.device: The appropriate torch device, at your
|
||||||
|
convenience."""
|
||||||
return self._device
|
return self._device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -909,11 +909,14 @@ class TrainingOperator:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_ids(self):
|
def device_ids(self):
|
||||||
"""List[int]: Device IDs for the model.
|
"""Optional[List[int]]: Device IDs for the model.
|
||||||
|
|
||||||
This is useful for using batch norm with DistributedDataParallel.
|
This is useful for using batch norm with DistributedDataParallel.
|
||||||
|
Not applicable if not using GPU.
|
||||||
"""
|
"""
|
||||||
return self._device_ids
|
if not self.use_gpu:
|
||||||
|
return None
|
||||||
|
return [self.device.index]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_step_freq(self):
|
def scheduler_step_freq(self):
|
||||||
|
|
Loading…
Add table
Reference in a new issue