diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 26f189341..319aca15b 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -138,9 +138,11 @@ MOCK_MODULES = [ "tensorflow.python.client", "tensorflow.python.util", "torch", + "torch.cuda.amp", "torch.distributed", "torch.nn", "torch.nn.parallel", + "torch.optim", "torch.profiler", "torch.utils.data", "torch.utils.data.distributed", diff --git a/doc/source/train/api.rst b/doc/source/train/api.rst index f56b85b3b..d7bad754e 100644 --- a/doc/source/train/api.rst +++ b/doc/source/train/api.rst @@ -200,6 +200,18 @@ train.torch.prepare_data_loader .. autofunction:: ray.train.torch.prepare_data_loader +train.torch.prepare_optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: ray.train.torch.prepare_optimizer + + +train.torch.backward +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: ray.train.torch.backward + + train.torch.get_device ~~~~~~~~~~~~~~~~~~~~~~ @@ -212,6 +224,11 @@ train.torch.enable_reproducibility .. _train-api-torch-worker-profiler: +train.torch.accelerate +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: ray.train.torch.accelerate + train.torch.TorchWorkerProfiler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/train/user_guide.rst b/doc/source/train/user_guide.rst index 1133421a4..12712469b 100644 --- a/doc/source/train/user_guide.rst +++ b/doc/source/train/user_guide.rst @@ -977,6 +977,54 @@ the disk that from which your script was executed from. # View the PyTorch Profiler traces. $ open http://localhost:6006/#pytorch_profiler +.. _torch-amp: + +Automatic Mixed Precision +------------------------- + +Automatic mixed precision (AMP) lets you train your models faster by using a lower +precision datatype for operations like linear layers and convolutions. + +.. tabbed:: PyTorch + + You can train your Torch model with AMP by: + + 1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function. + 2. Wrapping your optimizer with ``train.torch.prepare_optimizer``. + 3. Replacing your backward call with ``train.torch.backward``. + + .. code-block:: diff + + def train_func(): + + train.torch.accelerate(amp=True) + + model = NeuralNetwork() + model = train.torch.prepare_model(model) + + data_loader = DataLoader(my_dataset, batch_size=worker_batch_size) + data_loader = train.torch.prepare_data_loader(data_loader) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + optimizer = train.torch.prepare_optimizer(optimizer) + + model.train() + for epoch in range(90): + for images, targets in dataloader: + optimizer.zero_grad() + + outputs = model(images) + loss = torch.nn.functional.cross_entropy(outputs, targets) + + - loss.backward() + + train.torch.backward(loss) + optimizer.step() + ... + + +.. note:: The performance of AMP varies based on GPU architecture, model type, + and data shape. For certain workflows, AMP may perform worse than + full-precision training. + .. _train-reproducibility: Reproducibility diff --git a/python/ray/train/tests/test_gpu.py b/python/ray/train/tests/test_gpu.py index 6d095332b..311cd58a4 100644 --- a/python/ray/train/tests/test_gpu.py +++ b/python/ray/train/tests/test_gpu.py @@ -1,5 +1,6 @@ import os import pytest +from timeit import default_timer as timer import torch from torch.nn.parallel import DistributedDataParallel @@ -127,6 +128,65 @@ def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu): assert result1 == result2 +def test_torch_amp(ray_start_4_cpus_2_gpus): + def train_func(config): + train.torch.accelerate(amp=config["amp"]) + + model = torchvision.models.resnet101() + model = train.torch.prepare_model(model) + + dataset_length = 1000 + dataset = torch.utils.data.TensorDataset( + torch.randn(dataset_length, 3, 224, 224), + torch.randint(low=0, high=1000, size=(dataset_length,)), + ) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=64) + dataloader = train.torch.prepare_data_loader(dataloader) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = train.torch.prepare_optimizer(optimizer) + + model.train() + for epoch in range(1): + for images, targets in dataloader: + optimizer.zero_grad() + + outputs = model(images) + loss = torch.nn.functional.cross_entropy(outputs, targets) + + train.torch.backward(loss) + optimizer.step() + + def latency(amp: bool) -> float: + trainer = Trainer("torch", num_workers=2, use_gpu=True) + trainer.start() + start_time = timer() + trainer.run(train_func, {"amp": amp}) + end_time = timer() + trainer.shutdown() + return end_time - start_time + + # Training should be at least 5% faster with AMP. + assert 1.05 * latency(amp=True) < latency(amp=False) + + +def test_checkpoint_torch_model_with_amp(ray_start_4_cpus_2_gpus): + """Test that model with AMP is serializable.""" + + def train_func(): + train.torch.accelerate(amp=True) + + model = torchvision.models.resnet101() + model = train.torch.prepare_model(model) + + train.save_checkpoint(model=model) + + trainer = Trainer("torch", num_workers=1, use_gpu=True) + trainer.start() + trainer.run(train_func) + trainer.shutdown() + + def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus): """Tests if GPU tensors are auto converted to CPU on driver.""" diff --git a/python/ray/train/torch.py b/python/ray/train/torch.py index 9c7ac1bdc..204001afa 100644 --- a/python/ray/train/torch.py +++ b/python/ray/train/torch.py @@ -1,9 +1,11 @@ import tempfile from dataclasses import dataclass +import functools import io import logging import os import random +import types from datetime import timedelta from pathlib import Path @@ -14,6 +16,7 @@ from ray import train from ray.train.accelerator import Accelerator from ray.train.backend import BackendConfig, Backend, EncodedData from ray.train.constants import PYTORCH_PROFILER_KEY +from torch.optim import Optimizer from ray.train.session import get_accelerator, set_accelerator from ray.train.worker_group import WorkerGroup from ray.train.utils import get_address_and_port @@ -21,6 +24,7 @@ from ray.util import PublicAPI import numpy as np import torch +from torch.cuda.amp import autocast, GradScaler import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.data import ( @@ -40,9 +44,16 @@ logger = logging.getLogger(__name__) class TorchAccelerator(Accelerator): - """A utility that implements methods to accelerate PyTorch training.""" + """A utility that implements methods to accelerate PyTorch training. - def __init__(self): + Arguments: + amp (bool): If true, perform training with automatic mixed precision. + Otherwise, use full precision. + """ + + def __init__(self, amp: bool = False): + self.amp_is_enabled = amp + self.scaler = GradScaler() if amp else None self._seed = None def prepare_model( @@ -80,6 +91,37 @@ class TorchAccelerator(Accelerator): if move_to_device: logger.info(f"Moving model to device: {device}") model = model.to(device) + + def wrap_forward(forward): + @functools.wraps(forward) + def wrapper(*args, **kwargs): + with autocast(): + outputs = forward(*args, **kwargs) + assert isinstance(outputs, torch.Tensor) + return outputs.float() + + return wrapper + + def model_get_state(self): + # `__getstate__` is an special method that informs pickle which attributes + # to serialize. This custom implementation ensures that the wrapped forward + # method and custom `__getstate__` method aren't serialized. + state = self.__dict__.copy() + state["forward"] = state["_unwrapped_forward"] + del state["_unwrapped_forward"] + del state["__getstate__"] + return state + + if self.amp_is_enabled: + # Pickle cannot serialize the wrapped forward method. As a workaround, + # define a custom `__getstate__` method that unwraps the forward method. + model._unwrapped_forward = model.forward + model.forward = wrap_forward(model.forward) + # `__getstate__` must be a bound method rather than an callable attribute. + # See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501 + assert not hasattr(model, "__getstate__") + model.__getstate__ = types.MethodType(model_get_state, model) + if wrap_ddp and train.world_size() > 1: logger.info("Wrapping provided model in DDP.") if torch.cuda.is_available(): @@ -206,6 +248,28 @@ class TorchAccelerator(Accelerator): return device + def prepare_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Wraps optimizer to support automatic mixed precision. + + Args: + optimizer (torch.optim.Optimizer): The DataLoader to prepare. + + Returns: + A wrapped optimizer. + """ + return _WrappedOptimizer(optimizer, scaler=self.scaler) + + def backward(self, tensor: torch.Tensor) -> None: + """Computes the gradient of the specified tensor w.r.t. graph leaves. + + Args: + tensor (torch.Tensor): Tensor of which the derivative will be computed. + """ + if self.amp_is_enabled: + self.scaler.scale(tensor).backward() + else: + tensor.backward() + def enable_reproducibility(self, seed: int = 0) -> None: """Limits sources of nondeterministic behavior.""" self._seed = seed @@ -451,6 +515,55 @@ class _WrappedDataLoader(DataLoader): return next_batch +class _WrappedOptimizer(Optimizer): + def __init__(self, optimizer: Optimizer, scaler: Optional[GradScaler] = None): + self.optimizer = optimizer + self.scaler = scaler + + @property + def state(self): + return self.optimizer.state + + @state.setter + def state(self, state): + self.optimizer.state = state + + @property + def param_groups(self): + return self.optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self.optimizer.param_groups = param_groups + + @property + def defaults(self): + return self.optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self.optimizer.defaults = defaults + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def state_dict(self): + return self.optimizer.state_dict() + + def zero_grad(self): + self.optimizer.zero_grad() + + def step(self, closure=None): + if self.scaler is not None: + self.scaler.step(self.optimizer, closure) + self.scaler.update() + else: + self.optimizer.step(closure) + + @PublicAPI(stability="beta") def get_device() -> torch.device: """Gets the correct torch device to use for training.""" @@ -523,10 +636,18 @@ def prepare_data_loader( @PublicAPI(stability="beta") -def accelerate() -> None: - """Enables training optimizations.""" +def accelerate(amp: bool = False) -> None: + """Enables training optimizations. + + Arguments: + amp (bool): If true, perform training with automatic mixed precision. + Otherwise, use full precision. + + .. warning:: ``train.torch.accelerate`` cannot be called more than once, and it + must be called before any other ``train.torch`` utility function. + """ try: - set_accelerator(TorchAccelerator()) + set_accelerator(TorchAccelerator(amp=amp)) except RuntimeError: raise RuntimeError( "An accelerator has already been set. Make sure " @@ -536,6 +657,28 @@ def accelerate() -> None: @PublicAPI(stability="beta") +def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer: + """Wraps optimizer to support automatic mixed precision. + + Args: + optimizer (torch.optim.Optimizer): The DataLoader to prepare. + + Returns: + A wrapped optimizer. + """ + return get_accelerator(TorchAccelerator).prepare_optimizer(optimizer) + + +@PublicAPI(stability="beta") +def backward(tensor: torch.Tensor) -> None: + """Computes the gradient of the specified tensor w.r.t. graph leaves. + + Args: + tensor (torch.Tensor): Tensor of which the derivative will be computed. + """ + get_accelerator(TorchAccelerator).backward(tensor) + + def enable_reproducibility(seed: int = 0) -> None: """Limits sources of nondeterministic behavior.