From f10cef93c755e76a122912504ebf5b564356bbd3 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 16 Nov 2020 21:44:27 -0800 Subject: [PATCH] [sgd] support operator.device (#12056) --- .../util/sgd/torch/distributed_torch_runner.py | 15 ++++++++------- python/ray/util/sgd/torch/training_operator.py | 15 +++++++++------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index e09045072..309b6a36f 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -74,16 +74,16 @@ class DistributedTorchRunner(TorchRunner): This helps avoid timeouts due to creator functions (perhaps downloading data or models). """ - device_ids = None + device = torch.device("cpu") if self.use_gpu and torch.cuda.is_available(): - device_ids = self.get_device_ids() + device = self.get_device() self.training_operator = self.training_operator_cls( self.config, world_rank=self.world_rank, local_rank=self.local_rank, is_distributed=True, - device_ids=device_ids, + device=device, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, @@ -91,9 +91,9 @@ class DistributedTorchRunner(TorchRunner): add_dist_sampler=self.add_dist_sampler, scheduler_step_freq=self.scheduler_step_freq) - def get_device_ids(self): + def get_device(self): """Needed for SyncBatchNorm, which needs 1 GPU per process.""" - return [0] + return torch.device("cuda:0") def train_epoch(self, num_steps=None, @@ -296,8 +296,9 @@ class LocalDistributedRunner(DistributedTorchRunner): logger.error("Failed to set local CUDA device.") raise - def get_device_ids(self): - return [int(self.local_cuda_device)] + def get_device(self): + device_str = "cuda:" + self.local_cuda_device + return torch.device(device_str) def shutdown(self, cleanup=True): super(LocalDistributedRunner, self).shutdown() diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index b8ca0e785..595d7adf0 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -121,7 +121,7 @@ class TrainingOperator: world_rank, local_rank, is_distributed=False, - device_ids=None, + device=None, use_gpu=False, use_fp16=False, use_tqdm=False, @@ -134,9 +134,8 @@ class TrainingOperator: self._config = config self._is_distributed = is_distributed self._use_fp16 = use_fp16 - self._device_ids = device_ids + self._device = device self._use_gpu = use_gpu and torch.cuda.is_available() - self._device = torch.device("cuda" if self._use_gpu else "cpu") if tqdm is None and use_tqdm: raise ValueError("tqdm must be installed to use tqdm in training.") self._use_tqdm = use_tqdm @@ -874,7 +873,8 @@ class TrainingOperator: @property def device(self): - """torch.device: The appropriate torch device, at your convenience.""" + """torch.device: The appropriate torch device, at your + convenience.""" return self._device @property @@ -909,11 +909,14 @@ class TrainingOperator: @property def device_ids(self): - """List[int]: Device IDs for the model. + """Optional[List[int]]: Device IDs for the model. This is useful for using batch norm with DistributedDataParallel. + Not applicable if not using GPU. """ - return self._device_ids + if not self.use_gpu: + return None + return [self.device.index] @property def scheduler_step_freq(self):