2020-04-28 22:24:55 -07:00
|
|
|
# yapf: disable
|
|
|
|
# __doc_import_begin__
|
|
|
|
from ray import serve
|
|
|
|
|
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
|
|
|
import requests
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torchvision import transforms
|
|
|
|
from torchvision.models import resnet18
|
|
|
|
# __doc_import_end__
|
|
|
|
# yapf: enable
|
|
|
|
|
|
|
|
|
|
|
|
# __doc_define_servable_begin__
|
|
|
|
class ImageModel:
|
|
|
|
def __init__(self):
|
|
|
|
self.model = resnet18(pretrained=True)
|
|
|
|
self.preprocessor = transforms.Compose([
|
|
|
|
transforms.Resize(224),
|
|
|
|
transforms.CenterCrop(224),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel
|
|
|
|
])
|
|
|
|
|
|
|
|
def __call__(self, flask_request):
|
|
|
|
image_payload_bytes = flask_request.data
|
|
|
|
pil_image = Image.open(BytesIO(image_payload_bytes))
|
|
|
|
print("[1/3] Parsed image data: {}".format(pil_image))
|
|
|
|
|
|
|
|
pil_images = [pil_image] # Our current batch size is one
|
|
|
|
input_tensor = torch.cat(
|
|
|
|
[self.preprocessor(i).unsqueeze(0) for i in pil_images])
|
|
|
|
print("[2/3] Images transformed, tensor shape {}".format(
|
|
|
|
input_tensor.shape))
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
output_tensor = self.model(input_tensor)
|
|
|
|
print("[3/3] Inference done!")
|
|
|
|
return {"class_index": int(torch.argmax(output_tensor[0]))}
|
|
|
|
|
|
|
|
|
|
|
|
# __doc_define_servable_end__
|
|
|
|
|
|
|
|
# __doc_deploy_begin__
|
|
|
|
serve.init()
|
2020-04-30 22:31:07 -05:00
|
|
|
serve.create_backend("resnet18:v0", ImageModel)
|
2020-06-06 21:10:42 -05:00
|
|
|
serve.create_endpoint(
|
|
|
|
"predictor",
|
|
|
|
backend="resnet18:v0",
|
|
|
|
route="/image_predict",
|
|
|
|
methods=["POST"])
|
2020-04-28 22:24:55 -07:00
|
|
|
# __doc_deploy_end__
|
|
|
|
|
|
|
|
# __doc_query_begin__
|
|
|
|
ray_logo_bytes = requests.get(
|
|
|
|
"https://github.com/ray-project/ray/raw/"
|
|
|
|
"master/doc/source/images/ray_header_logo.png").content
|
|
|
|
|
|
|
|
resp = requests.post(
|
|
|
|
"http://localhost:8000/image_predict", data=ray_logo_bytes)
|
|
|
|
print(resp.json())
|
|
|
|
# Output
|
|
|
|
# {'class_index': 463}
|
|
|
|
# __doc_query_end__
|