mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Add basic functionality for Cython functions and actors * Fix up per @pcmoritz comments * Fixes per @richardliaw comments * Fixes per @robertnishihara comments * Forgot double quotes when updating masked_log * Remove import typing for Python 2 compatibility
57 lines
1.5 KiB
Python
57 lines
1.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
import cython_examples as cyth
|
|
|
|
|
|
def get_ray_result(cython_func, *args):
|
|
func = ray.remote(cython_func)
|
|
return ray.get(func.remote(*args))
|
|
|
|
|
|
class CythonTest(unittest.TestCase):
|
|
def setUp(self):
|
|
ray.init()
|
|
|
|
def tearDown(self):
|
|
ray.worker.cleanup()
|
|
|
|
def assertEqualHelper(self, cython_func, expected, *args):
|
|
self.assertEqual(get_ray_result(cython_func, *args), expected)
|
|
|
|
def test_simple_func(self):
|
|
self.assertEqualHelper(cyth.simple_func, 6, 1, 2, 3)
|
|
self.assertEqualHelper(cyth.fib, 55, 10)
|
|
self.assertEqualHelper(cyth.fib_int, 55, 10)
|
|
self.assertEqualHelper(cyth.fib_cpdef, 55, 10)
|
|
self.assertEqualHelper(cyth.fib_cdef, 55, 10)
|
|
|
|
def test_simple_class(self):
|
|
cls = ray.remote(cyth.simple_class)
|
|
a1 = cls.remote()
|
|
a2 = cls.remote()
|
|
|
|
result1 = ray.get(a1.increment.remote())
|
|
result2 = ray.get(a2.increment.remote())
|
|
result3 = ray.get(a2.increment.remote())
|
|
|
|
self.assertEqual(result1, 1)
|
|
self.assertEqual(result2, 1)
|
|
self.assertEqual(result3, 2)
|
|
|
|
def test_numpy(self):
|
|
array = np.array([-1.0, 0.0, 1.0, 2.0])
|
|
|
|
answer = [float("-inf") if x <= 0 else math.log(x) for x in array]
|
|
result = get_ray_result(cyth.masked_log, array)
|
|
|
|
np.testing.assert_array_equal(answer, result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|