[Train] Add support for automatic mixed precision (#22227)

Closes #20643

Co-authored-by: Ubuntu <ubuntu@ip-172-31-58-19.us-west-2.compute.internal>
This commit is contained in:
Balaji Veeramani 2022-03-16 22:53:02 -05:00 committed by GitHub
parent 77090144a2
commit 83986a4d83
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 275 additions and 5 deletions

View file

@ -138,9 +138,11 @@ MOCK_MODULES = [
"tensorflow.python.client",
"tensorflow.python.util",
"torch",
"torch.cuda.amp",
"torch.distributed",
"torch.nn",
"torch.nn.parallel",
"torch.optim",
"torch.profiler",
"torch.utils.data",
"torch.utils.data.distributed",

View file

@ -200,6 +200,18 @@ train.torch.prepare_data_loader
.. autofunction:: ray.train.torch.prepare_data_loader
train.torch.prepare_optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.prepare_optimizer
train.torch.backward
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.backward
train.torch.get_device
~~~~~~~~~~~~~~~~~~~~~~
@ -212,6 +224,11 @@ train.torch.enable_reproducibility
.. _train-api-torch-worker-profiler:
train.torch.accelerate
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.accelerate
train.torch.TorchWorkerProfiler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -977,6 +977,54 @@ the disk that from which your script was executed from.
# View the PyTorch Profiler traces.
$ open http://localhost:6006/#pytorch_profiler
.. _torch-amp:
Automatic Mixed Precision
-------------------------
Automatic mixed precision (AMP) lets you train your models faster by using a lower
precision datatype for operations like linear layers and convolutions.
.. tabbed:: PyTorch
You can train your Torch model with AMP by:
1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function.
2. Wrapping your optimizer with ``train.torch.prepare_optimizer``.
3. Replacing your backward call with ``train.torch.backward``.
.. code-block:: diff
def train_func():
+ train.torch.accelerate(amp=True)
model = NeuralNetwork()
model = train.torch.prepare_model(model)
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
data_loader = train.torch.prepare_data_loader(data_loader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
+ optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(90):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
- loss.backward()
+ train.torch.backward(loss)
optimizer.step()
...
.. note:: The performance of AMP varies based on GPU architecture, model type,
and data shape. For certain workflows, AMP may perform worse than
full-precision training.
.. _train-reproducibility:
Reproducibility

View file

@ -1,5 +1,6 @@
import os
import pytest
from timeit import default_timer as timer
import torch
from torch.nn.parallel import DistributedDataParallel
@ -127,6 +128,65 @@ def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu):
assert result1 == result2
def test_torch_amp(ray_start_4_cpus_2_gpus):
def train_func(config):
train.torch.accelerate(amp=config["amp"])
model = torchvision.models.resnet101()
model = train.torch.prepare_model(model)
dataset_length = 1000
dataset = torch.utils.data.TensorDataset(
torch.randn(dataset_length, 3, 224, 224),
torch.randint(low=0, high=1000, size=(dataset_length,)),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
dataloader = train.torch.prepare_data_loader(dataloader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(1):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
train.torch.backward(loss)
optimizer.step()
def latency(amp: bool) -> float:
trainer = Trainer("torch", num_workers=2, use_gpu=True)
trainer.start()
start_time = timer()
trainer.run(train_func, {"amp": amp})
end_time = timer()
trainer.shutdown()
return end_time - start_time
# Training should be at least 5% faster with AMP.
assert 1.05 * latency(amp=True) < latency(amp=False)
def test_checkpoint_torch_model_with_amp(ray_start_4_cpus_2_gpus):
"""Test that model with AMP is serializable."""
def train_func():
train.torch.accelerate(amp=True)
model = torchvision.models.resnet101()
model = train.torch.prepare_model(model)
train.save_checkpoint(model=model)
trainer = Trainer("torch", num_workers=1, use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus):
"""Tests if GPU tensors are auto converted to CPU on driver."""

View file

@ -1,9 +1,11 @@
import tempfile
from dataclasses import dataclass
import functools
import io
import logging
import os
import random
import types
from datetime import timedelta
from pathlib import Path
@ -14,6 +16,7 @@ from ray import train
from ray.train.accelerator import Accelerator
from ray.train.backend import BackendConfig, Backend, EncodedData
from ray.train.constants import PYTORCH_PROFILER_KEY
from torch.optim import Optimizer
from ray.train.session import get_accelerator, set_accelerator
from ray.train.worker_group import WorkerGroup
from ray.train.utils import get_address_and_port
@ -21,6 +24,7 @@ from ray.util import PublicAPI
import numpy as np
import torch
from torch.cuda.amp import autocast, GradScaler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import (
@ -40,9 +44,16 @@ logger = logging.getLogger(__name__)
class TorchAccelerator(Accelerator):
"""A utility that implements methods to accelerate PyTorch training."""
"""A utility that implements methods to accelerate PyTorch training.
def __init__(self):
Arguments:
amp (bool): If true, perform training with automatic mixed precision.
Otherwise, use full precision.
"""
def __init__(self, amp: bool = False):
self.amp_is_enabled = amp
self.scaler = GradScaler() if amp else None
self._seed = None
def prepare_model(
@ -80,6 +91,37 @@ class TorchAccelerator(Accelerator):
if move_to_device:
logger.info(f"Moving model to device: {device}")
model = model.to(device)
def wrap_forward(forward):
@functools.wraps(forward)
def wrapper(*args, **kwargs):
with autocast():
outputs = forward(*args, **kwargs)
assert isinstance(outputs, torch.Tensor)
return outputs.float()
return wrapper
def model_get_state(self):
# `__getstate__` is an special method that informs pickle which attributes
# to serialize. This custom implementation ensures that the wrapped forward
# method and custom `__getstate__` method aren't serialized.
state = self.__dict__.copy()
state["forward"] = state["_unwrapped_forward"]
del state["_unwrapped_forward"]
del state["__getstate__"]
return state
if self.amp_is_enabled:
# Pickle cannot serialize the wrapped forward method. As a workaround,
# define a custom `__getstate__` method that unwraps the forward method.
model._unwrapped_forward = model.forward
model.forward = wrap_forward(model.forward)
# `__getstate__` must be a bound method rather than an callable attribute.
# See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501
assert not hasattr(model, "__getstate__")
model.__getstate__ = types.MethodType(model_get_state, model)
if wrap_ddp and train.world_size() > 1:
logger.info("Wrapping provided model in DDP.")
if torch.cuda.is_available():
@ -206,6 +248,28 @@ class TorchAccelerator(Accelerator):
return device
def prepare_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Wraps optimizer to support automatic mixed precision.
Args:
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
Returns:
A wrapped optimizer.
"""
return _WrappedOptimizer(optimizer, scaler=self.scaler)
def backward(self, tensor: torch.Tensor) -> None:
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
Args:
tensor (torch.Tensor): Tensor of which the derivative will be computed.
"""
if self.amp_is_enabled:
self.scaler.scale(tensor).backward()
else:
tensor.backward()
def enable_reproducibility(self, seed: int = 0) -> None:
"""Limits sources of nondeterministic behavior."""
self._seed = seed
@ -451,6 +515,55 @@ class _WrappedDataLoader(DataLoader):
return next_batch
class _WrappedOptimizer(Optimizer):
def __init__(self, optimizer: Optimizer, scaler: Optional[GradScaler] = None):
self.optimizer = optimizer
self.scaler = scaler
@property
def state(self):
return self.optimizer.state
@state.setter
def state(self, state):
self.optimizer.state = state
@property
def param_groups(self):
return self.optimizer.param_groups
@param_groups.setter
def param_groups(self, param_groups):
self.optimizer.param_groups = param_groups
@property
def defaults(self):
return self.optimizer.defaults
@defaults.setter
def defaults(self, defaults):
self.optimizer.defaults = defaults
def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def state_dict(self):
return self.optimizer.state_dict()
def zero_grad(self):
self.optimizer.zero_grad()
def step(self, closure=None):
if self.scaler is not None:
self.scaler.step(self.optimizer, closure)
self.scaler.update()
else:
self.optimizer.step(closure)
@PublicAPI(stability="beta")
def get_device() -> torch.device:
"""Gets the correct torch device to use for training."""
@ -523,10 +636,18 @@ def prepare_data_loader(
@PublicAPI(stability="beta")
def accelerate() -> None:
"""Enables training optimizations."""
def accelerate(amp: bool = False) -> None:
"""Enables training optimizations.
Arguments:
amp (bool): If true, perform training with automatic mixed precision.
Otherwise, use full precision.
.. warning:: ``train.torch.accelerate`` cannot be called more than once, and it
must be called before any other ``train.torch`` utility function.
"""
try:
set_accelerator(TorchAccelerator())
set_accelerator(TorchAccelerator(amp=amp))
except RuntimeError:
raise RuntimeError(
"An accelerator has already been set. Make sure "
@ -536,6 +657,28 @@ def accelerate() -> None:
@PublicAPI(stability="beta")
def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""Wraps optimizer to support automatic mixed precision.
Args:
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
Returns:
A wrapped optimizer.
"""
return get_accelerator(TorchAccelerator).prepare_optimizer(optimizer)
@PublicAPI(stability="beta")
def backward(tensor: torch.Tensor) -> None:
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
Args:
tensor (torch.Tensor): Tensor of which the derivative will be computed.
"""
get_accelerator(TorchAccelerator).backward(tensor)
def enable_reproducibility(seed: int = 0) -> None:
"""Limits sources of nondeterministic behavior.