mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
implement keyword arguments
This commit is contained in:
parent
5ca5f7502b
commit
429d9025eb
4 changed files with 59 additions and 3 deletions
|
@ -1,5 +1,6 @@
|
|||
from types import ModuleType
|
||||
import typing
|
||||
import funcsigs
|
||||
import numpy as np
|
||||
import pynumbuf
|
||||
|
||||
|
@ -105,16 +106,19 @@ def distributed(arg_types, return_types, worker=global_worker):
|
|||
check_return_values(func_call, result) # throws an exception if result is invalid
|
||||
print "Finished executing function {}".format(func.__name__)
|
||||
return result
|
||||
def func_call(*args):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This is what gets run immediately when a worker calls a distributed function."""
|
||||
check_arguments(func_call, list(args)) # throws an exception if args are invalid
|
||||
objrefs = worker.submit_task(func_call.func_name, list(args))
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
check_arguments(func_call, args) # throws an exception if args are invalid
|
||||
objrefs = worker.submit_task(func_call.func_name, args)
|
||||
return objrefs[0] if len(objrefs) == 1 else objrefs
|
||||
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_call.executor = func_executor
|
||||
func_call.arg_types = arg_types
|
||||
func_call.return_types = return_types
|
||||
func_call.is_distributed = True
|
||||
func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
return func_call
|
||||
return distributed_decorator
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
six >= 1.10
|
||||
typing
|
||||
funcsigs
|
||||
subprocess32
|
||||
grpcio
|
||||
|
|
|
@ -181,6 +181,43 @@ class APITest(unittest.TestCase):
|
|||
|
||||
services.cleanup()
|
||||
|
||||
def testKeywordArgs(self):
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path)
|
||||
x = test_functions.keyword_fct1(1)
|
||||
self.assertEqual(orchpy.pull(x), "1 hello")
|
||||
x = test_functions.keyword_fct1(1, "hi")
|
||||
self.assertEqual(orchpy.pull(x), "1 hi")
|
||||
x = test_functions.keyword_fct1(1, b="world")
|
||||
self.assertEqual(orchpy.pull(x), "1 world")
|
||||
|
||||
x = test_functions.keyword_fct2(a="w", b="hi")
|
||||
self.assertEqual(orchpy.pull(x), "w hi")
|
||||
x = test_functions.keyword_fct2(b="hi", a="w")
|
||||
self.assertEqual(orchpy.pull(x), "w hi")
|
||||
x = test_functions.keyword_fct2(a="w")
|
||||
self.assertEqual(orchpy.pull(x), "w world")
|
||||
x = test_functions.keyword_fct2(b="hi")
|
||||
self.assertEqual(orchpy.pull(x), "hello hi")
|
||||
x = test_functions.keyword_fct2("w")
|
||||
self.assertEqual(orchpy.pull(x), "w world")
|
||||
x = test_functions.keyword_fct2("w", "hi")
|
||||
self.assertEqual(orchpy.pull(x), "w hi")
|
||||
|
||||
x = test_functions.keyword_fct3(0, 1, c="w", d="hi")
|
||||
self.assertEqual(orchpy.pull(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3(0, 1, d="hi", c="w")
|
||||
self.assertEqual(orchpy.pull(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3(0, 1, c="w")
|
||||
self.assertEqual(orchpy.pull(x), "0 1 w world")
|
||||
x = test_functions.keyword_fct3(0, 1, d="hi")
|
||||
self.assertEqual(orchpy.pull(x), "0 1 hello hi")
|
||||
x = test_functions.keyword_fct3(0, 1)
|
||||
self.assertEqual(orchpy.pull(x), "0 1 hello world")
|
||||
|
||||
services.cleanup()
|
||||
|
||||
class ReferenceCountingTest(unittest.TestCase):
|
||||
|
||||
def testDeallocation(self):
|
||||
|
|
|
@ -38,3 +38,17 @@ def empty_function():
|
|||
@orchpy.distributed([], [int])
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
@orchpy.distributed([int, str], [str])
|
||||
def keyword_fct1(a, b="hello"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
@orchpy.distributed([str, str], [str])
|
||||
def keyword_fct2(a="hello", b="world"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
@orchpy.distributed([int, int, str, str], [str])
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
|
Loading…
Add table
Reference in a new issue