ray/doc/source/serve/doc_code/tutorial_pytorch.py

52 lines
1.6 KiB
Python

# fmt: off
# __doc_import_begin__
from ray import serve
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
from torchvision.models import resnet18
# __doc_import_end__
# fmt: on
# __doc_define_servable_begin__
@serve.deployment(route_prefix="/image_predict")
class ImageModel:
def __init__(self):
self.model = resnet18(pretrained=True).eval()
self.preprocessor = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
async def __call__(self, starlette_request):
image_payload_bytes = await starlette_request.body()
pil_image = Image.open(BytesIO(image_payload_bytes))
print("[1/3] Parsed image data: {}".format(pil_image))
pil_images = [pil_image] # Our current batch size is one
input_tensor = torch.cat(
[self.preprocessor(i).unsqueeze(0) for i in pil_images]
)
print("[2/3] Images transformed, tensor shape {}".format(input_tensor.shape))
with torch.no_grad():
output_tensor = self.model(input_tensor)
print("[3/3] Inference done!")
return {"class_index": int(torch.argmax(output_tensor[0]))}
# __doc_define_servable_end__
# __doc_deploy_begin__
app = ImageModel.bind()
# __doc_deploy_end__