implement keyword arguments

This commit is contained in:
Philipp Moritz 2016-06-03 00:10:17 -07:00
parent 5ca5f7502b
commit 429d9025eb
4 changed files with 59 additions and 3 deletions

View file

@ -1,5 +1,6 @@
from types import ModuleType from types import ModuleType
import typing import typing
import funcsigs
import numpy as np import numpy as np
import pynumbuf import pynumbuf
@ -105,16 +106,19 @@ def distributed(arg_types, return_types, worker=global_worker):
check_return_values(func_call, result) # throws an exception if result is invalid check_return_values(func_call, result) # throws an exception if result is invalid
print "Finished executing function {}".format(func.__name__) print "Finished executing function {}".format(func.__name__)
return result return result
def func_call(*args): def func_call(*args, **kwargs):
"""This is what gets run immediately when a worker calls a distributed function.""" """This is what gets run immediately when a worker calls a distributed function."""
check_arguments(func_call, list(args)) # throws an exception if args are invalid args = list(args)
objrefs = worker.submit_task(func_call.func_name, list(args)) args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments
check_arguments(func_call, args) # throws an exception if args are invalid
objrefs = worker.submit_task(func_call.func_name, args)
return objrefs[0] if len(objrefs) == 1 else objrefs return objrefs[0] if len(objrefs) == 1 else objrefs
func_call.func_name = "{}.{}".format(func.__module__, func.__name__) func_call.func_name = "{}.{}".format(func.__module__, func.__name__)
func_call.executor = func_executor func_call.executor = func_executor
func_call.arg_types = arg_types func_call.arg_types = arg_types
func_call.return_types = return_types func_call.return_types = return_types
func_call.is_distributed = True func_call.is_distributed = True
func_call.keyword_defaults = [(k, v.default) for k, v in funcsigs.signature(func).parameters.iteritems()]
return func_call return func_call
return distributed_decorator return distributed_decorator

View file

@ -1,4 +1,5 @@
six >= 1.10 six >= 1.10
typing typing
funcsigs
subprocess32 subprocess32
grpcio grpcio

View file

@ -181,6 +181,43 @@ class APITest(unittest.TestCase):
services.cleanup() services.cleanup()
def testKeywordArgs(self):
test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
services.start_singlenode_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path)
x = test_functions.keyword_fct1(1)
self.assertEqual(orchpy.pull(x), "1 hello")
x = test_functions.keyword_fct1(1, "hi")
self.assertEqual(orchpy.pull(x), "1 hi")
x = test_functions.keyword_fct1(1, b="world")
self.assertEqual(orchpy.pull(x), "1 world")
x = test_functions.keyword_fct2(a="w", b="hi")
self.assertEqual(orchpy.pull(x), "w hi")
x = test_functions.keyword_fct2(b="hi", a="w")
self.assertEqual(orchpy.pull(x), "w hi")
x = test_functions.keyword_fct2(a="w")
self.assertEqual(orchpy.pull(x), "w world")
x = test_functions.keyword_fct2(b="hi")
self.assertEqual(orchpy.pull(x), "hello hi")
x = test_functions.keyword_fct2("w")
self.assertEqual(orchpy.pull(x), "w world")
x = test_functions.keyword_fct2("w", "hi")
self.assertEqual(orchpy.pull(x), "w hi")
x = test_functions.keyword_fct3(0, 1, c="w", d="hi")
self.assertEqual(orchpy.pull(x), "0 1 w hi")
x = test_functions.keyword_fct3(0, 1, d="hi", c="w")
self.assertEqual(orchpy.pull(x), "0 1 w hi")
x = test_functions.keyword_fct3(0, 1, c="w")
self.assertEqual(orchpy.pull(x), "0 1 w world")
x = test_functions.keyword_fct3(0, 1, d="hi")
self.assertEqual(orchpy.pull(x), "0 1 hello hi")
x = test_functions.keyword_fct3(0, 1)
self.assertEqual(orchpy.pull(x), "0 1 hello world")
services.cleanup()
class ReferenceCountingTest(unittest.TestCase): class ReferenceCountingTest(unittest.TestCase):
def testDeallocation(self): def testDeallocation(self):

View file

@ -38,3 +38,17 @@ def empty_function():
@orchpy.distributed([], [int]) @orchpy.distributed([], [int])
def trivial_function(): def trivial_function():
return 1 return 1
# Test keyword arguments
@orchpy.distributed([int, str], [str])
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@orchpy.distributed([str, str], [str])
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@orchpy.distributed([int, int, str, str], [str])
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)