From 696a229ece4dfa7b10634873d485267b6749d1df Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 30 Jul 2018 11:04:06 -0700 Subject: [PATCH] Fix text verbosity in python 2.7 by running tests with pytest (#2470) --- .travis.yml | 84 +++++---- python/ray/test/test_functions.py | 108 ------------ test/actor_test.py | 202 +++++++++++----------- test/failure_test.py | 25 ++- test/microbenchmarks.py | 21 +-- test/recursion_test.py | 13 +- test/runtest.py | 272 +++++++++++++++++------------- 7 files changed, 344 insertions(+), 381 deletions(-) delete mode 100644 python/ray/test/test_functions.py diff --git a/.travis.yml b/.travis.yml index fafa2fbea..b583b9d75 100644 --- a/.travis.yml +++ b/.travis.yml @@ -119,30 +119,38 @@ matrix: - ./.travis/install-cython-examples.sh script: - export PATH="$HOME/miniconda/bin:$PATH" + # The following is needed so cloudpickle can find some of the + # class definitions: The main module of tests that are run + # with pytest have the same name as the test file -- and this + # module is only found if the test directory is in the PYTHONPATH. + - export PYTHONPATH="$PYTHONPATH:./test/" - - python python/ray/common/test/test.py - - python python/ray/common/redis_module/runtest.py - - python python/ray/plasma/test/test.py - # - python python/ray/local_scheduler/test/test.py - # - python python/ray/global_scheduler/test/test.py + - python -m pytest python/ray/common/test/test.py + - python -m pytest python/ray/common/redis_module/runtest.py + - python -m pytest python/ray/plasma/test/test.py + # - python -m pytest python/ray/local_scheduler/test/test.py + # - python -m pytest python/ray/global_scheduler/test/test.py - python -m pytest python/ray/test/test_queue.py - python -m pytest test/xray_test.py - - python test/runtest.py - - python test/array_test.py - - python test/actor_test.py - - python test/autoscaler_test.py - - python test/tensorflow_test.py - - python test/failure_test.py - - python test/microbenchmarks.py + # The --assert=plain here is because pytest's assertion + # rewriting mechanism seems to mess up on this file, + # see https://github.com/ray-project/ray/issues/2514 + - python -m pytest -v --assert=plain test/runtest.py + - python -m pytest test/array_test.py + - python -m pytest test/actor_test.py + - python -m pytest test/autoscaler_test.py + - python -m pytest test/tensorflow_test.py + - python -m pytest test/failure_test.py + - python -m pytest test/microbenchmarks.py - python -m pytest test/stress_tests.py - # - python test/component_failures_test.py + # - pytest test/component_failures_test.py - python test/multi_node_test.py - - python test/recursion_test.py - # - python test/monitor_test.py - - python test/cython_test.py - - python test/credis_test.py + - python -m pytest test/recursion_test.py + # - pytest test/monitor_test.py + - python -m pytest test/cython_test.py + - python -m pytest test/credis_test.py # ray tune tests - python python/ray/tune/test/dependency_test.py @@ -179,30 +187,38 @@ install: script: - export PATH="$HOME/miniconda/bin:$PATH" + # The following is needed so cloudpickle can find some of the + # class definitions: The main module of tests that are run + # with pytest have the same name as the test file -- and this + # module is only found if the test directory is in the PYTHONPATH. + - export PYTHONPATH="$PYTHONPATH:./test/" - - python python/ray/common/test/test.py - - python python/ray/common/redis_module/runtest.py - - python python/ray/plasma/test/test.py - - python python/ray/local_scheduler/test/test.py - - python python/ray/global_scheduler/test/test.py + - python -m pytest python/ray/common/test/test.py + - python -m pytest python/ray/common/redis_module/runtest.py + - python -m pytest python/ray/plasma/test/test.py + - python -m pytest python/ray/local_scheduler/test/test.py + - python -m pytest python/ray/global_scheduler/test/test.py - python -m pytest python/ray/test/test_queue.py - python -m pytest test/xray_test.py - - python test/runtest.py - - python test/array_test.py - - python test/actor_test.py - - python test/autoscaler_test.py - - python test/tensorflow_test.py - - python test/failure_test.py - - python test/microbenchmarks.py + # The --assert=plain here is because pytest's assertion + # rewriting mechanism seems to mess up on this file, + # see https://github.com/ray-project/ray/issues/2514 + - python -m pytest --assert=plain -v test/runtest.py + - python -m pytest test/array_test.py + - python -m pytest test/actor_test.py + - python -m pytest test/autoscaler_test.py + - python -m pytest test/tensorflow_test.py + - python -m pytest test/failure_test.py + - python -m pytest test/microbenchmarks.py - python -m pytest test/stress_tests.py - - python test/component_failures_test.py + - python -m pytest test/component_failures_test.py - python test/multi_node_test.py - - python test/recursion_test.py - - python test/monitor_test.py - - python test/cython_test.py - - python test/credis_test.py + - python -m pytest test/recursion_test.py + - python -m pytest test/monitor_test.py + - python -m pytest test/cython_test.py + - python -m pytest test/credis_test.py # ray tune tests - python python/ray/tune/test/dependency_test.py diff --git a/python/ray/test/test_functions.py b/python/ray/test/test_functions.py deleted file mode 100644 index ade1c8183..000000000 --- a/python/ray/test/test_functions.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray - -import numpy as np - -# Test simple functionality - - -@ray.remote(num_return_vals=2) -def handle_int(a, b): - return a + 1, b + 1 - - -# Test timing - - -@ray.remote -def empty_function(): - pass - - -@ray.remote -def trivial_function(): - return 1 - - -# Test keyword arguments - - -@ray.remote -def keyword_fct1(a, b="hello"): - return "{} {}".format(a, b) - - -@ray.remote -def keyword_fct2(a="hello", b="world"): - return "{} {}".format(a, b) - - -@ray.remote -def keyword_fct3(a, b, c="hello", d="world"): - return "{} {} {} {}".format(a, b, c, d) - - -# Test variable numbers of arguments - - -@ray.remote -def varargs_fct1(*a): - return " ".join(map(str, a)) - - -@ray.remote -def varargs_fct2(a, *b): - return " ".join(map(str, b)) - - -try: - - @ray.remote - def kwargs_throw_exception(**c): - return () - - kwargs_exception_thrown = False -except Exception: - kwargs_exception_thrown = True - -# test throwing an exception - - -@ray.remote -def throw_exception_fct1(): - raise Exception("Test function 1 intentionally failed.") - - -@ray.remote -def throw_exception_fct2(): - raise Exception("Test function 2 intentionally failed.") - - -@ray.remote(num_return_vals=3) -def throw_exception_fct3(x): - raise Exception("Test function 3 intentionally failed.") - - -# test Python mode - - -@ray.remote -def local_mode_f(): - return np.array([0, 0]) - - -@ray.remote -def local_mode_g(x): - x[0] = 1 - return x - - -# test no return values - - -@ray.remote -def no_op(): - pass diff --git a/test/actor_test.py b/test/actor_test.py index 3d4bf1372..a0fe6d410 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -6,6 +6,7 @@ import collections import random import numpy as np import os +import pytest import sys import time import unittest @@ -275,7 +276,7 @@ class ActorAPI(unittest.TestCase): self.assertEqual(len(actor_keys), 1) actor_class_info = r.hgetall(actor_keys[0]) self.assertEqual(actor_class_info[b"class_name"], b"Foo") - self.assertEqual(actor_class_info[b"module"], b"__main__") + self.assertEqual(actor_class_info[b"module"], b"actor_test") def testMultipleReturnValues(self): ray.init(num_workers=0) @@ -1965,120 +1966,127 @@ class DistributedActorHandles(unittest.TestCase): self.assertEqual(ray.get(f2.method.remote()), 4) -class ActorPlacementAndResources(unittest.TestCase): - def tearDown(self): - ray.shutdown() +@pytest.fixture +def ray_stop(): + # The initialization code depends on the test that is run. + yield None + # The code after the yield will run as teardown code. + ray.shutdown() - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") - def testLifetimeAndTransientResources(self): - ray.init(num_cpus=1) - # This actor acquires resources only when running methods. - @ray.remote - class Actor1(object): - def method(self): - pass +@unittest.skipIf( + os.environ.get("RAY_USE_XRAY") == "1" or sys.version_info < (3, 0), + "This test does not work with xray yet" + " and is currently failing on Python 2.7.") +def testLifetimeAndTransientResources(ray_stop): + ray.init(num_cpus=1) - # This actor acquires resources for its lifetime. - @ray.remote(num_cpus=1) - class Actor2(object): - def method(self): - pass + # This actor acquires resources only when running methods. + @ray.remote + class Actor1(object): + def method(self): + pass - actor1s = [Actor1.remote() for _ in range(10)] - ray.get([a.method.remote() for a in actor1s]) + # This actor acquires resources for its lifetime. + @ray.remote(num_cpus=1) + class Actor2(object): + def method(self): + pass - actor2s = [Actor2.remote() for _ in range(2)] - results = [a.method.remote() for a in actor2s] - ready_ids, remaining_ids = ray.wait( - results, num_returns=len(results), timeout=1000) - self.assertEqual(len(ready_ids), 1) + actor1s = [Actor1.remote() for _ in range(10)] + ray.get([a.method.remote() for a in actor1s]) - def testCustomLabelPlacement(self): - ray.worker._init( - start_ray_local=True, - num_local_schedulers=2, - num_workers=0, - resources=[{ - "CustomResource1": 2 - }, { - "CustomResource2": 2 - }]) + actor2s = [Actor2.remote() for _ in range(2)] + results = [a.method.remote() for a in actor2s] + ready_ids, remaining_ids = ray.wait( + results, num_returns=len(results), timeout=1000) + assert len(ready_ids) == 1 - @ray.remote(resources={"CustomResource1": 1}) - class ResourceActor1(object): - def get_location(self): - return ray.worker.global_worker.plasma_client.store_socket_name - @ray.remote(resources={"CustomResource2": 1}) - class ResourceActor2(object): - def get_location(self): - return ray.worker.global_worker.plasma_client.store_socket_name +def testCustomLabelPlacement(ray_stop): + ray.worker._init( + start_ray_local=True, + num_local_schedulers=2, + num_workers=0, + resources=[{ + "CustomResource1": 2 + }, { + "CustomResource2": 2 + }]) - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name + @ray.remote(resources={"CustomResource1": 1}) + class ResourceActor1(object): + def get_location(self): + return ray.worker.global_worker.plasma_client.store_socket_name - # Create some actors. - actors1 = [ResourceActor1.remote() for _ in range(2)] - actors2 = [ResourceActor2.remote() for _ in range(2)] - locations1 = ray.get([a.get_location.remote() for a in actors1]) - locations2 = ray.get([a.get_location.remote() for a in actors2]) - for location in locations1: - self.assertEqual(location, local_plasma) - for location in locations2: - self.assertNotEqual(location, local_plasma) + @ray.remote(resources={"CustomResource2": 1}) + class ResourceActor2(object): + def get_location(self): + return ray.worker.global_worker.plasma_client.store_socket_name - def testCreatingMoreActorsThanResources(self): - ray.init( - num_workers=0, - num_cpus=10, - num_gpus=2, - resources={"CustomResource1": 1}) + local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - @ray.remote(num_gpus=1) - class ResourceActor1(object): - def method(self): - return ray.get_gpu_ids()[0] + # Create some actors. + actors1 = [ResourceActor1.remote() for _ in range(2)] + actors2 = [ResourceActor2.remote() for _ in range(2)] + locations1 = ray.get([a.get_location.remote() for a in actors1]) + locations2 = ray.get([a.get_location.remote() for a in actors2]) + for location in locations1: + assert location == local_plasma + for location in locations2: + assert location != local_plasma - @ray.remote(resources={"CustomResource1": 1}) - class ResourceActor2(object): - def method(self): - pass - # Make sure the first two actors get created and the third one does - # not. - actor1 = ResourceActor1.remote() - result1 = actor1.method.remote() - ray.wait([result1]) - actor2 = ResourceActor1.remote() - result2 = actor2.method.remote() - ray.wait([result2]) - actor3 = ResourceActor1.remote() - result3 = actor3.method.remote() - ready_ids, _ = ray.wait([result3], timeout=200) - self.assertEqual(len(ready_ids), 0) +def testCreatingMoreActorsThanResources(ray_stop): + ray.init( + num_workers=0, + num_cpus=10, + num_gpus=2, + resources={"CustomResource1": 1}) - # By deleting actor1, we free up resources to create actor3. - del actor1 + @ray.remote(num_gpus=1) + class ResourceActor1(object): + def method(self): + return ray.get_gpu_ids()[0] - results = ray.get([result1, result2, result3]) - self.assertEqual(results[0], results[2]) - self.assertEqual(set(results), {0, 1}) + @ray.remote(resources={"CustomResource1": 1}) + class ResourceActor2(object): + def method(self): + pass - # Make sure that when one actor goes out of scope a new actor is - # created because some resources have been freed up. - results = [] - for _ in range(3): - actor = ResourceActor2.remote() - object_id = actor.method.remote() - results.append(object_id) - # Wait for the task to execute. We do this because otherwise it may - # be possible for the __ray_terminate__ task to execute before the - # method. - ray.wait([object_id]) + # Make sure the first two actors get created and the third one does + # not. + actor1 = ResourceActor1.remote() + result1 = actor1.method.remote() + ray.wait([result1]) + actor2 = ResourceActor1.remote() + result2 = actor2.method.remote() + ray.wait([result2]) + actor3 = ResourceActor1.remote() + result3 = actor3.method.remote() + ready_ids, _ = ray.wait([result3], timeout=200) + assert len(ready_ids) == 0 - ray.get(results) + # By deleting actor1, we free up resources to create actor3. + del actor1 + + results = ray.get([result1, result2, result3]) + assert results[0] == results[2] + assert set(results) == {0, 1} + + # Make sure that when one actor goes out of scope a new actor is + # created because some resources have been freed up. + results = [] + for _ in range(3): + actor = ResourceActor2.remote() + object_id = actor.method.remote() + results.append(object_id) + # Wait for the task to execute. We do this because otherwise it may + # be possible for the __ray_terminate__ task to execute before the + # method. + ray.wait([object_id]) + + ray.get(results) if __name__ == "__main__": diff --git a/test/failure_test.py b/test/failure_test.py index ee1069401..00a1edd4d 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -11,10 +11,6 @@ import time import unittest import ray.ray_constants as ray_constants -import ray.test.test_functions as test_functions - -if sys.version_info >= (3, 0): - from importlib import reload def relevant_errors(error_type): @@ -35,11 +31,22 @@ class TaskStatusTest(unittest.TestCase): ray.shutdown() def testFailedTask(self): - reload(test_functions) + @ray.remote + def throw_exception_fct1(): + raise Exception("Test function 1 intentionally failed.") + + @ray.remote + def throw_exception_fct2(): + raise Exception("Test function 2 intentionally failed.") + + @ray.remote(num_return_vals=3) + def throw_exception_fct3(x): + raise Exception("Test function 3 intentionally failed.") + ray.init(num_workers=3, driver_mode=ray.SILENT_MODE) - test_functions.throw_exception_fct1.remote() - test_functions.throw_exception_fct1.remote() + throw_exception_fct1.remote() + throw_exception_fct1.remote() wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2) self.assertEqual( len(relevant_errors(ray_constants.TASK_PUSH_ERROR)), 2) @@ -47,7 +54,7 @@ class TaskStatusTest(unittest.TestCase): self.assertIn("Test function 1 intentionally failed.", task.get("message")) - x = test_functions.throw_exception_fct2.remote() + x = throw_exception_fct2.remote() try: ray.get(x) except Exception as e: @@ -56,7 +63,7 @@ class TaskStatusTest(unittest.TestCase): # ray.get should throw an exception. self.assertTrue(False) - x, y, z = test_functions.throw_exception_fct3.remote(1.0) + x, y, z = throw_exception_fct3.remote(1.0) for ref in [x, y, z]: try: ray.get(ref) diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index a3b8aee59..5cfc5ce40 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -5,29 +5,30 @@ from __future__ import print_function import os import unittest import ray -import sys import time import numpy as np -import ray.test.test_functions as test_functions - -if sys.version_info >= (3, 0): - from importlib import reload - class MicroBenchmarkTest(unittest.TestCase): def tearDown(self): ray.shutdown() def testTiming(self): - reload(test_functions) + @ray.remote + def empty_function(): + pass + + @ray.remote + def trivial_function(): + return 1 + ray.init(num_workers=3) # Measure the time required to submit a remote task to the scheduler. elapsed_times = [] for _ in range(1000): start_time = time.time() - test_functions.empty_function.remote() + empty_function.remote() end_time = time.time() elapsed_times.append(end_time - start_time) elapsed_times = np.sort(elapsed_times) @@ -44,7 +45,7 @@ class MicroBenchmarkTest(unittest.TestCase): elapsed_times = [] for _ in range(1000): start_time = time.time() - test_functions.trivial_function.remote() + trivial_function.remote() end_time = time.time() elapsed_times.append(end_time - start_time) elapsed_times = np.sort(elapsed_times) @@ -61,7 +62,7 @@ class MicroBenchmarkTest(unittest.TestCase): elapsed_times = [] for _ in range(1000): start_time = time.time() - x = test_functions.trivial_function.remote() + x = trivial_function.remote() ray.get(x) end_time = time.time() elapsed_times.append(end_time - start_time) diff --git a/test/recursion_test.py b/test/recursion_test.py index 38134c108..c3f7db59a 100644 --- a/test/recursion_test.py +++ b/test/recursion_test.py @@ -18,9 +18,10 @@ def factorial(n): return n * ray.get(factorial.remote(n - 1)) -assert ray.get(factorial.remote(0)) == 1 -assert ray.get(factorial.remote(1)) == 1 -assert ray.get(factorial.remote(2)) == 2 -assert ray.get(factorial.remote(3)) == 6 -assert ray.get(factorial.remote(4)) == 24 -assert ray.get(factorial.remote(5)) == 120 +def test_recursion(): + assert ray.get(factorial.remote(0)) == 1 + assert ray.get(factorial.remote(1)) == 1 + assert ray.get(factorial.remote(2)) == 2 + assert ray.get(factorial.remote(3)) == 6 + assert ray.get(factorial.remote(4)) == 24 + assert ray.get(factorial.remote(5)) == 120 diff --git a/test/runtest.py b/test/runtest.py index d55d9c91d..e8b099894 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import os +import pytest import re import string import sys @@ -12,12 +13,8 @@ from collections import defaultdict, namedtuple, OrderedDict import numpy as np import ray -import ray.test.test_functions as test_functions import ray.test.test_utils -if sys.version_info >= (3, 0): - from importlib import reload - def assert_equal(obj1, obj2): module_numpy = (type(obj1).__module__ == np.__name__ @@ -194,98 +191,100 @@ DICT_OBJECTS = ( RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS -class SerializationTest(unittest.TestCase): - def tearDown(self): - ray.shutdown() +@pytest.fixture +def ray_start(): + # Start the Ray processes. + ray.init(num_cpus=1) + yield None + # The code after the yield will run as teardown code. + ray.shutdown() - def testRecursiveObjects(self): - ray.init(num_workers=0) - class ClassA(object): +def test_passing_arguments_by_value(ray_start): + @ray.remote + def f(x): + return x + + # Check that we can pass arguments by value to remote functions and + # that they are uncorrupted. + for obj in RAY_TEST_OBJECTS: + assert_equal(obj, ray.get(f.remote(obj))) + + +def test_ray_recursive_objects(ray_start): + class ClassA(object): + pass + + # Make a list that contains itself. + lst = [] + lst.append(lst) + # Make an object that contains itself as a field. + a1 = ClassA() + a1.field = a1 + # Make two objects that contain each other as fields. + a2 = ClassA() + a3 = ClassA() + a2.field = a3 + a3.field = a2 + # Make a dictionary that contains itself. + d1 = {} + d1["key"] = d1 + # Create a list of recursive objects. + recursive_objects = [lst, a1, a2, a3, d1] + + # Check that exceptions are thrown when we serialize the recursive + # objects. + for obj in recursive_objects: + with pytest.raises(Exception): + ray.put(obj) + + +def test_passing_arguments_by_value_out_of_the_box(ray_start): + @ray.remote + def f(x): + return x + + # Test passing lambdas. + + def temp(): + return 1 + + assert ray.get(f.remote(temp))() == 1 + assert ray.get(f.remote(lambda x: x + 1))(3) == 4 + + # Test sets. + assert ray.get(f.remote(set())) == set() + s = {1, (1, 2, "hi")} + assert ray.get(f.remote(s)) == s + + # Test types. + assert ray.get(f.remote(int)) == int + assert ray.get(f.remote(float)) == float + assert ray.get(f.remote(str)) == str + + class Foo(object): + def __init__(self): pass - # Make a list that contains itself. - lst = [] - lst.append(lst) - # Make an object that contains itself as a field. - a1 = ClassA() - a1.field = a1 - # Make two objects that contain each other as fields. - a2 = ClassA() - a3 = ClassA() - a2.field = a3 - a3.field = a2 - # Make a dictionary that contains itself. - d1 = {} - d1["key"] = d1 - # Create a list of recursive objects. - recursive_objects = [lst, a1, a2, a3, d1] + # Make sure that we can put and get a custom type. Note that the result + # won't be "equal" to Foo. + ray.get(ray.put(Foo)) - # Check that exceptions are thrown when we serialize the recursive - # objects. - for obj in recursive_objects: - self.assertRaises(Exception, lambda: ray.put(obj)) - def testPassingArgumentsByValue(self): - ray.init(num_workers=1) +def test_putting_object_that_closes_over_object_id(ray_start): + # This test is here to prevent a regression of + # https://github.com/ray-project/ray/issues/1317. - @ray.remote - def f(x): - return x + class Foo(object): + def __init__(self): + self.val = ray.put(0) - # Check that we can pass arguments by value to remote functions and - # that they are uncorrupted. - for obj in RAY_TEST_OBJECTS: - assert_equal(obj, ray.get(f.remote(obj))) + def method(self): + f - def testPassingArgumentsByValueOutOfTheBox(self): - ray.init(num_workers=1) - - @ray.remote - def f(x): - return x - - # Test passing lambdas. - - def temp(): - return 1 - - self.assertEqual(ray.get(f.remote(temp))(), 1) - self.assertEqual(ray.get(f.remote(lambda x: x + 1))(3), 4) - - # Test sets. - self.assertEqual(ray.get(f.remote(set())), set()) - s = {1, (1, 2, "hi")} - self.assertEqual(ray.get(f.remote(s)), s) - - # Test types. - self.assertEqual(ray.get(f.remote(int)), int) - self.assertEqual(ray.get(f.remote(float)), float) - self.assertEqual(ray.get(f.remote(str)), str) - - class Foo(object): - def __init__(self): - pass - - # Make sure that we can put and get a custom type. Note that the result - # won't be "equal" to Foo. - ray.get(ray.put(Foo)) - - def testPuttingObjectThatClosesOverObjectID(self): - # This test is here to prevent a regression of - # https://github.com/ray-project/ray/issues/1317. - ray.init(num_workers=0) - - class Foo(object): - def __init__(self): - self.val = ray.put(0) - - def method(self): - f - - f = Foo() - with self.assertRaises(ray.local_scheduler.common_error): - ray.put(f) + f = Foo() + with pytest.raises(ray.local_scheduler.common_error): + ray.put(f) class WorkerTest(unittest.TestCase): @@ -522,46 +521,57 @@ class APITest(unittest.TestCase): self.assertFalse(hasattr(c2, "method1")) def testKeywordArgs(self): - reload(test_functions) + @ray.remote + def keyword_fct1(a, b="hello"): + return "{} {}".format(a, b) + + @ray.remote + def keyword_fct2(a="hello", b="world"): + return "{} {}".format(a, b) + + @ray.remote + def keyword_fct3(a, b, c="hello", d="world"): + return "{} {} {} {}".format(a, b, c, d) + self.init_ray() - x = test_functions.keyword_fct1.remote(1) + x = keyword_fct1.remote(1) self.assertEqual(ray.get(x), "1 hello") - x = test_functions.keyword_fct1.remote(1, "hi") + x = keyword_fct1.remote(1, "hi") self.assertEqual(ray.get(x), "1 hi") - x = test_functions.keyword_fct1.remote(1, b="world") + x = keyword_fct1.remote(1, b="world") self.assertEqual(ray.get(x), "1 world") - x = test_functions.keyword_fct1.remote(a=1, b="world") + x = keyword_fct1.remote(a=1, b="world") self.assertEqual(ray.get(x), "1 world") - x = test_functions.keyword_fct2.remote(a="w", b="hi") + x = keyword_fct2.remote(a="w", b="hi") self.assertEqual(ray.get(x), "w hi") - x = test_functions.keyword_fct2.remote(b="hi", a="w") + x = keyword_fct2.remote(b="hi", a="w") self.assertEqual(ray.get(x), "w hi") - x = test_functions.keyword_fct2.remote(a="w") + x = keyword_fct2.remote(a="w") self.assertEqual(ray.get(x), "w world") - x = test_functions.keyword_fct2.remote(b="hi") + x = keyword_fct2.remote(b="hi") self.assertEqual(ray.get(x), "hello hi") - x = test_functions.keyword_fct2.remote("w") + x = keyword_fct2.remote("w") self.assertEqual(ray.get(x), "w world") - x = test_functions.keyword_fct2.remote("w", "hi") + x = keyword_fct2.remote("w", "hi") self.assertEqual(ray.get(x), "w hi") - x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi") + x = keyword_fct3.remote(0, 1, c="w", d="hi") self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(0, b=1, c="w", d="hi") + x = keyword_fct3.remote(0, b=1, c="w", d="hi") self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(a=0, b=1, c="w", d="hi") + x = keyword_fct3.remote(a=0, b=1, c="w", d="hi") self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w") + x = keyword_fct3.remote(0, 1, d="hi", c="w") self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(0, 1, c="w") + x = keyword_fct3.remote(0, 1, c="w") self.assertEqual(ray.get(x), "0 1 w world") - x = test_functions.keyword_fct3.remote(0, 1, d="hi") + x = keyword_fct3.remote(0, 1, d="hi") self.assertEqual(ray.get(x), "0 1 hello hi") - x = test_functions.keyword_fct3.remote(0, 1) + x = keyword_fct3.remote(0, 1) self.assertEqual(ray.get(x), "0 1 hello world") - x = test_functions.keyword_fct3.remote(a=0, b=1) + x = keyword_fct3.remote(a=0, b=1) self.assertEqual(ray.get(x), "0 1 hello world") # Check that we cannot pass invalid keyword arguments to functions. @@ -597,15 +607,32 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(f3.remote(4)), 4) def testVariableNumberOfArgs(self): - reload(test_functions) + @ray.remote + def varargs_fct1(*a): + return " ".join(map(str, a)) + + @ray.remote + def varargs_fct2(a, *b): + return " ".join(map(str, b)) + + try: + + @ray.remote + def kwargs_throw_exception(**c): + return () + + kwargs_exception_thrown = False + except Exception: + kwargs_exception_thrown = True + self.init_ray() - x = test_functions.varargs_fct1.remote(0, 1, 2) + x = varargs_fct1.remote(0, 1, 2) self.assertEqual(ray.get(x), "0 1 2") - x = test_functions.varargs_fct2.remote(0, 1, 2) + x = varargs_fct2.remote(0, 1, 2) self.assertEqual(ray.get(x), "1 2") - self.assertTrue(test_functions.kwargs_exception_thrown) + self.assertTrue(kwargs_exception_thrown) @ray.remote def f1(*args): @@ -627,10 +654,13 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) def testNoArgs(self): - reload(test_functions) + @ray.remote + def no_op(): + pass + self.init_ray() - ray.get(test_functions.no_op.remote()) + ray.get(no_op.remote()) def testDefiningRemoteFunctions(self): self.init_ray(num_cpus=3) @@ -1200,7 +1230,15 @@ class LocalModeTest(unittest.TestCase): ray.shutdown() def testLocalMode(self): - reload(test_functions) + @ray.remote + def local_mode_f(): + return np.array([0, 0]) + + @ray.remote + def local_mode_g(x): + x[0] = 1 + return x + ray.init(driver_mode=ray.LOCAL_MODE) @ray.remote @@ -1218,9 +1256,9 @@ class LocalModeTest(unittest.TestCase): # Make sure objects are immutable, this example is why we need to copy # arguments before passing them into remote functions in python mode - aref = test_functions.local_mode_f.remote() + aref = local_mode_f.remote() assert_equal(aref, np.array([0, 0])) - bref = test_functions.local_mode_g.remote(aref) + bref = local_mode_g.remote(aref) # Make sure local_mode_g does not mutate aref. assert_equal(aref, np.array([0, 0])) assert_equal(bref, np.array([1, 0])) @@ -2189,9 +2227,9 @@ class GlobalStateAPI(unittest.TestCase): self.assertEqual(task_spec["DriverID"], driver_id) self.assertEqual(task_spec["ReturnObjectIDs"], [result_id]) function_table_entry = function_table[task_spec["FunctionID"]] - self.assertEqual(function_table_entry["Name"], "__main__.f") + self.assertEqual(function_table_entry["Name"], "runtest.f") self.assertEqual(function_table_entry["DriverID"], driver_id) - self.assertEqual(function_table_entry["Module"], "__main__") + self.assertEqual(function_table_entry["Module"], "runtest") self.assertEqual(task_table[task_id], ray.global_state.task_table(task_id))