mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
115468de2c
commit
beb9b02dbd
3 changed files with 47 additions and 2 deletions
|
@ -53,7 +53,7 @@ if [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "linux" ]]; then
|
|||
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
||||
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
||||
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
||||
blist scikit-learn
|
||||
blist scikit-learn numba
|
||||
elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
|
||||
# Install miniconda.
|
||||
wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv
|
||||
|
@ -64,7 +64,7 @@ elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
|
|||
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
||||
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
||||
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
||||
blist scikit-learn
|
||||
blist scikit-learn numba
|
||||
elif [[ "$LINT" == "1" ]]; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential curl unzip
|
||||
|
|
|
@ -383,3 +383,11 @@ py_test(
|
|||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_numba",
|
||||
size = "small",
|
||||
srcs = ["test_numba.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
|
37
python/ray/tests/test_numba.py
Normal file
37
python/ray/tests/test_numba.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import unittest
|
||||
|
||||
from numba import njit
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
@njit(fastmath=True)
|
||||
def centroid(x, y):
|
||||
return ((x / x.sum()) * y).sum()
|
||||
|
||||
|
||||
# Define a wrapper to call centroid function
|
||||
@ray.remote
|
||||
def centroid_wrapper(x, y):
|
||||
return centroid(x, y)
|
||||
|
||||
|
||||
class NumbaTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_numba_njit(self):
|
||||
x = np.random.random(10)
|
||||
y = np.random.random(1)
|
||||
result = ray.get(centroid_wrapper.remote(x, y))
|
||||
assert result == centroid(x, y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue