diff --git a/python/ray/ml/tests/test_tensorflow_predictor.py b/python/ray/ml/tests/test_tensorflow_predictor.py index 1d9cd9701..cd0123ff7 100644 --- a/python/ray/ml/tests/test_tensorflow_predictor.py +++ b/python/ray/ml/tests/test_tensorflow_predictor.py @@ -53,7 +53,7 @@ def test_predict(): predictions = predictor.predict(data_batch) assert len(predictions) == 3 - assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6] + assert predictions.to_numpy().flatten().tolist() == [2, 4, 6] assert hasattr(predictor.preprocessor, "_batch_transformed") @@ -67,7 +67,7 @@ def test_predict_feature_columns(): predictions = predictor.predict(data_batch, feature_columns=[0]) assert len(predictions) == 3 - assert predictions.to_numpy().flatten().round().tolist() == [2, 4, 6] + assert predictions.to_numpy().flatten().tolist() == [2, 4, 6] assert hasattr(predictor.preprocessor, "_batch_transformed") @@ -82,3 +82,10 @@ def test_predict_no_preprocessor(): assert len(predictions) == 3 assert predictions.to_numpy().flatten().tolist() == [1, 2, 3] + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__]))