mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[AIR] Run test_tensorflow_predictors.py
and fix failing tests (#25208)
`test_tensorflow_predictors` wasn't running in CI. This fixes that and also fixes broken tests. Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
d2f0c3b2f6
commit
1ad5e619e1
1 changed files with 9 additions and 2 deletions
|
@ -53,7 +53,7 @@ def test_predict():
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
|
||||
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ def test_predict_feature_columns():
|
|||
predictions = predictor.predict(data_batch, feature_columns=[0])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6]
|
||||
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
|
||||
|
||||
|
@ -82,3 +82,10 @@ def test_predict_no_preprocessor():
|
|||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [1, 2, 3]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue