mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
# fmt: off
|
|
# __doc_import_begin__
|
|
from ray import serve
|
|
|
|
import os
|
|
import tempfile
|
|
import numpy as np
|
|
from starlette.requests import Request
|
|
from typing import Dict
|
|
|
|
import tensorflow as tf
|
|
# __doc_import_end__
|
|
# fmt: on
|
|
|
|
# __doc_train_model_begin__
|
|
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")
|
|
|
|
|
|
def train_and_save_model():
|
|
# Load mnist dataset
|
|
mnist = tf.keras.datasets.mnist
|
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
|
|
|
# Train a simple neural net model
|
|
model = tf.keras.models.Sequential(
|
|
[
|
|
tf.keras.layers.Flatten(input_shape=(28, 28)),
|
|
tf.keras.layers.Dense(128, activation="relu"),
|
|
tf.keras.layers.Dropout(0.2),
|
|
tf.keras.layers.Dense(10),
|
|
]
|
|
)
|
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
|
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
|
|
model.fit(x_train, y_train, epochs=1)
|
|
|
|
model.evaluate(x_test, y_test, verbose=2)
|
|
model.summary()
|
|
|
|
# Save the model in h5 format in local file system
|
|
model.save(TRAINED_MODEL_PATH)
|
|
|
|
|
|
if not os.path.exists(TRAINED_MODEL_PATH):
|
|
train_and_save_model()
|
|
# __doc_train_model_end__
|
|
|
|
|
|
# __doc_define_servable_begin__
|
|
@serve.deployment
|
|
class TFMnistModel:
|
|
def __init__(self, model_path: str):
|
|
import tensorflow as tf
|
|
|
|
self.model_path = model_path
|
|
self.model = tf.keras.models.load_model(model_path)
|
|
|
|
async def __call__(self, starlette_request: Request) -> Dict:
|
|
# Step 1: transform HTTP request -> tensorflow input
|
|
# Here we define the request schema to be a json array.
|
|
input_array = np.array((await starlette_request.json())["array"])
|
|
reshaped_array = input_array.reshape((1, 28, 28))
|
|
|
|
# Step 2: tensorflow input -> tensorflow output
|
|
prediction = self.model(reshaped_array)
|
|
|
|
# Step 3: tensorflow output -> web output
|
|
return {"prediction": prediction.numpy().tolist(), "file": self.model_path}
|
|
# __doc_define_servable_end__
|
|
|
|
|
|
# __doc_deploy_begin__
|
|
mnist_model = TFMnistModel.bind(TRAINED_MODEL_PATH)
|
|
# __doc_deploy_end__
|