[sgd] support operator.device (#12056)

This commit is contained in:
Amog Kamsetty 2020-11-16 21:44:27 -08:00 committed by GitHub
parent 380df89069
commit f10cef93c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 13 deletions

View file

@ -74,16 +74,16 @@ class DistributedTorchRunner(TorchRunner):
This helps avoid timeouts due to creator functions (perhaps
downloading data or models).
"""
device_ids = None
device = torch.device("cpu")
if self.use_gpu and torch.cuda.is_available():
device_ids = self.get_device_ids()
device = self.get_device()
self.training_operator = self.training_operator_cls(
self.config,
world_rank=self.world_rank,
local_rank=self.local_rank,
is_distributed=True,
device_ids=device_ids,
device=device,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
@ -91,9 +91,9 @@ class DistributedTorchRunner(TorchRunner):
add_dist_sampler=self.add_dist_sampler,
scheduler_step_freq=self.scheduler_step_freq)
def get_device_ids(self):
def get_device(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
return [0]
return torch.device("cuda:0")
def train_epoch(self,
num_steps=None,
@ -296,8 +296,9 @@ class LocalDistributedRunner(DistributedTorchRunner):
logger.error("Failed to set local CUDA device.")
raise
def get_device_ids(self):
return [int(self.local_cuda_device)]
def get_device(self):
device_str = "cuda:" + self.local_cuda_device
return torch.device(device_str)
def shutdown(self, cleanup=True):
super(LocalDistributedRunner, self).shutdown()

View file

@ -121,7 +121,7 @@ class TrainingOperator:
world_rank,
local_rank,
is_distributed=False,
device_ids=None,
device=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False,
@ -134,9 +134,8 @@ class TrainingOperator:
self._config = config
self._is_distributed = is_distributed
self._use_fp16 = use_fp16
self._device_ids = device_ids
self._device = device
self._use_gpu = use_gpu and torch.cuda.is_available()
self._device = torch.device("cuda" if self._use_gpu else "cpu")
if tqdm is None and use_tqdm:
raise ValueError("tqdm must be installed to use tqdm in training.")
self._use_tqdm = use_tqdm
@ -874,7 +873,8 @@ class TrainingOperator:
@property
def device(self):
"""torch.device: The appropriate torch device, at your convenience."""
"""torch.device: The appropriate torch device, at your
convenience."""
return self._device
@property
@ -909,11 +909,14 @@ class TrainingOperator:
@property
def device_ids(self):
"""List[int]: Device IDs for the model.
"""Optional[List[int]]: Device IDs for the model.
This is useful for using batch norm with DistributedDataParallel.
Not applicable if not using GPU.
"""
return self._device_ids
if not self.use_gpu:
return None
return [self.device.index]
@property
def scheduler_step_freq(self):