mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[AIR] Build model in TensorflowPredictor.predict
(#25136)
`TensorflowPredictor.predict` doesn't work right now. For more information, see #25125. Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
087e356613
commit
f623c607f2
2 changed files with 21 additions and 1 deletions
|
@ -144,7 +144,14 @@ class TensorflowPredictor(Predictor):
|
|||
# a callable that returns the model and initialize it here,
|
||||
# instead of having an initialized model object as an attribute.
|
||||
model = self.model_definition()
|
||||
if self.model_weights:
|
||||
|
||||
if self.model_weights is not None:
|
||||
input_shape = list(tensor.shape)
|
||||
# The batch axis can contain varying number of elements, so we set
|
||||
# the shape along the axis to `None`.
|
||||
input_shape[0] = None
|
||||
|
||||
model.build(input_shape=input_shape)
|
||||
model.set_weights(self.model_weights)
|
||||
|
||||
prediction = list(model(tensor).numpy())
|
||||
|
|
|
@ -57,6 +57,19 @@ def test_predict():
|
|||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_array_with_input_shape_unspecified():
|
||||
def model_definition():
|
||||
return tf.keras.models.Sequential(tf.keras.layers.Lambda(lambda tensor: tensor))
|
||||
|
||||
predictor = TensorflowPredictor(model_definition=model_definition, model_weights=[])
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [1, 2, 3]
|
||||
|
||||
|
||||
def test_predict_feature_columns():
|
||||
preprocessor = DummyPreprocessor()
|
||||
predictor = TensorflowPredictor(
|
||||
|
|
Loading…
Add table
Reference in a new issue