From 8fff66545570b45fa26f32df63ccb815ae00ee53 Mon Sep 17 00:00:00 2001 From: Junwen Yao Date: Fri, 18 Mar 2022 13:27:26 -0700 Subject: [PATCH] [Train] Add torch data prefetch benchmark example (#22974) Add a benchmark example for the auto pipeline functionality for host to device data transfer. --- doc/source/train/api.rst | 2 + doc/source/train/examples.rst | 7 + .../benchmark_example.rst | 52 ++++++ .../torch_data_prefetch_benchmark/__init__.py | 0 ...peline_for_host_to_device_data_transfer.py | 152 ++++++++++++++++++ 5 files changed, 213 insertions(+) create mode 100644 doc/source/train/examples/torch_data_prefetch_benchmark/benchmark_example.rst create mode 100644 python/ray/train/examples/torch_data_prefetch_benchmark/__init__.py create mode 100644 python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py diff --git a/doc/source/train/api.rst b/doc/source/train/api.rst index d7bad754e..66c875836 100644 --- a/doc/source/train/api.rst +++ b/doc/source/train/api.rst @@ -195,6 +195,8 @@ train.torch.prepare_model .. autofunction:: ray.train.torch.prepare_model +.. _train-api-torch-prepare-data-loader: + train.torch.prepare_data_loader ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/train/examples.rst b/doc/source/train/examples.rst index dc691df9c..f13464587 100644 --- a/doc/source/train/examples.rst +++ b/doc/source/train/examples.rst @@ -91,3 +91,10 @@ Ray Tune Integration Examples ------ * Example training on Vision model. + +Benchmarks +---------- + +* :doc:`/train/examples/torch_data_prefetch_benchmark/benchmark_example`: + Benchmark example for the PyTorch data transfer auto pipeline. + diff --git a/doc/source/train/examples/torch_data_prefetch_benchmark/benchmark_example.rst b/doc/source/train/examples/torch_data_prefetch_benchmark/benchmark_example.rst new file mode 100644 index 000000000..acb6df2aa --- /dev/null +++ b/doc/source/train/examples/torch_data_prefetch_benchmark/benchmark_example.rst @@ -0,0 +1,52 @@ +:orphan: + +Torch Data Prefetching Benchmark +================================ + +We provide a benchmark example to show how the auto pipeline for host to device data transfer speeds up training on GPUs. +This functionality can be easily enabled by setting ``auto_transfer=True`` in :ref:`train.torch.prepare_data_loader() `. + +.. code-block:: python + + from torch.utils.data import DataLoader + from ray import train + + ... + + data_loader = DataLoader(my_dataset, batch_size) + train_loader = train.torch.prepare_data_loader( + data_loader=train_loader, move_to_device=True, auto_transfer=True + ) + + +Running the following command gives the runtime of a small model training with and without the auto pipeline functionality. +The experiment size can be modified by setting different values for ``epochs`` and ``num_hidden_layers``, e.g., + +.. code-block:: bash + + python auto_pipeline_for_host_to_device_data_transfer.py --epochs 2 --num_hidden_layers 2 + + +The table below displays the runtime in seconds (excluding preparation work) under different configurations. +The first value in the parentheses reports the runtime of using the auto pipeline, and the second reports the time of not using it. +These experiments were done on a NVIDIA 2080 Ti. +The auto pipeline functionality offers more speed improvement when the model size and the number of epochs gets larger. +(The actual runtime outputs may vary if these experiments are run locally or different hardware devices are used.) + + +========== =================== ======================== ======================== + `epochs` `num_of_layers` `auto_transfer=False` `auto_transfer=True` +========== =================== ======================== ======================== + 1 1 2.69 2.52 + 1 4 7.21 6.85 + 1 8 13.54 13.05 + 5 1 12.88 12.14 + 5 4 36.48 34.33 + 5 8 69.12 66.38 + 50 1 132.88 123.12 + 50 4 381.67 369.42 + 50 8 736.17 693.52 +========== =================== ======================== ======================== + + +.. literalinclude:: /../../python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py diff --git a/python/ray/train/examples/torch_data_prefetch_benchmark/__init__.py b/python/ray/train/examples/torch_data_prefetch_benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py b/python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py new file mode 100644 index 000000000..c8cc25b04 --- /dev/null +++ b/python/ray/train/examples/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py @@ -0,0 +1,152 @@ +# The PyTorch data transfer benchmark script. +import argparse +import warnings + +import numpy as np +import torch +import torch.nn as nn +import ray.train as train +from ray.train import Trainer + + +class Net(nn.Module): + def __init__(self, in_d, hidden): + # output dim = 1 + super(Net, self).__init__() + dims = [in_d] + hidden + [1] + self.layers = nn.ModuleList( + [nn.Linear(dims[i - 1], dims[i]) for i in range(len(dims))] + ) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class BenchmarkDataset(torch.utils.data.Dataset): + """Create a naive dataset for the benchmark""" + + def __init__(self, dim, size=1000): + self.x = torch.from_numpy(np.random.normal(size=(size, dim))).float() + self.y = torch.from_numpy(np.random.normal(size=(size, 1))).float() + self.size = size + + def __getitem__(self, index): + return self.x[index, None], self.y[index, None] + + def __len__(self): + return self.size + + +def train_epoch(dataloader, model, loss_fn, optimizer): + for X, y in dataloader: + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def train_func(config): + data_size = config.get("data_size", 4096 * 50) + batch_size = config.get("batch_size", 4096) + hidden_size = config.get("hidden_size", 1) + use_auto_transfer = config.get("use_auto_transfer", False) + lr = config.get("lr", 1e-2) + epochs = config.get("epochs", 10) + + train_dataset = BenchmarkDataset(4096, size=data_size) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + + train_loader = train.torch.prepare_data_loader( + data_loader=train_loader, move_to_device=True, auto_transfer=use_auto_transfer + ) + + model = Net(in_d=4096, hidden=[4096] * hidden_size) + model = train.torch.prepare_model(model) + + loss_fn = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + choice = "with" if use_auto_transfer else "without" + print(f"Starting the torch data prefetch benchmark {choice} auto pipeline...") + + torch.cuda.synchronize() + start.record() + for _ in range(epochs): + train_epoch(train_loader, model, loss_fn, optimizer) + end.record() + torch.cuda.synchronize() + + print( + f"Finished the torch data prefetch benchmark {choice} " + f"auto pipeline: {start.elapsed_time(end)} ms." + ) + + return "Experiment done." + + +def train_linear(num_workers=1, num_hidden_layers=1, use_auto_transfer=True, epochs=3): + trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=True) + config = { + "lr": 1e-2, + "hidden_size": num_hidden_layers, + "batch_size": 4096, + "epochs": epochs, + "use_auto_transfer": use_auto_transfer, + } + trainer.start() + results = trainer.run(train_func, config) + trainer.shutdown() + + print(results) + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", required=False, type=str, help="the address to use for Ray" + ) + parser.add_argument( + "--epochs", type=int, default=1, help="Number of epochs to train for." + ) + parser.add_argument( + "--num_hidden_layers", + type=int, + default=1, + help="Number of epochs to train for.", + ) + + args, _ = parser.parse_known_args() + + import ray + + ray.init(address=args.address) + + if not torch.cuda.is_available(): + warnings.warn("GPU is not available. Skip the test using auto pipeline.") + else: + train_linear( + num_workers=1, + num_hidden_layers=args.num_hidden_layers, + use_auto_transfer=True, + epochs=args.epochs, + ) + + torch.cuda.empty_cache() + train_linear( + num_workers=1, + num_hidden_layers=args.num_hidden_layers, + use_auto_transfer=False, + epochs=args.epochs, + ) + + ray.shutdown()