mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
implement keyword arguments
This commit is contained in:
parent
5ca5f7502b
commit
429d9025eb
4 changed files with 59 additions and 3 deletions
|
@ -1,5 +1,6 @@
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
import typing
|
import typing
|
||||||
|
import funcsigs
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pynumbuf
|
import pynumbuf
|
||||||
|
|
||||||
|
@ -105,16 +106,19 @@ def distributed(arg_types, return_types, worker=global_worker):
|
||||||
check_return_values(func_call, result) # throws an exception if result is invalid
|
check_return_values(func_call, result) # throws an exception if result is invalid
|
||||||
print "Finished executing function {}".format(func.__name__)
|
print "Finished executing function {}".format(func.__name__)
|
||||||
return result
|
return result
|
||||||
def func_call(*args):
|
def func_call(*args, **kwargs):
|
||||||
"""This is what gets run immediately when a worker calls a distributed function."""
|
"""This is what gets run immediately when a worker calls a distributed function."""
|
||||||
check_arguments(func_call, list(args)) # throws an exception if args are invalid
|
args = list(args)
|
||||||
objrefs = worker.submit_task(func_call.func_name, list(args))
|
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||||
|
check_arguments(func_call, args) # throws an exception if args are invalid
|
||||||
|
objrefs = worker.submit_task(func_call.func_name, args)
|
||||||
return objrefs[0] if len(objrefs) == 1 else objrefs
|
return objrefs[0] if len(objrefs) == 1 else objrefs
|
||||||
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
|
func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||||
func_call.executor = func_executor
|
func_call.executor = func_executor
|
||||||
func_call.arg_types = arg_types
|
func_call.arg_types = arg_types
|
||||||
func_call.return_types = return_types
|
func_call.return_types = return_types
|
||||||
func_call.is_distributed = True
|
func_call.is_distributed = True
|
||||||
|
func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||||
return func_call
|
return func_call
|
||||||
return distributed_decorator
|
return distributed_decorator
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
six >= 1.10
|
six >= 1.10
|
||||||
typing
|
typing
|
||||||
|
funcsigs
|
||||||
subprocess32
|
subprocess32
|
||||||
grpcio
|
grpcio
|
||||||
|
|
|
@ -181,6 +181,43 @@ class APITest(unittest.TestCase):
|
||||||
|
|
||||||
services.cleanup()
|
services.cleanup()
|
||||||
|
|
||||||
|
def testKeywordArgs(self):
|
||||||
|
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
test_path = os.path.join(test_dir, "testrecv.py")
|
||||||
|
services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path)
|
||||||
|
x = test_functions.keyword_fct1(1)
|
||||||
|
self.assertEqual(orchpy.pull(x), "1 hello")
|
||||||
|
x = test_functions.keyword_fct1(1, "hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "1 hi")
|
||||||
|
x = test_functions.keyword_fct1(1, b="world")
|
||||||
|
self.assertEqual(orchpy.pull(x), "1 world")
|
||||||
|
|
||||||
|
x = test_functions.keyword_fct2(a="w", b="hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "w hi")
|
||||||
|
x = test_functions.keyword_fct2(b="hi", a="w")
|
||||||
|
self.assertEqual(orchpy.pull(x), "w hi")
|
||||||
|
x = test_functions.keyword_fct2(a="w")
|
||||||
|
self.assertEqual(orchpy.pull(x), "w world")
|
||||||
|
x = test_functions.keyword_fct2(b="hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "hello hi")
|
||||||
|
x = test_functions.keyword_fct2("w")
|
||||||
|
self.assertEqual(orchpy.pull(x), "w world")
|
||||||
|
x = test_functions.keyword_fct2("w", "hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "w hi")
|
||||||
|
|
||||||
|
x = test_functions.keyword_fct3(0, 1, c="w", d="hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "0 1 w hi")
|
||||||
|
x = test_functions.keyword_fct3(0, 1, d="hi", c="w")
|
||||||
|
self.assertEqual(orchpy.pull(x), "0 1 w hi")
|
||||||
|
x = test_functions.keyword_fct3(0, 1, c="w")
|
||||||
|
self.assertEqual(orchpy.pull(x), "0 1 w world")
|
||||||
|
x = test_functions.keyword_fct3(0, 1, d="hi")
|
||||||
|
self.assertEqual(orchpy.pull(x), "0 1 hello hi")
|
||||||
|
x = test_functions.keyword_fct3(0, 1)
|
||||||
|
self.assertEqual(orchpy.pull(x), "0 1 hello world")
|
||||||
|
|
||||||
|
services.cleanup()
|
||||||
|
|
||||||
class ReferenceCountingTest(unittest.TestCase):
|
class ReferenceCountingTest(unittest.TestCase):
|
||||||
|
|
||||||
def testDeallocation(self):
|
def testDeallocation(self):
|
||||||
|
|
|
@ -38,3 +38,17 @@ def empty_function():
|
||||||
@orchpy.distributed([], [int])
|
@orchpy.distributed([], [int])
|
||||||
def trivial_function():
|
def trivial_function():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
# Test keyword arguments
|
||||||
|
|
||||||
|
@orchpy.distributed([int, str], [str])
|
||||||
|
def keyword_fct1(a, b="hello"):
|
||||||
|
return "{} {}".format(a, b)
|
||||||
|
|
||||||
|
@orchpy.distributed([str, str], [str])
|
||||||
|
def keyword_fct2(a="hello", b="world"):
|
||||||
|
return "{} {}".format(a, b)
|
||||||
|
|
||||||
|
@orchpy.distributed([int, int, str, str], [str])
|
||||||
|
def keyword_fct3(a, b, c="hello", d="world"):
|
||||||
|
return "{} {} {} {}".format(a, b, c, d)
|
||||||
|
|
Loading…
Add table
Reference in a new issue