[ML] TensorflowPredictor implementation (#23146)

Implementation for TensorflowPredictor.
This commit is contained in:
Antoni Baum 2022-03-15 01:02:21 +01:00 committed by GitHub
parent 5ecd88e2e0
commit 447a98eed1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 9 deletions

View file

@ -20,6 +20,14 @@ py_test(
deps = [":ml_lib"]
)
py_test(
name = "test_tensorflow_predictor",
size = "small",
srcs = ["tests/test_tensorflow_predictor.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test(
name = "test_torch_predictor",
size = "small",

View file

@ -1,14 +1,12 @@
from typing import Callable, Optional, Union, List, Type
import pandas as pd
import tensorflow as tf
from ray.ml.predictor import Predictor, DataBatchType
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
# TensorFlow model objects cannot be pickled, therefore we use
# a callable that returns the model, instead of a model object
# itself.
from ray.ml.constants import MODEL_KEY, PREPROCESSOR_KEY
class TensorflowPredictor(Predictor):
@ -25,10 +23,12 @@ class TensorflowPredictor(Predictor):
def __init__(
self,
model_definition: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]],
preprocessor: Preprocessor,
preprocessor: Optional[Preprocessor] = None,
model_weights: Optional[list] = None,
):
raise NotImplementedError
self.model_definition = model_definition
self.model_weights = model_weights
self.preprocessor = preprocessor
@classmethod
def from_checkpoint(
@ -47,7 +47,20 @@ class TensorflowPredictor(Predictor):
model_definition: A callable that returns a TensorFlow Keras model
to use. Model weights will be loaded from the checkpoint.
"""
raise NotImplementedError
checkpoint_dict = checkpoint.to_dict()
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
if MODEL_KEY not in checkpoint_dict:
raise RuntimeError(
f"No item with key: {MODEL_KEY} is found in the "
f"Checkpoint. Make sure this key exists when saving the "
f"checkpoint in ``TensorflowTrainer``."
)
model_weights = checkpoint_dict[MODEL_KEY]
return TensorflowPredictor(
model_definition=model_definition,
model_weights=model_weights,
preprocessor=preprocessor,
)
def predict(
self,
@ -57,7 +70,7 @@ class TensorflowPredictor(Predictor):
) -> DataBatchType:
"""Run inference on data batch.
The data is converted into a torch Tensor before being inputted to
The data is converted into a TensorFlow Tensor before being inputted to
the model.
Args:
@ -121,4 +134,24 @@ class TensorflowPredictor(Predictor):
Returns:
DataBatchType: Prediction result.
"""
raise NotImplementedError
if self.preprocessor:
data = self.preprocessor.transform_batch(data)
if isinstance(data, pd.DataFrame):
if feature_columns:
data = data[feature_columns]
data = data.values
else:
data = data[:, feature_columns]
tensor = tf.convert_to_tensor(data, dtype=dtype)
# TensorFlow model objects cannot be pickled, therefore we use
# a callable that returns the model and initialize it here,
# instead of having an initialized model object as an attribute.
model = self.model_definition()
if self.model_weights:
model.set_weights(self.model_weights)
prediction = model(tensor).numpy().ravel()
return pd.DataFrame(prediction, columns=["predictions"])

View file

@ -0,0 +1,84 @@
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
from ray.ml.constants import PREPROCESSOR_KEY, MODEL_KEY
import numpy as np
import tensorflow as tf
class DummyPreprocessor(Preprocessor):
def transform_batch(self, df):
self._batch_transformed = True
return df * 2
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(1),
]
)
return model
weights = [np.array([[1.0]]), np.array([0.0])]
def test_init():
preprocessor = DummyPreprocessor()
predictor = TensorflowPredictor(
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
)
checkpoint = {MODEL_KEY: weights, PREPROCESSOR_KEY: preprocessor}
checkpoint_predictor = TensorflowPredictor.from_checkpoint(
Checkpoint.from_dict(checkpoint), build_model
)
assert checkpoint_predictor.model_definition == predictor.model_definition
assert checkpoint_predictor.model_weights == predictor.model_weights
assert checkpoint_predictor.preprocessor == predictor.preprocessor
def test_predict():
preprocessor = DummyPreprocessor()
predictor = TensorflowPredictor(
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
)
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
assert hasattr(predictor.preprocessor, "_batch_transformed")
def test_predict_feature_columns():
preprocessor = DummyPreprocessor()
predictor = TensorflowPredictor(
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
)
data_batch = np.array([[1, 4], [2, 5], [3, 6]])
predictions = predictor.predict(data_batch, feature_columns=[0])
assert len(predictions) == 3
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
assert hasattr(predictor.preprocessor, "_batch_transformed")
def test_predict_no_preprocessor():
checkpoint = {MODEL_KEY: weights}
predictor = TensorflowPredictor.from_checkpoint(
Checkpoint.from_dict(checkpoint), build_model
)
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [1, 2, 3]