[AIR] Build model in TensorflowPredictor.predict (#25136)

`TensorflowPredictor.predict` doesn't work right now. For more information, see #25125.

Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
Balaji Veeramani 2022-05-26 16:42:09 -07:00 committed by GitHub
parent 087e356613
commit f623c607f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View file

@ -144,7 +144,14 @@ class TensorflowPredictor(Predictor):
# a callable that returns the model and initialize it here,
# instead of having an initialized model object as an attribute.
model = self.model_definition()
if self.model_weights:
if self.model_weights is not None:
input_shape = list(tensor.shape)
# The batch axis can contain varying number of elements, so we set
# the shape along the axis to `None`.
input_shape[0] = None
model.build(input_shape=input_shape)
model.set_weights(self.model_weights)
prediction = list(model(tensor).numpy())

View file

@ -57,6 +57,19 @@ def test_predict():
assert hasattr(predictor.preprocessor, "_batch_transformed")
def test_predict_array_with_input_shape_unspecified():
def model_definition():
return tf.keras.models.Sequential(tf.keras.layers.Lambda(lambda tensor: tensor))
predictor = TensorflowPredictor(model_definition=model_definition, model_weights=[])
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [1, 2, 3]
def test_predict_feature_columns():
preprocessor = DummyPreprocessor()
predictor = TensorflowPredictor(