[Train] Add support for automatic mixed precision (#22227)

Closes #20643

Co-authored-by: Ubuntu <ubuntu@ip-172-31-58-19.us-west-2.compute.internal>
This commit is contained in:
Balaji Veeramani 2022-03-16 22:53:02 -05:00 committed by GitHub
parent 77090144a2
commit 83986a4d83
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 275 additions and 5 deletions

View file

@ -138,9 +138,11 @@ MOCK_MODULES = [
"tensorflow.python.client", "tensorflow.python.client",
"tensorflow.python.util", "tensorflow.python.util",
"torch", "torch",
"torch.cuda.amp",
"torch.distributed", "torch.distributed",
"torch.nn", "torch.nn",
"torch.nn.parallel", "torch.nn.parallel",
"torch.optim",
"torch.profiler", "torch.profiler",
"torch.utils.data", "torch.utils.data",
"torch.utils.data.distributed", "torch.utils.data.distributed",

View file

@ -200,6 +200,18 @@ train.torch.prepare_data_loader
.. autofunction:: ray.train.torch.prepare_data_loader .. autofunction:: ray.train.torch.prepare_data_loader
train.torch.prepare_optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.prepare_optimizer
train.torch.backward
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.backward
train.torch.get_device train.torch.get_device
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
@ -212,6 +224,11 @@ train.torch.enable_reproducibility
.. _train-api-torch-worker-profiler: .. _train-api-torch-worker-profiler:
train.torch.accelerate
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.accelerate
train.torch.TorchWorkerProfiler train.torch.TorchWorkerProfiler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -977,6 +977,54 @@ the disk that from which your script was executed from.
# View the PyTorch Profiler traces. # View the PyTorch Profiler traces.
$ open http://localhost:6006/#pytorch_profiler $ open http://localhost:6006/#pytorch_profiler
.. _torch-amp:
Automatic Mixed Precision
-------------------------
Automatic mixed precision (AMP) lets you train your models faster by using a lower
precision datatype for operations like linear layers and convolutions.
.. tabbed:: PyTorch
You can train your Torch model with AMP by:
1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function.
2. Wrapping your optimizer with ``train.torch.prepare_optimizer``.
3. Replacing your backward call with ``train.torch.backward``.
.. code-block:: diff
def train_func():
+ train.torch.accelerate(amp=True)
model = NeuralNetwork()
model = train.torch.prepare_model(model)
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
data_loader = train.torch.prepare_data_loader(data_loader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
+ optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(90):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
- loss.backward()
+ train.torch.backward(loss)
optimizer.step()
...
.. note:: The performance of AMP varies based on GPU architecture, model type,
and data shape. For certain workflows, AMP may perform worse than
full-precision training.
.. _train-reproducibility: .. _train-reproducibility:
Reproducibility Reproducibility

View file

@ -1,5 +1,6 @@
import os import os
import pytest import pytest
from timeit import default_timer as timer
import torch import torch
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
@ -127,6 +128,65 @@ def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu):
assert result1 == result2 assert result1 == result2
def test_torch_amp(ray_start_4_cpus_2_gpus):
def train_func(config):
train.torch.accelerate(amp=config["amp"])
model = torchvision.models.resnet101()
model = train.torch.prepare_model(model)
dataset_length = 1000
dataset = torch.utils.data.TensorDataset(
torch.randn(dataset_length, 3, 224, 224),
torch.randint(low=0, high=1000, size=(dataset_length,)),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
dataloader = train.torch.prepare_data_loader(dataloader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(1):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
train.torch.backward(loss)
optimizer.step()
def latency(amp: bool) -> float:
trainer = Trainer("torch", num_workers=2, use_gpu=True)
trainer.start()
start_time = timer()
trainer.run(train_func, {"amp": amp})
end_time = timer()
trainer.shutdown()
return end_time - start_time
# Training should be at least 5% faster with AMP.
assert 1.05 * latency(amp=True) < latency(amp=False)
def test_checkpoint_torch_model_with_amp(ray_start_4_cpus_2_gpus):
"""Test that model with AMP is serializable."""
def train_func():
train.torch.accelerate(amp=True)
model = torchvision.models.resnet101()
model = train.torch.prepare_model(model)
train.save_checkpoint(model=model)
trainer = Trainer("torch", num_workers=1, use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus): def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus):
"""Tests if GPU tensors are auto converted to CPU on driver.""" """Tests if GPU tensors are auto converted to CPU on driver."""

View file

@ -1,9 +1,11 @@
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
import functools
import io import io
import logging import logging
import os import os
import random import random
import types
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
@ -14,6 +16,7 @@ from ray import train
from ray.train.accelerator import Accelerator from ray.train.accelerator import Accelerator
from ray.train.backend import BackendConfig, Backend, EncodedData from ray.train.backend import BackendConfig, Backend, EncodedData
from ray.train.constants import PYTORCH_PROFILER_KEY from ray.train.constants import PYTORCH_PROFILER_KEY
from torch.optim import Optimizer
from ray.train.session import get_accelerator, set_accelerator from ray.train.session import get_accelerator, set_accelerator
from ray.train.worker_group import WorkerGroup from ray.train.worker_group import WorkerGroup
from ray.train.utils import get_address_and_port from ray.train.utils import get_address_and_port
@ -21,6 +24,7 @@ from ray.util import PublicAPI
import numpy as np import numpy as np
import torch import torch
from torch.cuda.amp import autocast, GradScaler
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import ( from torch.utils.data import (
@ -40,9 +44,16 @@ logger = logging.getLogger(__name__)
class TorchAccelerator(Accelerator): class TorchAccelerator(Accelerator):
"""A utility that implements methods to accelerate PyTorch training.""" """A utility that implements methods to accelerate PyTorch training.
def __init__(self): Arguments:
amp (bool): If true, perform training with automatic mixed precision.
Otherwise, use full precision.
"""
def __init__(self, amp: bool = False):
self.amp_is_enabled = amp
self.scaler = GradScaler() if amp else None
self._seed = None self._seed = None
def prepare_model( def prepare_model(
@ -80,6 +91,37 @@ class TorchAccelerator(Accelerator):
if move_to_device: if move_to_device:
logger.info(f"Moving model to device: {device}") logger.info(f"Moving model to device: {device}")
model = model.to(device) model = model.to(device)
def wrap_forward(forward):
@functools.wraps(forward)
def wrapper(*args, **kwargs):
with autocast():
outputs = forward(*args, **kwargs)
assert isinstance(outputs, torch.Tensor)
return outputs.float()
return wrapper
def model_get_state(self):
# `__getstate__` is an special method that informs pickle which attributes
# to serialize. This custom implementation ensures that the wrapped forward
# method and custom `__getstate__` method aren't serialized.
state = self.__dict__.copy()
state["forward"] = state["_unwrapped_forward"]
del state["_unwrapped_forward"]
del state["__getstate__"]
return state
if self.amp_is_enabled:
# Pickle cannot serialize the wrapped forward method. As a workaround,
# define a custom `__getstate__` method that unwraps the forward method.
model._unwrapped_forward = model.forward
model.forward = wrap_forward(model.forward)
# `__getstate__` must be a bound method rather than an callable attribute.
# See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501
assert not hasattr(model, "__getstate__")
model.__getstate__ = types.MethodType(model_get_state, model)
if wrap_ddp and train.world_size() > 1: if wrap_ddp and train.world_size() > 1:
logger.info("Wrapping provided model in DDP.") logger.info("Wrapping provided model in DDP.")
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -206,6 +248,28 @@ class TorchAccelerator(Accelerator):
return device return device
def prepare_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Wraps optimizer to support automatic mixed precision.
Args:
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
Returns:
A wrapped optimizer.
"""
return _WrappedOptimizer(optimizer, scaler=self.scaler)
def backward(self, tensor: torch.Tensor) -> None:
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
Args:
tensor (torch.Tensor): Tensor of which the derivative will be computed.
"""
if self.amp_is_enabled:
self.scaler.scale(tensor).backward()
else:
tensor.backward()
def enable_reproducibility(self, seed: int = 0) -> None: def enable_reproducibility(self, seed: int = 0) -> None:
"""Limits sources of nondeterministic behavior.""" """Limits sources of nondeterministic behavior."""
self._seed = seed self._seed = seed
@ -451,6 +515,55 @@ class _WrappedDataLoader(DataLoader):
return next_batch return next_batch
class _WrappedOptimizer(Optimizer):
def __init__(self, optimizer: Optimizer, scaler: Optional[GradScaler] = None):
self.optimizer = optimizer
self.scaler = scaler
@property
def state(self):
return self.optimizer.state
@state.setter
def state(self, state):
self.optimizer.state = state
@property
def param_groups(self):
return self.optimizer.param_groups
@param_groups.setter
def param_groups(self, param_groups):
self.optimizer.param_groups = param_groups
@property
def defaults(self):
return self.optimizer.defaults
@defaults.setter
def defaults(self, defaults):
self.optimizer.defaults = defaults
def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def state_dict(self):
return self.optimizer.state_dict()
def zero_grad(self):
self.optimizer.zero_grad()
def step(self, closure=None):
if self.scaler is not None:
self.scaler.step(self.optimizer, closure)
self.scaler.update()
else:
self.optimizer.step(closure)
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
def get_device() -> torch.device: def get_device() -> torch.device:
"""Gets the correct torch device to use for training.""" """Gets the correct torch device to use for training."""
@ -523,10 +636,18 @@ def prepare_data_loader(
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
def accelerate() -> None: def accelerate(amp: bool = False) -> None:
"""Enables training optimizations.""" """Enables training optimizations.
Arguments:
amp (bool): If true, perform training with automatic mixed precision.
Otherwise, use full precision.
.. warning:: ``train.torch.accelerate`` cannot be called more than once, and it
must be called before any other ``train.torch`` utility function.
"""
try: try:
set_accelerator(TorchAccelerator()) set_accelerator(TorchAccelerator(amp=amp))
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"An accelerator has already been set. Make sure " "An accelerator has already been set. Make sure "
@ -536,6 +657,28 @@ def accelerate() -> None:
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""Wraps optimizer to support automatic mixed precision.
Args:
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
Returns:
A wrapped optimizer.
"""
return get_accelerator(TorchAccelerator).prepare_optimizer(optimizer)
@PublicAPI(stability="beta")
def backward(tensor: torch.Tensor) -> None:
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
Args:
tensor (torch.Tensor): Tensor of which the derivative will be computed.
"""
get_accelerator(TorchAccelerator).backward(tensor)
def enable_reproducibility(seed: int = 0) -> None: def enable_reproducibility(seed: int = 0) -> None:
"""Limits sources of nondeterministic behavior. """Limits sources of nondeterministic behavior.