mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
115468de2c
commit
beb9b02dbd
3 changed files with 47 additions and 2 deletions
|
@ -53,7 +53,7 @@ if [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "linux" ]]; then
|
||||||
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
||||||
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
||||||
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
||||||
blist scikit-learn
|
blist scikit-learn numba
|
||||||
elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
|
elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
|
||||||
# Install miniconda.
|
# Install miniconda.
|
||||||
wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv
|
wget -q https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv
|
||||||
|
@ -64,7 +64,7 @@ elif [[ "$PYTHON" == "3.6" ]] && [[ "$platform" == "macosx" ]]; then
|
||||||
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
opencv-python-headless pyyaml pandas==0.24.2 requests \
|
||||||
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
feather-format lxml openpyxl xlrd py-spy pytest-timeout networkx tabulate aiohttp \
|
||||||
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
uvicorn dataclasses pygments werkzeug kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio \
|
||||||
blist scikit-learn
|
blist scikit-learn numba
|
||||||
elif [[ "$LINT" == "1" ]]; then
|
elif [[ "$LINT" == "1" ]]; then
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y build-essential curl unzip
|
sudo apt-get install -y build-essential curl unzip
|
||||||
|
|
|
@ -383,3 +383,11 @@ py_test(
|
||||||
tags = ["exclusive"],
|
tags = ["exclusive"],
|
||||||
deps = ["//:ray_lib"],
|
deps = ["//:ray_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_numba",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["test_numba.py"],
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = ["//:ray_lib"],
|
||||||
|
)
|
||||||
|
|
37
python/ray/tests/test_numba.py
Normal file
37
python/ray/tests/test_numba.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from numba import njit
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
@njit(fastmath=True)
|
||||||
|
def centroid(x, y):
|
||||||
|
return ((x / x.sum()) * y).sum()
|
||||||
|
|
||||||
|
|
||||||
|
# Define a wrapper to call centroid function
|
||||||
|
@ray.remote
|
||||||
|
def centroid_wrapper(x, y):
|
||||||
|
return centroid(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class NumbaTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
ray.init(num_cpus=1)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_numba_njit(self):
|
||||||
|
x = np.random.random(10)
|
||||||
|
y = np.random.random(1)
|
||||||
|
result = ray.get(centroid_wrapper.remote(x, y))
|
||||||
|
assert result == centroid(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue