diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 542f183c5..af920c48a 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -53,7 +53,7 @@ if [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "linux" ]]; then opencv-python-headless pyyaml pandas==0.24.2 requests \ feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \ uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \ - blist scikit-learn + blist scikit-learn numba elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then # Install miniconda. wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv @@ -64,7 +64,7 @@ elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then opencv-python-headless pyyaml pandas==0.24.2 requests \ feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \ uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \ - blist scikit-learn + blist scikit-learn numba elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y build-essential curl unzip diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 448176474..96a844c00 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -383,3 +383,11 @@ py_test( tags = ["exclusive"], deps = ["//:ray_lib"], ) + +py_test( + name = "test_numba", + size = "small", + srcs = ["test_numba.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/tests/test_numba.py b/python/ray/tests/test_numba.py new file mode 100644 index 000000000..d9a87ffb6 --- /dev/null +++ b/python/ray/tests/test_numba.py @@ -0,0 +1,37 @@ +import unittest + +from numba import njit +import numpy as np + +import ray + + +@njit(fastmath=True) +def centroid(x, y): + return ((x / x.sum()) * y).sum() + + +# Define a wrapper to call centroid function +@ray.remote +def centroid_wrapper(x, y): + return centroid(x, y) + + +class NumbaTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=1) + + def tearDown(self): + ray.shutdown() + + def test_numba_njit(self): + x = np.random.random(10) + y = np.random.random(1) + result = ray.get(centroid_wrapper.remote(x, y)) + assert result == centroid(x, y) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))