[sgd] support operator.device (#12056)

This commit is contained in:
Amog Kamsetty 2020-11-16 21:44:27 -08:00 committed by GitHub
parent 380df89069
commit f10cef93c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 13 deletions

View file

@ -74,16 +74,16 @@ class DistributedTorchRunner(TorchRunner):
This helps avoid timeouts due to creator functions (perhaps This helps avoid timeouts due to creator functions (perhaps
downloading data or models). downloading data or models).
""" """
device_ids = None device = torch.device("cpu")
if self.use_gpu and torch.cuda.is_available(): if self.use_gpu and torch.cuda.is_available():
device_ids = self.get_device_ids() device = self.get_device()
self.training_operator = self.training_operator_cls( self.training_operator = self.training_operator_cls(
self.config, self.config,
world_rank=self.world_rank, world_rank=self.world_rank,
local_rank=self.local_rank, local_rank=self.local_rank,
is_distributed=True, is_distributed=True,
device_ids=device_ids, device=device,
use_gpu=self.use_gpu, use_gpu=self.use_gpu,
use_fp16=self.use_fp16, use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm, use_tqdm=self.use_tqdm,
@ -91,9 +91,9 @@ class DistributedTorchRunner(TorchRunner):
add_dist_sampler=self.add_dist_sampler, add_dist_sampler=self.add_dist_sampler,
scheduler_step_freq=self.scheduler_step_freq) scheduler_step_freq=self.scheduler_step_freq)
def get_device_ids(self): def get_device(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process.""" """Needed for SyncBatchNorm, which needs 1 GPU per process."""
return [0] return torch.device("cuda:0")
def train_epoch(self, def train_epoch(self,
num_steps=None, num_steps=None,
@ -296,8 +296,9 @@ class LocalDistributedRunner(DistributedTorchRunner):
logger.error("Failed to set local CUDA device.") logger.error("Failed to set local CUDA device.")
raise raise
def get_device_ids(self): def get_device(self):
return [int(self.local_cuda_device)] device_str = "cuda:" + self.local_cuda_device
return torch.device(device_str)
def shutdown(self, cleanup=True): def shutdown(self, cleanup=True):
super(LocalDistributedRunner, self).shutdown() super(LocalDistributedRunner, self).shutdown()

View file

@ -121,7 +121,7 @@ class TrainingOperator:
world_rank, world_rank,
local_rank, local_rank,
is_distributed=False, is_distributed=False,
device_ids=None, device=None,
use_gpu=False, use_gpu=False,
use_fp16=False, use_fp16=False,
use_tqdm=False, use_tqdm=False,
@ -134,9 +134,8 @@ class TrainingOperator:
self._config = config self._config = config
self._is_distributed = is_distributed self._is_distributed = is_distributed
self._use_fp16 = use_fp16 self._use_fp16 = use_fp16
self._device_ids = device_ids self._device = device
self._use_gpu = use_gpu and torch.cuda.is_available() self._use_gpu = use_gpu and torch.cuda.is_available()
self._device = torch.device("cuda" if self._use_gpu else "cpu")
if tqdm is None and use_tqdm: if tqdm is None and use_tqdm:
raise ValueError("tqdm must be installed to use tqdm in training.") raise ValueError("tqdm must be installed to use tqdm in training.")
self._use_tqdm = use_tqdm self._use_tqdm = use_tqdm
@ -874,7 +873,8 @@ class TrainingOperator:
@property @property
def device(self): def device(self):
"""torch.device: The appropriate torch device, at your convenience.""" """torch.device: The appropriate torch device, at your
convenience."""
return self._device return self._device
@property @property
@ -909,11 +909,14 @@ class TrainingOperator:
@property @property
def device_ids(self): def device_ids(self):
"""List[int]: Device IDs for the model. """Optional[List[int]]: Device IDs for the model.
This is useful for using batch norm with DistributedDataParallel. This is useful for using batch norm with DistributedDataParallel.
Not applicable if not using GPU.
""" """
return self._device_ids if not self.use_gpu:
return None
return [self.device.index]
@property @property
def scheduler_step_freq(self): def scheduler_step_freq(self):