Add numba test (#7298) (#7487)

This commit is contained in:
Landcold7 2020-03-08 03:12:25 +08:00 committed by GitHub
parent 115468de2c
commit beb9b02dbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 2 deletions

View file

@ -53,7 +53,7 @@ if [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "linux" ]]; then
opencv-python-headless pyyaml pandas==0.24.2 requests \ opencv-python-headless pyyaml pandas==0.24.2 requests \
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \ feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \ uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
blist scikit-learn blist scikit-learn numba
elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
# Install miniconda. # Install miniconda.
wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv
@ -64,7 +64,7 @@ elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
opencv-python-headless pyyaml pandas==0.24.2 requests \ opencv-python-headless pyyaml pandas==0.24.2 requests \
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \ feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \ uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
blist scikit-learn blist scikit-learn numba
elif [[ "$LINT" == "1" ]]; then elif [[ "$LINT" == "1" ]]; then
sudo apt-get update sudo apt-get update
sudo apt-get install -y build-essential curl unzip sudo apt-get install -y build-essential curl unzip

View file

@ -383,3 +383,11 @@ py_test(
tags = ["exclusive"], tags = ["exclusive"],
deps = ["//:ray_lib"], deps = ["//:ray_lib"],
) )
py_test(
name = "test_numba",
size = "small",
srcs = ["test_numba.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)

View file

@ -0,0 +1,37 @@
import unittest
from numba import njit
import numpy as np
import ray
@njit(fastmath=True)
def centroid(x, y):
return ((x / x.sum()) * y).sum()
# Define a wrapper to call centroid function
@ray.remote
def centroid_wrapper(x, y):
return centroid(x, y)
class NumbaTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1)
def tearDown(self):
ray.shutdown()
def test_numba_njit(self):
x = np.random.random(10)
y = np.random.random(1)
result = ray.get(centroid_wrapper.remote(x, y))
assert result == centroid(x, y)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))