[Train] Add torch data prefetch benchmark example (#22974)

Add a benchmark example for the auto pipeline functionality for host to device data transfer.
This commit is contained in:
Junwen Yao 2022-03-18 13:27:26 -07:00 committed by GitHub
parent c4b52d34ca
commit 8fff665455
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 213 additions and 0 deletions

View file

@ -195,6 +195,8 @@ train.torch.prepare_model
.. autofunction:: ray.train.torch.prepare_model
.. _train-api-torch-prepare-data-loader:
train.torch.prepare_data_loader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -91,3 +91,10 @@ Ray Tune Integration Examples
------
* Example training on Vision model.
Benchmarks
----------
* :doc:`/train/examples/torch_data_prefetch_benchmark/benchmark_example`:
Benchmark example for the PyTorch data transfer auto pipeline.

View file

@ -0,0 +1,52 @@
:orphan:
Torch Data Prefetching Benchmark
================================
We provide a benchmark example to show how the auto pipeline for host to device data transfer speeds up training on GPUs.
This functionality can be easily enabled by setting ``auto_transfer=True`` in :ref:`train.torch.prepare_data_loader() <train-api-torch-prepare-data-loader>`.
.. code-block:: python
from torch.utils.data import DataLoader
from ray import train
...
data_loader = DataLoader(my_dataset, batch_size)
train_loader = train.torch.prepare_data_loader(
data_loader=train_loader, move_to_device=True, auto_transfer=True
)
Running the following command gives the runtime of a small model training with and without the auto pipeline functionality.
The experiment size can be modified by setting different values for ``epochs`` and ``num_hidden_layers``, e.g.,
.. code-block:: bash
python auto_pipeline_for_host_to_device_data_transfer.py --epochs 2 --num_hidden_layers 2
The table below displays the runtime in seconds (excluding preparation work) under different configurations.
The first value in the parentheses reports the runtime of using the auto pipeline, and the second reports the time of not using it.
These experiments were done on a NVIDIA 2080 Ti.
The auto pipeline functionality offers more speed improvement when the model size and the number of epochs gets larger.
(The actual runtime outputs may vary if these experiments are run locally or different hardware devices are used.)
========== =================== ======================== ========================
`epochs` `num_of_layers` `auto_transfer=False` `auto_transfer=True`
========== =================== ======================== ========================
1 1 2.69 2.52
1 4 7.21 6.85
1 8 13.54 13.05
5 1 12.88 12.14
5 4 36.48 34.33
5 8 69.12 66.38
50 1 132.88 123.12
50 4 381.67 369.42
50 8 736.17 693.52
========== =================== ======================== ========================
.. literalinclude:: /../../python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py

View file

@ -0,0 +1,152 @@
# The PyTorch data transfer benchmark script.
import argparse
import warnings
import numpy as np
import torch
import torch.nn as nn
import ray.train as train
from ray.train import Trainer
class Net(nn.Module):
def __init__(self, in_d, hidden):
# output dim = 1
super(Net, self).__init__()
dims = [in_d] + hidden + [1]
self.layers = nn.ModuleList(
[nn.Linear(dims[i - 1], dims[i]) for i in range(len(dims))]
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class BenchmarkDataset(torch.utils.data.Dataset):
"""Create a naive dataset for the benchmark"""
def __init__(self, dim, size=1000):
self.x = torch.from_numpy(np.random.normal(size=(size, dim))).float()
self.y = torch.from_numpy(np.random.normal(size=(size, 1))).float()
self.size = size
def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
def __len__(self):
return self.size
def train_epoch(dataloader, model, loss_fn, optimizer):
for X, y in dataloader:
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
def train_func(config):
data_size = config.get("data_size", 4096 * 50)
batch_size = config.get("batch_size", 4096)
hidden_size = config.get("hidden_size", 1)
use_auto_transfer = config.get("use_auto_transfer", False)
lr = config.get("lr", 1e-2)
epochs = config.get("epochs", 10)
train_dataset = BenchmarkDataset(4096, size=data_size)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
train_loader = train.torch.prepare_data_loader(
data_loader=train_loader, move_to_device=True, auto_transfer=use_auto_transfer
)
model = Net(in_d=4096, hidden=[4096] * hidden_size)
model = train.torch.prepare_model(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
choice = "with" if use_auto_transfer else "without"
print(f"Starting the torch data prefetch benchmark {choice} auto pipeline...")
torch.cuda.synchronize()
start.record()
for _ in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
end.record()
torch.cuda.synchronize()
print(
f"Finished the torch data prefetch benchmark {choice} "
f"auto pipeline: {start.elapsed_time(end)} ms."
)
return "Experiment done."
def train_linear(num_workers=1, num_hidden_layers=1, use_auto_transfer=True, epochs=3):
trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=True)
config = {
"lr": 1e-2,
"hidden_size": num_hidden_layers,
"batch_size": 4096,
"epochs": epochs,
"use_auto_transfer": use_auto_transfer,
}
trainer.start()
results = trainer.run(train_func, config)
trainer.shutdown()
print(results)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--epochs", type=int, default=1, help="Number of epochs to train for."
)
parser.add_argument(
"--num_hidden_layers",
type=int,
default=1,
help="Number of epochs to train for.",
)
args, _ = parser.parse_known_args()
import ray
ray.init(address=args.address)
if not torch.cuda.is_available():
warnings.warn("GPU is not available. Skip the test using auto pipeline.")
else:
train_linear(
num_workers=1,
num_hidden_layers=args.num_hidden_layers,
use_auto_transfer=True,
epochs=args.epochs,
)
torch.cuda.empty_cache()
train_linear(
num_workers=1,
num_hidden_layers=args.num_hidden_layers,
use_auto_transfer=False,
epochs=args.epochs,
)
ray.shutdown()