ray/doc/source/ray-air/doc_code/use_pretrained_model.py
xwjiang2010 76b34d4a03
[air] add to_air_checkpoint method for inference only workload. (#25444)
Follow up on our last discussion for supporting piecemeal fashion air users.
Only did for tensorflow for now, want to collect some feedback on API naming, package structure etc and I will add others.
2022-06-07 14:50:39 -07:00

32 lines
781 B
Python

# flake8: noqa
# __use_pretrained_model_start__
import ray
import tensorflow as tf
from ray.air.batch_predictor import BatchPredictor
from ray.air.predictors.integrations.tensorflow import (
to_air_checkpoint,
TensorflowPredictor,
)
# to simulate having a pretrained model.
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(1),
]
)
return model
model = build_model()
checkpoint = to_air_checkpoint(model)
batch_predictor = BatchPredictor(
checkpoint, TensorflowPredictor, model_definition=build_model
)
predict_dataset = ray.data.range(3)
predictions = batch_predictor.predict(predict_dataset)
# __use_pretrained_model_end__