ray/doc/examples/cython/tests/test_cython.py
Barak Michener 68f3fec744
*: Centralize requirements.txt and unify dependency versions (#9759)
* python_test: fix cython_examples in doc/ and tests/

* update setup.py to parse the bazel version string better

* all: centralize all python deps into stackable requirements files in python/

* format

* Move cython test into the proper package

* Add cross-reference dependency comments for requirements and setup.py

* re-enable version pinning on CI, fix formatting

* fix up torchvision version

* fix case in shell
2020-07-30 11:22:56 -07:00

59 lines
1.6 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
import math
import numpy as np
import unittest
import ray
import cython_examples as cyth
def get_ray_result(cython_func, *args):
func = ray.remote(cython_func)
return ray.get(func.remote(*args))
class CythonTest(unittest.TestCase):
def setUp(self):
ray.init(object_store_memory=int(150 * 1024 * 1024))
def tearDown(self):
ray.shutdown()
def assertEqualHelper(self, cython_func, expected, *args):
assert get_ray_result(cython_func, *args) == expected
def test_simple_func(self):
self.assertEqualHelper(cyth.simple_func, 6, 1, 2, 3)
self.assertEqualHelper(cyth.fib, 55, 10)
self.assertEqualHelper(cyth.fib_int, 55, 10)
self.assertEqualHelper(cyth.fib_cpdef, 55, 10)
self.assertEqualHelper(cyth.fib_cdef, 55, 10)
def test_simple_class(self):
cls = ray.remote(cyth.simple_class)
a1 = cls.remote()
a2 = cls.remote()
result1 = ray.get(a1.increment.remote())
result2 = ray.get(a2.increment.remote())
result3 = ray.get(a2.increment.remote())
assert result1 == 1
assert result2 == 1
assert result3 == 2
def test_numpy(self):
array = np.array([-1.0, 0.0, 1.0, 2.0])
answer = [float("-inf") if x <= 0 else math.log(x) for x in array]
result = get_ray_result(cyth.masked_log, array)
np.testing.assert_array_equal(answer, result)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))