[AIR] Update Torch benchmarks with documentation (#26631)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Jiao 2022-07-16 17:58:21 -07:00 committed by GitHub
parent ef091c382e
commit 77e2ef2eb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 61 deletions

View file

@ -103,8 +103,67 @@ XGBoost parameters were kept as defaults for xgboost==1.6.1 this task.
- `python xgboost_benchmark.py --size 100GB`
GPU image batch prediction
----------------------------------------------------
This task uses the BatchPredictor module to process different amounts of data
using a Pytorch pre-trained ResNet model.
We test out the performance across different cluster sizes and data sizes.
- `GPU image batch prediction script`_
.. list-table::
* - **Cluster Setup**
- **Data Size**
- **Performance**
- **Command**
* - 1 g3.8xlarge node
- 1 GB (1623 images)
- 72.59 s (22.3 images/sec)
- `python gpu_batch_prediction.py --data-size-gb=1`
* - 1 g3.8xlarge node
- 20 GB (32460 images)
- 1213.48 s (26.76 images/sec)
- `python gpu_batch_prediction.py --data-size-gb=20`
* - 8 g3.8xlarge node
- 100 GB (162300 images)
- 784.91 s (206.78 images/sec)
- `python gpu_batch_prediction.py --data-size-gb=100`
GPU image training
------------------------
This task uses the TorchTrainer module to train different amounts of data
using an Pytorch ResNet model.
We test out the performance across different cluster sizes and data sizes.
- `GPU image training script`_
.. list-table::
* - **Cluster Setup**
- **Data Size**
- **Performance**
- **Command**
* - 1 g3.8xlarge node (1 worker)
- 1 GB (1623 images)
- 79.76 s (2 epochs, 40.7 images/sec)
- `python pytorch_training_e2e.py --data-size-gb=1`
* - 1 g3.8xlarge node (1 worker)
- 20 GB (32460 images)
- 1388.33 s (2 epochs, 46.76 images/sec)
- `python pytorch_training_e2e.py --data-size-gb=20`
.. _`Bulk Ingest Script`: https://github.com/ray-project/ray/blob/a30bdf9ef34a45f973b589993f7707a763df6ebf/release/air_tests/air_benchmarks/workloads/data_benchmark.py#L25-L40
.. _`Bulk Ingest Cluster Configuration`: https://github.com/ray-project/ray/blob/a30bdf9ef34a45f973b589993f7707a763df6ebf/release/air_tests/air_benchmarks/data_20_nodes.yaml#L6-L15
.. _`XGBoost Training Script`: https://github.com/ray-project/ray/blob/a241e6a0f5a630d6ed5b84cce30c51963834d15b/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py#L40-L58
.. _`XGBoost Prediction Script`: https://github.com/ray-project/ray/blob/a241e6a0f5a630d6ed5b84cce30c51963834d15b/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py#L63-L71
.. _`XGBoost Cluster Configuration`: https://github.com/ray-project/ray/blob/a241e6a0f5a630d6ed5b84cce30c51963834d15b/release/air_tests/air_benchmarks/xgboost_compute_tpl.yaml#L6-L24
.. _`GPU image batch prediction script`: https://github.com/ray-project/ray/blob/cec82a1ced631525a4d115e4dc0c283fa4275a7f/release/air_tests/air_benchmarks/workloads/gpu_batch_prediction.py#L18-L49
.. _`GPU image training script`: https://github.com/ray-project/ray/blob/cec82a1ced631525a4d115e4dc0c283fa4275a7f/release/air_tests/air_benchmarks/workloads/pytorch_training_e2e.py#L95-L106

View file

@ -2,12 +2,8 @@ import click
import time
import json
import os
import numpy as np
import pandas as pd
from io import BytesIO
from typing import List
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18
@ -16,26 +12,13 @@ from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.train.torch import to_air_checkpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.data.preprocessors import BatchMapper
# TODO(jiaodong): Remove this once ImageFolder #24641 merges
def convert_to_pandas(byte_item_list: List[bytes]) -> pd.DataFrame:
"""
Convert input bytes into pandas DataFrame with image column and value of
TensorArray to prevent serializing ndarray image data.
"""
images = [
Image.open(BytesIO(byte_item)).convert("RGB") for byte_item in byte_item_list
]
images = [np.asarray(image) for image in images]
return pd.DataFrame({"image": TensorArray(images)})
from ray.data.datasource import ImageFolderDatasource
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
"""
User Pytorch code to transform user image. Note we still use pandas as
intermediate format to hold images as shorthand of python dictionary.
User Pytorch code to transform user image. Note we still use TensorArray as
intermediate format to hold images for now.
"""
preprocess = transforms.Compose(
[
@ -45,9 +28,7 @@ def preprocess(df: pd.DataFrame) -> pd.DataFrame:
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
df["image"] = df["image"].map(preprocess)
df["image"] = df["image"].map(lambda x: x.numpy())
df["image"] = TensorArray(df["image"])
df["image"] = TensorArray([preprocess(image.to_numpy()) for image in df["image"]])
return df
@ -57,9 +38,7 @@ def main(data_size_gb: int):
data_url = f"s3://air-example-data-2/{data_size_gb}G-image-data-synthetic-raw"
print(f"Running GPU batch prediction with {data_size_gb}GB data from {data_url}")
start = time.time()
dataset = ray.data.read_binary_files(paths=data_url)
# TODO(jiaodong): Remove this once ImageFolder #24641 merges
dataset = dataset.map_batches(convert_to_pandas)
dataset = ray.data.read_datasource(ImageFolderDatasource(), paths=[data_url])
model = resnet18(pretrained=True)
@ -67,7 +46,7 @@ def main(data_size_gb: int):
ckpt = to_air_checkpoint(model=model, preprocessor=preprocessor)
predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
predictor.predict(dataset, num_gpus_per_worker=1)
predictor.predict(dataset, num_gpus_per_worker=1, feature_columns=["image"])
total_time_s = round(time.time() - start, 2)
# For structured output integration with internal tooling

View file

@ -2,12 +2,8 @@ import click
import time
import json
import os
import numpy as np
import pandas as pd
from io import BytesIO
from typing import List
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18
import torch
@ -21,28 +17,13 @@ from ray.data.preprocessors import BatchMapper
from ray import train
from ray.air import session
from ray.train.torch import TorchTrainer
from ray.data.datasource import ImageFolderDatasource
# TODO(jiaodong): Remove this once ImageFolder #24641 merges
def convert_to_pandas(byte_item_list: List[bytes]) -> pd.DataFrame:
def preprocess_image_with_label(df: pd.DataFrame) -> pd.DataFrame:
"""
Convert input bytes into pandas DataFrame with image column and value of
TensorArray to prevent serializing ndarray image data.
"""
images = [
Image.open(BytesIO(byte_item)).convert("RGB") for byte_item in byte_item_list
]
images = [np.asarray(image) for image in images]
# Dummy label since we're only testing training throughput
labels = [1 for _ in range(len(images))]
return pd.DataFrame({"image": TensorArray(images), "label": labels})
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
"""
User Pytorch code to transform user image. Note we still use pandas as
intermediate format to hold images as shorthand of python dictionary.
User Pytorch code to transform user image. Note we still use TensorArray as
intermediate format to hold images for now.
"""
preprocess = transforms.Compose(
[
@ -52,9 +33,9 @@ def preprocess(df: pd.DataFrame) -> pd.DataFrame:
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
df["image"] = df["image"].map(preprocess)
df["image"] = df["image"].map(lambda x: x.numpy())
df["image"] = TensorArray(df["image"])
df["image"] = TensorArray([preprocess(image.to_numpy()) for image in df["image"]])
# Fix fixed synthetic value for perf benchmark purpose
df["label"] = df["label"].map(lambda _: 1)
return df
@ -101,27 +82,26 @@ def train_loop_per_worker(config):
@click.command(help="Run Batch prediction on Pytorch ResNet models.")
@click.option("--data-size-gb", type=int, default=1)
@click.option("--num-epochs", type=int, default=10)
def main(data_size_gb: int, num_epochs=10):
@click.option("--num-epochs", type=int, default=2)
@click.option("--num-workers", type=int, default=1)
def main(data_size_gb: int, num_epochs=2, num_workers=1):
data_url = f"s3://air-example-data-2/{data_size_gb}G-image-data-synthetic-raw"
print(
"Running Pytorch image model training with "
f"{data_size_gb}GB data from {data_url}"
)
print(f"Training for {num_epochs} epochs.")
print(f"Training for {num_epochs} epochs with {num_workers} workers.")
start = time.time()
dataset = ray.data.read_binary_files(paths=data_url)
# TODO(jiaodong): Remove this once ImageFolder #24641 merges
dataset = dataset.map_batches(convert_to_pandas)
dataset = ray.data.read_datasource(ImageFolderDatasource(), paths=[data_url])
preprocessor = BatchMapper(preprocess)
preprocessor = BatchMapper(preprocess_image_with_label)
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"batch_size": 64, "num_epochs": num_epochs},
datasets={"train": dataset},
preprocessor=preprocessor,
scaling_config={"num_workers": 1, "use_gpu": True},
scaling_config={"num_workers": num_workers, "use_gpu": True},
)
trainer.fit()