From f623c607f29d6b4de49e85e0c3dbed56b00ca749 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 26 May 2022 16:42:09 -0700 Subject: [PATCH] [AIR] Build model in `TensorflowPredictor.predict` (#25136) `TensorflowPredictor.predict` doesn't work right now. For more information, see #25125. Co-authored-by: Amog Kamsetty --- .../integrations/tensorflow/tensorflow_predictor.py | 9 ++++++++- python/ray/ml/tests/test_tensorflow_predictor.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py b/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py index 2eb125db3..cd69ecdc3 100644 --- a/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py +++ b/python/ray/ml/predictors/integrations/tensorflow/tensorflow_predictor.py @@ -144,7 +144,14 @@ class TensorflowPredictor(Predictor): # a callable that returns the model and initialize it here, # instead of having an initialized model object as an attribute. model = self.model_definition() - if self.model_weights: + + if self.model_weights is not None: + input_shape = list(tensor.shape) + # The batch axis can contain varying number of elements, so we set + # the shape along the axis to `None`. + input_shape[0] = None + + model.build(input_shape=input_shape) model.set_weights(self.model_weights) prediction = list(model(tensor).numpy()) diff --git a/python/ray/ml/tests/test_tensorflow_predictor.py b/python/ray/ml/tests/test_tensorflow_predictor.py index cd0123ff7..9a20f8891 100644 --- a/python/ray/ml/tests/test_tensorflow_predictor.py +++ b/python/ray/ml/tests/test_tensorflow_predictor.py @@ -57,6 +57,19 @@ def test_predict(): assert hasattr(predictor.preprocessor, "_batch_transformed") +def test_predict_array_with_input_shape_unspecified(): + def model_definition(): + return tf.keras.models.Sequential(tf.keras.layers.Lambda(lambda tensor: tensor)) + + predictor = TensorflowPredictor(model_definition=model_definition, model_weights=[]) + + data_batch = np.array([[1], [2], [3]]) + predictions = predictor.predict(data_batch) + + assert len(predictions) == 3 + assert predictions.to_numpy().flatten().tolist() == [1, 2, 3] + + def test_predict_feature_columns(): preprocessor = DummyPreprocessor() predictor = TensorflowPredictor(