mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Train] Add torch data prefetch benchmark example (#22974)
Add a benchmark example for the auto pipeline functionality for host to device data transfer.
This commit is contained in:
parent
c4b52d34ca
commit
8fff665455
5 changed files with 213 additions and 0 deletions
|
@ -195,6 +195,8 @@ train.torch.prepare_model
|
|||
|
||||
.. autofunction:: ray.train.torch.prepare_model
|
||||
|
||||
.. _train-api-torch-prepare-data-loader:
|
||||
|
||||
train.torch.prepare_data_loader
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -91,3 +91,10 @@ Ray Tune Integration Examples
|
|||
------
|
||||
|
||||
* Example training on Vision model.
|
||||
|
||||
Benchmarks
|
||||
----------
|
||||
|
||||
* :doc:`/train/examples/torch_data_prefetch_benchmark/benchmark_example`:
|
||||
Benchmark example for the PyTorch data transfer auto pipeline.
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
:orphan:
|
||||
|
||||
Torch Data Prefetching Benchmark
|
||||
================================
|
||||
|
||||
We provide a benchmark example to show how the auto pipeline for host to device data transfer speeds up training on GPUs.
|
||||
This functionality can be easily enabled by setting ``auto_transfer=True`` in :ref:`train.torch.prepare_data_loader() <train-api-torch-prepare-data-loader>`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from ray import train
|
||||
|
||||
...
|
||||
|
||||
data_loader = DataLoader(my_dataset, batch_size)
|
||||
train_loader = train.torch.prepare_data_loader(
|
||||
data_loader=train_loader, move_to_device=True, auto_transfer=True
|
||||
)
|
||||
|
||||
|
||||
Running the following command gives the runtime of a small model training with and without the auto pipeline functionality.
|
||||
The experiment size can be modified by setting different values for ``epochs`` and ``num_hidden_layers``, e.g.,
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python auto_pipeline_for_host_to_device_data_transfer.py --epochs 2 --num_hidden_layers 2
|
||||
|
||||
|
||||
The table below displays the runtime in seconds (excluding preparation work) under different configurations.
|
||||
The first value in the parentheses reports the runtime of using the auto pipeline, and the second reports the time of not using it.
|
||||
These experiments were done on a NVIDIA 2080 Ti.
|
||||
The auto pipeline functionality offers more speed improvement when the model size and the number of epochs gets larger.
|
||||
(The actual runtime outputs may vary if these experiments are run locally or different hardware devices are used.)
|
||||
|
||||
|
||||
========== =================== ======================== ========================
|
||||
`epochs` `num_of_layers` `auto_transfer=False` `auto_transfer=True`
|
||||
========== =================== ======================== ========================
|
||||
1 1 2.69 2.52
|
||||
1 4 7.21 6.85
|
||||
1 8 13.54 13.05
|
||||
5 1 12.88 12.14
|
||||
5 4 36.48 34.33
|
||||
5 8 69.12 66.38
|
||||
50 1 132.88 123.12
|
||||
50 4 381.67 369.42
|
||||
50 8 736.17 693.52
|
||||
========== =================== ======================== ========================
|
||||
|
||||
|
||||
.. literalinclude:: /../../python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py
|
|
@ -0,0 +1,152 @@
|
|||
# The PyTorch data transfer benchmark script.
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ray.train as train
|
||||
from ray.train import Trainer
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, in_d, hidden):
|
||||
# output dim = 1
|
||||
super(Net, self).__init__()
|
||||
dims = [in_d] + hidden + [1]
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(dims[i - 1], dims[i]) for i in range(len(dims))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class BenchmarkDataset(torch.utils.data.Dataset):
|
||||
"""Create a naive dataset for the benchmark"""
|
||||
|
||||
def __init__(self, dim, size=1000):
|
||||
self.x = torch.from_numpy(np.random.normal(size=(size, dim))).float()
|
||||
self.y = torch.from_numpy(np.random.normal(size=(size, 1))).float()
|
||||
self.size = size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.x[index, None], self.y[index, None]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
def train_epoch(dataloader, model, loss_fn, optimizer):
|
||||
for X, y in dataloader:
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def train_func(config):
|
||||
data_size = config.get("data_size", 4096 * 50)
|
||||
batch_size = config.get("batch_size", 4096)
|
||||
hidden_size = config.get("hidden_size", 1)
|
||||
use_auto_transfer = config.get("use_auto_transfer", False)
|
||||
lr = config.get("lr", 1e-2)
|
||||
epochs = config.get("epochs", 10)
|
||||
|
||||
train_dataset = BenchmarkDataset(4096, size=data_size)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
|
||||
|
||||
train_loader = train.torch.prepare_data_loader(
|
||||
data_loader=train_loader, move_to_device=True, auto_transfer=use_auto_transfer
|
||||
)
|
||||
|
||||
model = Net(in_d=4096, hidden=[4096] * hidden_size)
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
choice = "with" if use_auto_transfer else "without"
|
||||
print(f"Starting the torch data prefetch benchmark {choice} auto pipeline...")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start.record()
|
||||
for _ in range(epochs):
|
||||
train_epoch(train_loader, model, loss_fn, optimizer)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(
|
||||
f"Finished the torch data prefetch benchmark {choice} "
|
||||
f"auto pipeline: {start.elapsed_time(end)} ms."
|
||||
)
|
||||
|
||||
return "Experiment done."
|
||||
|
||||
|
||||
def train_linear(num_workers=1, num_hidden_layers=1, use_auto_transfer=True, epochs=3):
|
||||
trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=True)
|
||||
config = {
|
||||
"lr": 1e-2,
|
||||
"hidden_size": num_hidden_layers,
|
||||
"batch_size": 4096,
|
||||
"epochs": epochs,
|
||||
"use_auto_transfer": use_auto_transfer,
|
||||
}
|
||||
trainer.start()
|
||||
results = trainer.run(train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
print(results)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=1, help="Number of epochs to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_hidden_layers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs to train for.",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
|
||||
ray.init(address=args.address)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn("GPU is not available. Skip the test using auto pipeline.")
|
||||
else:
|
||||
train_linear(
|
||||
num_workers=1,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
use_auto_transfer=True,
|
||||
epochs=args.epochs,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
train_linear(
|
||||
num_workers=1,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
use_auto_transfer=False,
|
||||
epochs=args.epochs,
|
||||
)
|
||||
|
||||
ray.shutdown()
|
Loading…
Add table
Reference in a new issue