mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Train] Add support for automatic mixed precision (#22227)
Closes #20643 Co-authored-by: Ubuntu <ubuntu@ip-172-31-58-19.us-west-2.compute.internal>
This commit is contained in:
parent
77090144a2
commit
83986a4d83
5 changed files with 275 additions and 5 deletions
|
@ -138,9 +138,11 @@ MOCK_MODULES = [
|
|||
"tensorflow.python.client",
|
||||
"tensorflow.python.util",
|
||||
"torch",
|
||||
"torch.cuda.amp",
|
||||
"torch.distributed",
|
||||
"torch.nn",
|
||||
"torch.nn.parallel",
|
||||
"torch.optim",
|
||||
"torch.profiler",
|
||||
"torch.utils.data",
|
||||
"torch.utils.data.distributed",
|
||||
|
|
|
@ -200,6 +200,18 @@ train.torch.prepare_data_loader
|
|||
|
||||
.. autofunction:: ray.train.torch.prepare_data_loader
|
||||
|
||||
train.torch.prepare_optimizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: ray.train.torch.prepare_optimizer
|
||||
|
||||
|
||||
train.torch.backward
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: ray.train.torch.backward
|
||||
|
||||
|
||||
train.torch.get_device
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -212,6 +224,11 @@ train.torch.enable_reproducibility
|
|||
|
||||
.. _train-api-torch-worker-profiler:
|
||||
|
||||
train.torch.accelerate
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: ray.train.torch.accelerate
|
||||
|
||||
train.torch.TorchWorkerProfiler
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -977,6 +977,54 @@ the disk that from which your script was executed from.
|
|||
# View the PyTorch Profiler traces.
|
||||
$ open http://localhost:6006/#pytorch_profiler
|
||||
|
||||
.. _torch-amp:
|
||||
|
||||
Automatic Mixed Precision
|
||||
-------------------------
|
||||
|
||||
Automatic mixed precision (AMP) lets you train your models faster by using a lower
|
||||
precision datatype for operations like linear layers and convolutions.
|
||||
|
||||
.. tabbed:: PyTorch
|
||||
|
||||
You can train your Torch model with AMP by:
|
||||
|
||||
1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function.
|
||||
2. Wrapping your optimizer with ``train.torch.prepare_optimizer``.
|
||||
3. Replacing your backward call with ``train.torch.backward``.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
def train_func():
|
||||
+ train.torch.accelerate(amp=True)
|
||||
|
||||
model = NeuralNetwork()
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
|
||||
data_loader = train.torch.prepare_data_loader(data_loader)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||
+ optimizer = train.torch.prepare_optimizer(optimizer)
|
||||
|
||||
model.train()
|
||||
for epoch in range(90):
|
||||
for images, targets in dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(images)
|
||||
loss = torch.nn.functional.cross_entropy(outputs, targets)
|
||||
|
||||
- loss.backward()
|
||||
+ train.torch.backward(loss)
|
||||
optimizer.step()
|
||||
...
|
||||
|
||||
|
||||
.. note:: The performance of AMP varies based on GPU architecture, model type,
|
||||
and data shape. For certain workflows, AMP may perform worse than
|
||||
full-precision training.
|
||||
|
||||
.. _train-reproducibility:
|
||||
|
||||
Reproducibility
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import pytest
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
@ -127,6 +128,65 @@ def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu):
|
|||
assert result1 == result2
|
||||
|
||||
|
||||
def test_torch_amp(ray_start_4_cpus_2_gpus):
|
||||
def train_func(config):
|
||||
train.torch.accelerate(amp=config["amp"])
|
||||
|
||||
model = torchvision.models.resnet101()
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
dataset_length = 1000
|
||||
dataset = torch.utils.data.TensorDataset(
|
||||
torch.randn(dataset_length, 3, 224, 224),
|
||||
torch.randint(low=0, high=1000, size=(dataset_length,)),
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
|
||||
dataloader = train.torch.prepare_data_loader(dataloader)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||
optimizer = train.torch.prepare_optimizer(optimizer)
|
||||
|
||||
model.train()
|
||||
for epoch in range(1):
|
||||
for images, targets in dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(images)
|
||||
loss = torch.nn.functional.cross_entropy(outputs, targets)
|
||||
|
||||
train.torch.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
def latency(amp: bool) -> float:
|
||||
trainer = Trainer("torch", num_workers=2, use_gpu=True)
|
||||
trainer.start()
|
||||
start_time = timer()
|
||||
trainer.run(train_func, {"amp": amp})
|
||||
end_time = timer()
|
||||
trainer.shutdown()
|
||||
return end_time - start_time
|
||||
|
||||
# Training should be at least 5% faster with AMP.
|
||||
assert 1.05 * latency(amp=True) < latency(amp=False)
|
||||
|
||||
|
||||
def test_checkpoint_torch_model_with_amp(ray_start_4_cpus_2_gpus):
|
||||
"""Test that model with AMP is serializable."""
|
||||
|
||||
def train_func():
|
||||
train.torch.accelerate(amp=True)
|
||||
|
||||
model = torchvision.models.resnet101()
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
train.save_checkpoint(model=model)
|
||||
|
||||
trainer = Trainer("torch", num_workers=1, use_gpu=True)
|
||||
trainer.start()
|
||||
trainer.run(train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus):
|
||||
"""Tests if GPU tensors are auto converted to CPU on driver."""
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import types
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
@ -14,6 +16,7 @@ from ray import train
|
|||
from ray.train.accelerator import Accelerator
|
||||
from ray.train.backend import BackendConfig, Backend, EncodedData
|
||||
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||
from torch.optim import Optimizer
|
||||
from ray.train.session import get_accelerator, set_accelerator
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.train.utils import get_address_and_port
|
||||
|
@ -21,6 +24,7 @@ from ray.util import PublicAPI
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import (
|
||||
|
@ -40,9 +44,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TorchAccelerator(Accelerator):
|
||||
"""A utility that implements methods to accelerate PyTorch training."""
|
||||
"""A utility that implements methods to accelerate PyTorch training.
|
||||
|
||||
def __init__(self):
|
||||
Arguments:
|
||||
amp (bool): If true, perform training with automatic mixed precision.
|
||||
Otherwise, use full precision.
|
||||
"""
|
||||
|
||||
def __init__(self, amp: bool = False):
|
||||
self.amp_is_enabled = amp
|
||||
self.scaler = GradScaler() if amp else None
|
||||
self._seed = None
|
||||
|
||||
def prepare_model(
|
||||
|
@ -80,6 +91,37 @@ class TorchAccelerator(Accelerator):
|
|||
if move_to_device:
|
||||
logger.info(f"Moving model to device: {device}")
|
||||
model = model.to(device)
|
||||
|
||||
def wrap_forward(forward):
|
||||
@functools.wraps(forward)
|
||||
def wrapper(*args, **kwargs):
|
||||
with autocast():
|
||||
outputs = forward(*args, **kwargs)
|
||||
assert isinstance(outputs, torch.Tensor)
|
||||
return outputs.float()
|
||||
|
||||
return wrapper
|
||||
|
||||
def model_get_state(self):
|
||||
# `__getstate__` is an special method that informs pickle which attributes
|
||||
# to serialize. This custom implementation ensures that the wrapped forward
|
||||
# method and custom `__getstate__` method aren't serialized.
|
||||
state = self.__dict__.copy()
|
||||
state["forward"] = state["_unwrapped_forward"]
|
||||
del state["_unwrapped_forward"]
|
||||
del state["__getstate__"]
|
||||
return state
|
||||
|
||||
if self.amp_is_enabled:
|
||||
# Pickle cannot serialize the wrapped forward method. As a workaround,
|
||||
# define a custom `__getstate__` method that unwraps the forward method.
|
||||
model._unwrapped_forward = model.forward
|
||||
model.forward = wrap_forward(model.forward)
|
||||
# `__getstate__` must be a bound method rather than an callable attribute.
|
||||
# See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501
|
||||
assert not hasattr(model, "__getstate__")
|
||||
model.__getstate__ = types.MethodType(model_get_state, model)
|
||||
|
||||
if wrap_ddp and train.world_size() > 1:
|
||||
logger.info("Wrapping provided model in DDP.")
|
||||
if torch.cuda.is_available():
|
||||
|
@ -206,6 +248,28 @@ class TorchAccelerator(Accelerator):
|
|||
|
||||
return device
|
||||
|
||||
def prepare_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Wraps optimizer to support automatic mixed precision.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
|
||||
|
||||
Returns:
|
||||
A wrapped optimizer.
|
||||
"""
|
||||
return _WrappedOptimizer(optimizer, scaler=self.scaler)
|
||||
|
||||
def backward(self, tensor: torch.Tensor) -> None:
|
||||
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor of which the derivative will be computed.
|
||||
"""
|
||||
if self.amp_is_enabled:
|
||||
self.scaler.scale(tensor).backward()
|
||||
else:
|
||||
tensor.backward()
|
||||
|
||||
def enable_reproducibility(self, seed: int = 0) -> None:
|
||||
"""Limits sources of nondeterministic behavior."""
|
||||
self._seed = seed
|
||||
|
@ -451,6 +515,55 @@ class _WrappedDataLoader(DataLoader):
|
|||
return next_batch
|
||||
|
||||
|
||||
class _WrappedOptimizer(Optimizer):
|
||||
def __init__(self, optimizer: Optimizer, scaler: Optional[GradScaler] = None):
|
||||
self.optimizer = optimizer
|
||||
self.scaler = scaler
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.optimizer.state
|
||||
|
||||
@state.setter
|
||||
def state(self, state):
|
||||
self.optimizer.state = state
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, param_groups):
|
||||
self.optimizer.param_groups = param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self.optimizer.defaults
|
||||
|
||||
@defaults.setter
|
||||
def defaults(self, defaults):
|
||||
self.optimizer.defaults = defaults
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self.optimizer.add_param_group(param_group)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def step(self, closure=None):
|
||||
if self.scaler is not None:
|
||||
self.scaler.step(self.optimizer, closure)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step(closure)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def get_device() -> torch.device:
|
||||
"""Gets the correct torch device to use for training."""
|
||||
|
@ -523,10 +636,18 @@ def prepare_data_loader(
|
|||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def accelerate() -> None:
|
||||
"""Enables training optimizations."""
|
||||
def accelerate(amp: bool = False) -> None:
|
||||
"""Enables training optimizations.
|
||||
|
||||
Arguments:
|
||||
amp (bool): If true, perform training with automatic mixed precision.
|
||||
Otherwise, use full precision.
|
||||
|
||||
.. warning:: ``train.torch.accelerate`` cannot be called more than once, and it
|
||||
must be called before any other ``train.torch`` utility function.
|
||||
"""
|
||||
try:
|
||||
set_accelerator(TorchAccelerator())
|
||||
set_accelerator(TorchAccelerator(amp=amp))
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"An accelerator has already been set. Make sure "
|
||||
|
@ -536,6 +657,28 @@ def accelerate() -> None:
|
|||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
|
||||
"""Wraps optimizer to support automatic mixed precision.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
|
||||
|
||||
Returns:
|
||||
A wrapped optimizer.
|
||||
"""
|
||||
return get_accelerator(TorchAccelerator).prepare_optimizer(optimizer)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def backward(tensor: torch.Tensor) -> None:
|
||||
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor of which the derivative will be computed.
|
||||
"""
|
||||
get_accelerator(TorchAccelerator).backward(tensor)
|
||||
|
||||
|
||||
def enable_reproducibility(seed: int = 0) -> None:
|
||||
"""Limits sources of nondeterministic behavior.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue