ray/doc/source/ray-air/examples/torch_image_batch_pretrained.py
Balaji Veeramani 89f7f2a567
[Datasets] Add size parameter to ImageFolderDatasource (#26975)
If you read a folder with differently-sized images, `ImageFolderDatasource` errors. This PR fixes the issue by resizing images to a user-specified size.
2022-07-26 14:57:38 -07:00

43 lines
1.4 KiB
Python

import pandas as pd
from torchvision import transforms
from torchvision.models import resnet18
import ray
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.train.torch import TorchCheckpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.data.preprocessors import BatchMapper
from ray.data.datasource import ImageFolderDatasource
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
"""
User Pytorch code to transform user image. Note we still use pandas as
intermediate format to hold images as shorthand of python dictionary.
"""
preprocess = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
df["image"] = TensorArray([preprocess(x.to_numpy()) for x in df["image"]])
return df
data_url = "s3://anonymous@air-example-data-2/1G-image-data-synthetic-raw"
print(f"Running GPU batch prediction with 1GB data from {data_url}")
dataset = ray.data.read_datasource(
ImageFolderDatasource(), root=data_url, size=(256, 256)
)
model = resnet18(pretrained=True)
preprocessor = BatchMapper(preprocess)
ckpt = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)
predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
predictor.predict(dataset, feature_columns=["image"])