mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Train] Add support for automatic mixed precision (#22227)
Closes #20643 Co-authored-by: Ubuntu <ubuntu@ip-172-31-58-19.us-west-2.compute.internal>
This commit is contained in:
parent
77090144a2
commit
83986a4d83
5 changed files with 275 additions and 5 deletions
|
@ -138,9 +138,11 @@ MOCK_MODULES = [
|
||||||
"tensorflow.python.client",
|
"tensorflow.python.client",
|
||||||
"tensorflow.python.util",
|
"tensorflow.python.util",
|
||||||
"torch",
|
"torch",
|
||||||
|
"torch.cuda.amp",
|
||||||
"torch.distributed",
|
"torch.distributed",
|
||||||
"torch.nn",
|
"torch.nn",
|
||||||
"torch.nn.parallel",
|
"torch.nn.parallel",
|
||||||
|
"torch.optim",
|
||||||
"torch.profiler",
|
"torch.profiler",
|
||||||
"torch.utils.data",
|
"torch.utils.data",
|
||||||
"torch.utils.data.distributed",
|
"torch.utils.data.distributed",
|
||||||
|
|
|
@ -200,6 +200,18 @@ train.torch.prepare_data_loader
|
||||||
|
|
||||||
.. autofunction:: ray.train.torch.prepare_data_loader
|
.. autofunction:: ray.train.torch.prepare_data_loader
|
||||||
|
|
||||||
|
train.torch.prepare_optimizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autofunction:: ray.train.torch.prepare_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
train.torch.backward
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autofunction:: ray.train.torch.backward
|
||||||
|
|
||||||
|
|
||||||
train.torch.get_device
|
train.torch.get_device
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -212,6 +224,11 @@ train.torch.enable_reproducibility
|
||||||
|
|
||||||
.. _train-api-torch-worker-profiler:
|
.. _train-api-torch-worker-profiler:
|
||||||
|
|
||||||
|
train.torch.accelerate
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autofunction:: ray.train.torch.accelerate
|
||||||
|
|
||||||
train.torch.TorchWorkerProfiler
|
train.torch.TorchWorkerProfiler
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -977,6 +977,54 @@ the disk that from which your script was executed from.
|
||||||
# View the PyTorch Profiler traces.
|
# View the PyTorch Profiler traces.
|
||||||
$ open http://localhost:6006/#pytorch_profiler
|
$ open http://localhost:6006/#pytorch_profiler
|
||||||
|
|
||||||
|
.. _torch-amp:
|
||||||
|
|
||||||
|
Automatic Mixed Precision
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
Automatic mixed precision (AMP) lets you train your models faster by using a lower
|
||||||
|
precision datatype for operations like linear layers and convolutions.
|
||||||
|
|
||||||
|
.. tabbed:: PyTorch
|
||||||
|
|
||||||
|
You can train your Torch model with AMP by:
|
||||||
|
|
||||||
|
1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function.
|
||||||
|
2. Wrapping your optimizer with ``train.torch.prepare_optimizer``.
|
||||||
|
3. Replacing your backward call with ``train.torch.backward``.
|
||||||
|
|
||||||
|
.. code-block:: diff
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
+ train.torch.accelerate(amp=True)
|
||||||
|
|
||||||
|
model = NeuralNetwork()
|
||||||
|
model = train.torch.prepare_model(model)
|
||||||
|
|
||||||
|
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
|
||||||
|
data_loader = train.torch.prepare_data_loader(data_loader)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||||
|
+ optimizer = train.torch.prepare_optimizer(optimizer)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for epoch in range(90):
|
||||||
|
for images, targets in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(images)
|
||||||
|
loss = torch.nn.functional.cross_entropy(outputs, targets)
|
||||||
|
|
||||||
|
- loss.backward()
|
||||||
|
+ train.torch.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
.. note:: The performance of AMP varies based on GPU architecture, model type,
|
||||||
|
and data shape. For certain workflows, AMP may perform worse than
|
||||||
|
full-precision training.
|
||||||
|
|
||||||
.. _train-reproducibility:
|
.. _train-reproducibility:
|
||||||
|
|
||||||
Reproducibility
|
Reproducibility
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
@ -127,6 +128,65 @@ def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu):
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_amp(ray_start_4_cpus_2_gpus):
|
||||||
|
def train_func(config):
|
||||||
|
train.torch.accelerate(amp=config["amp"])
|
||||||
|
|
||||||
|
model = torchvision.models.resnet101()
|
||||||
|
model = train.torch.prepare_model(model)
|
||||||
|
|
||||||
|
dataset_length = 1000
|
||||||
|
dataset = torch.utils.data.TensorDataset(
|
||||||
|
torch.randn(dataset_length, 3, 224, 224),
|
||||||
|
torch.randint(low=0, high=1000, size=(dataset_length,)),
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
|
||||||
|
dataloader = train.torch.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||||
|
optimizer = train.torch.prepare_optimizer(optimizer)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for epoch in range(1):
|
||||||
|
for images, targets in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(images)
|
||||||
|
loss = torch.nn.functional.cross_entropy(outputs, targets)
|
||||||
|
|
||||||
|
train.torch.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def latency(amp: bool) -> float:
|
||||||
|
trainer = Trainer("torch", num_workers=2, use_gpu=True)
|
||||||
|
trainer.start()
|
||||||
|
start_time = timer()
|
||||||
|
trainer.run(train_func, {"amp": amp})
|
||||||
|
end_time = timer()
|
||||||
|
trainer.shutdown()
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
# Training should be at least 5% faster with AMP.
|
||||||
|
assert 1.05 * latency(amp=True) < latency(amp=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_torch_model_with_amp(ray_start_4_cpus_2_gpus):
|
||||||
|
"""Test that model with AMP is serializable."""
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
train.torch.accelerate(amp=True)
|
||||||
|
|
||||||
|
model = torchvision.models.resnet101()
|
||||||
|
model = train.torch.prepare_model(model)
|
||||||
|
|
||||||
|
train.save_checkpoint(model=model)
|
||||||
|
|
||||||
|
trainer = Trainer("torch", num_workers=1, use_gpu=True)
|
||||||
|
trainer.start()
|
||||||
|
trainer.run(train_func)
|
||||||
|
trainer.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus):
|
def test_torch_auto_gpu_to_cpu(ray_start_4_cpus_2_gpus):
|
||||||
"""Tests if GPU tensors are auto converted to CPU on driver."""
|
"""Tests if GPU tensors are auto converted to CPU on driver."""
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import functools
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import types
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -14,6 +16,7 @@ from ray import train
|
||||||
from ray.train.accelerator import Accelerator
|
from ray.train.accelerator import Accelerator
|
||||||
from ray.train.backend import BackendConfig, Backend, EncodedData
|
from ray.train.backend import BackendConfig, Backend, EncodedData
|
||||||
from ray.train.constants import PYTORCH_PROFILER_KEY
|
from ray.train.constants import PYTORCH_PROFILER_KEY
|
||||||
|
from torch.optim import Optimizer
|
||||||
from ray.train.session import get_accelerator, set_accelerator
|
from ray.train.session import get_accelerator, set_accelerator
|
||||||
from ray.train.worker_group import WorkerGroup
|
from ray.train.worker_group import WorkerGroup
|
||||||
from ray.train.utils import get_address_and_port
|
from ray.train.utils import get_address_and_port
|
||||||
|
@ -21,6 +24,7 @@ from ray.util import PublicAPI
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
|
@ -40,9 +44,16 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchAccelerator(Accelerator):
|
class TorchAccelerator(Accelerator):
|
||||||
"""A utility that implements methods to accelerate PyTorch training."""
|
"""A utility that implements methods to accelerate PyTorch training.
|
||||||
|
|
||||||
def __init__(self):
|
Arguments:
|
||||||
|
amp (bool): If true, perform training with automatic mixed precision.
|
||||||
|
Otherwise, use full precision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, amp: bool = False):
|
||||||
|
self.amp_is_enabled = amp
|
||||||
|
self.scaler = GradScaler() if amp else None
|
||||||
self._seed = None
|
self._seed = None
|
||||||
|
|
||||||
def prepare_model(
|
def prepare_model(
|
||||||
|
@ -80,6 +91,37 @@ class TorchAccelerator(Accelerator):
|
||||||
if move_to_device:
|
if move_to_device:
|
||||||
logger.info(f"Moving model to device: {device}")
|
logger.info(f"Moving model to device: {device}")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
def wrap_forward(forward):
|
||||||
|
@functools.wraps(forward)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with autocast():
|
||||||
|
outputs = forward(*args, **kwargs)
|
||||||
|
assert isinstance(outputs, torch.Tensor)
|
||||||
|
return outputs.float()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def model_get_state(self):
|
||||||
|
# `__getstate__` is an special method that informs pickle which attributes
|
||||||
|
# to serialize. This custom implementation ensures that the wrapped forward
|
||||||
|
# method and custom `__getstate__` method aren't serialized.
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["forward"] = state["_unwrapped_forward"]
|
||||||
|
del state["_unwrapped_forward"]
|
||||||
|
del state["__getstate__"]
|
||||||
|
return state
|
||||||
|
|
||||||
|
if self.amp_is_enabled:
|
||||||
|
# Pickle cannot serialize the wrapped forward method. As a workaround,
|
||||||
|
# define a custom `__getstate__` method that unwraps the forward method.
|
||||||
|
model._unwrapped_forward = model.forward
|
||||||
|
model.forward = wrap_forward(model.forward)
|
||||||
|
# `__getstate__` must be a bound method rather than an callable attribute.
|
||||||
|
# See https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance. # noqa: E501
|
||||||
|
assert not hasattr(model, "__getstate__")
|
||||||
|
model.__getstate__ = types.MethodType(model_get_state, model)
|
||||||
|
|
||||||
if wrap_ddp and train.world_size() > 1:
|
if wrap_ddp and train.world_size() > 1:
|
||||||
logger.info("Wrapping provided model in DDP.")
|
logger.info("Wrapping provided model in DDP.")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -206,6 +248,28 @@ class TorchAccelerator(Accelerator):
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
def prepare_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||||
|
"""Wraps optimizer to support automatic mixed precision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A wrapped optimizer.
|
||||||
|
"""
|
||||||
|
return _WrappedOptimizer(optimizer, scaler=self.scaler)
|
||||||
|
|
||||||
|
def backward(self, tensor: torch.Tensor) -> None:
|
||||||
|
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): Tensor of which the derivative will be computed.
|
||||||
|
"""
|
||||||
|
if self.amp_is_enabled:
|
||||||
|
self.scaler.scale(tensor).backward()
|
||||||
|
else:
|
||||||
|
tensor.backward()
|
||||||
|
|
||||||
def enable_reproducibility(self, seed: int = 0) -> None:
|
def enable_reproducibility(self, seed: int = 0) -> None:
|
||||||
"""Limits sources of nondeterministic behavior."""
|
"""Limits sources of nondeterministic behavior."""
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
@ -451,6 +515,55 @@ class _WrappedDataLoader(DataLoader):
|
||||||
return next_batch
|
return next_batch
|
||||||
|
|
||||||
|
|
||||||
|
class _WrappedOptimizer(Optimizer):
|
||||||
|
def __init__(self, optimizer: Optimizer, scaler: Optional[GradScaler] = None):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.scaler = scaler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self.optimizer.state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, state):
|
||||||
|
self.optimizer.state = state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
@param_groups.setter
|
||||||
|
def param_groups(self, param_groups):
|
||||||
|
self.optimizer.param_groups = param_groups
|
||||||
|
|
||||||
|
@property
|
||||||
|
def defaults(self):
|
||||||
|
return self.optimizer.defaults
|
||||||
|
|
||||||
|
@defaults.setter
|
||||||
|
def defaults(self, defaults):
|
||||||
|
self.optimizer.defaults = defaults
|
||||||
|
|
||||||
|
def add_param_group(self, param_group):
|
||||||
|
self.optimizer.add_param_group(param_group)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.optimizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optimizer.state_dict()
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
if self.scaler is not None:
|
||||||
|
self.scaler.step(self.optimizer, closure)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step(closure)
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
def get_device() -> torch.device:
|
def get_device() -> torch.device:
|
||||||
"""Gets the correct torch device to use for training."""
|
"""Gets the correct torch device to use for training."""
|
||||||
|
@ -523,10 +636,18 @@ def prepare_data_loader(
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
def accelerate() -> None:
|
def accelerate(amp: bool = False) -> None:
|
||||||
"""Enables training optimizations."""
|
"""Enables training optimizations.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
amp (bool): If true, perform training with automatic mixed precision.
|
||||||
|
Otherwise, use full precision.
|
||||||
|
|
||||||
|
.. warning:: ``train.torch.accelerate`` cannot be called more than once, and it
|
||||||
|
must be called before any other ``train.torch`` utility function.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
set_accelerator(TorchAccelerator())
|
set_accelerator(TorchAccelerator(amp=amp))
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"An accelerator has already been set. Make sure "
|
"An accelerator has already been set. Make sure "
|
||||||
|
@ -536,6 +657,28 @@ def accelerate() -> None:
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
|
def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
|
||||||
|
"""Wraps optimizer to support automatic mixed precision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): The DataLoader to prepare.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A wrapped optimizer.
|
||||||
|
"""
|
||||||
|
return get_accelerator(TorchAccelerator).prepare_optimizer(optimizer)
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="beta")
|
||||||
|
def backward(tensor: torch.Tensor) -> None:
|
||||||
|
"""Computes the gradient of the specified tensor w.r.t. graph leaves.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): Tensor of which the derivative will be computed.
|
||||||
|
"""
|
||||||
|
get_accelerator(TorchAccelerator).backward(tensor)
|
||||||
|
|
||||||
|
|
||||||
def enable_reproducibility(seed: int = 0) -> None:
|
def enable_reproducibility(seed: int = 0) -> None:
|
||||||
"""Limits sources of nondeterministic behavior.
|
"""Limits sources of nondeterministic behavior.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue