mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ML] TensorflowPredictor
implementation (#23146)
Implementation for TensorflowPredictor.
This commit is contained in:
parent
5ecd88e2e0
commit
447a98eed1
3 changed files with 134 additions and 9 deletions
|
@ -20,6 +20,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_tensorflow_predictor",
|
||||
size = "small",
|
||||
srcs = ["tests/test_tensorflow_predictor.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_predictor",
|
||||
size = "small",
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from typing import Callable, Optional, Union, List, Type
|
||||
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.ml.predictor import Predictor, DataBatchType
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
|
||||
# TensorFlow model objects cannot be pickled, therefore we use
|
||||
# a callable that returns the model, instead of a model object
|
||||
# itself.
|
||||
from ray.ml.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
|
||||
|
||||
class TensorflowPredictor(Predictor):
|
||||
|
@ -25,10 +23,12 @@ class TensorflowPredictor(Predictor):
|
|||
def __init__(
|
||||
self,
|
||||
model_definition: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]],
|
||||
preprocessor: Preprocessor,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
model_weights: Optional[list] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
self.model_definition = model_definition
|
||||
self.model_weights = model_weights
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
@ -47,7 +47,20 @@ class TensorflowPredictor(Predictor):
|
|||
model_definition: A callable that returns a TensorFlow Keras model
|
||||
to use. Model weights will be loaded from the checkpoint.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
|
||||
if MODEL_KEY not in checkpoint_dict:
|
||||
raise RuntimeError(
|
||||
f"No item with key: {MODEL_KEY} is found in the "
|
||||
f"Checkpoint. Make sure this key exists when saving the "
|
||||
f"checkpoint in ``TensorflowTrainer``."
|
||||
)
|
||||
model_weights = checkpoint_dict[MODEL_KEY]
|
||||
return TensorflowPredictor(
|
||||
model_definition=model_definition,
|
||||
model_weights=model_weights,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
|
@ -57,7 +70,7 @@ class TensorflowPredictor(Predictor):
|
|||
) -> DataBatchType:
|
||||
"""Run inference on data batch.
|
||||
|
||||
The data is converted into a torch Tensor before being inputted to
|
||||
The data is converted into a TensorFlow Tensor before being inputted to
|
||||
the model.
|
||||
|
||||
Args:
|
||||
|
@ -121,4 +134,24 @@ class TensorflowPredictor(Predictor):
|
|||
Returns:
|
||||
DataBatchType: Prediction result.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if self.preprocessor:
|
||||
data = self.preprocessor.transform_batch(data)
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
if feature_columns:
|
||||
data = data[feature_columns]
|
||||
data = data.values
|
||||
else:
|
||||
data = data[:, feature_columns]
|
||||
|
||||
tensor = tf.convert_to_tensor(data, dtype=dtype)
|
||||
|
||||
# TensorFlow model objects cannot be pickled, therefore we use
|
||||
# a callable that returns the model and initialize it here,
|
||||
# instead of having an initialized model object as an attribute.
|
||||
model = self.model_definition()
|
||||
if self.model_weights:
|
||||
model.set_weights(self.model_weights)
|
||||
|
||||
prediction = model(tensor).numpy().ravel()
|
||||
return pd.DataFrame(prediction, columns=["predictions"])
|
||||
|
|
84
python/ray/ml/tests/test_tensorflow_predictor.py
Normal file
84
python/ray/ml/tests/test_tensorflow_predictor.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.constants import PREPROCESSOR_KEY, MODEL_KEY
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class DummyPreprocessor(Preprocessor):
|
||||
def transform_batch(self, df):
|
||||
self._batch_transformed = True
|
||||
return df * 2
|
||||
|
||||
|
||||
def build_model() -> tf.keras.Model:
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.InputLayer(input_shape=(1,)),
|
||||
tf.keras.layers.Dense(1),
|
||||
]
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
weights = [np.array([[1.0]]), np.array([0.0])]
|
||||
|
||||
|
||||
def test_init():
|
||||
preprocessor = DummyPreprocessor()
|
||||
predictor = TensorflowPredictor(
|
||||
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
|
||||
)
|
||||
|
||||
checkpoint = {MODEL_KEY: weights, PREPROCESSOR_KEY: preprocessor}
|
||||
checkpoint_predictor = TensorflowPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict(checkpoint), build_model
|
||||
)
|
||||
|
||||
assert checkpoint_predictor.model_definition == predictor.model_definition
|
||||
assert checkpoint_predictor.model_weights == predictor.model_weights
|
||||
assert checkpoint_predictor.preprocessor == predictor.preprocessor
|
||||
|
||||
|
||||
def test_predict():
|
||||
preprocessor = DummyPreprocessor()
|
||||
predictor = TensorflowPredictor(
|
||||
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
|
||||
)
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns():
|
||||
preprocessor = DummyPreprocessor()
|
||||
predictor = TensorflowPredictor(
|
||||
model_definition=build_model, preprocessor=preprocessor, model_weights=weights
|
||||
)
|
||||
|
||||
data_batch = np.array([[1, 4], [2, 5], [3, 6]])
|
||||
predictions = predictor.predict(data_batch, feature_columns=[0])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_no_preprocessor():
|
||||
checkpoint = {MODEL_KEY: weights}
|
||||
predictor = TensorflowPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict(checkpoint), build_model
|
||||
)
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [1, 2, 3]
|
Loading…
Add table
Reference in a new issue