Support calling positional arguments by keyword (fix #998) (#2081)

This commit is contained in:
Adam Gleave 2018-05-17 16:10:26 -07:00 committed by Robert Nishihara
parent afbb260ca4
commit 470887c2ad
4 changed files with 47 additions and 40 deletions

View file

@ -4,6 +4,7 @@ from __future__ import print_function
from collections import namedtuple
import funcsigs
from funcsigs import Parameter
from ray.utils import is_cython
@ -14,15 +15,16 @@ FunctionSignature = namedtuple("FunctionSignature", [
"""This class is used to represent a function signature.
Attributes:
keyword_names: The names of the functions keyword arguments. This is used
to test if an incorrect keyword argument has been passed to the
function.
arg_names: A list containing the name of all arguments.
arg_defaults: A dictionary mapping from argument name to argument default
value. If the argument is not a keyword argument, the default value
will be funcsigs._empty.
arg_is_positionals: A dictionary mapping from argument name to a bool. The
bool will be true if the argument is a *args argument. Otherwise it
will be false.
keyword_names: A set containing the names of the keyword arguments.
Note most arguments in Python can be called as positional or keyword
arguments, so this overlaps (sometimes completely) with arg_names.
function_name: The name of the function whose signature is being
inspected. This is used for printing better error messages.
"""
@ -85,16 +87,13 @@ def check_signature_supported(func, warn=False):
function_name = func.__name__
sig_params = get_signature_params(func)
has_vararg_param = False
has_kwargs_param = False
has_keyword_arg = False
has_kwonly_param = False
for keyword_name, parameter in sig_params:
if parameter.kind == parameter.VAR_KEYWORD:
if parameter.kind == Parameter.VAR_KEYWORD:
has_kwargs_param = True
if parameter.kind == parameter.VAR_POSITIONAL:
has_vararg_param = True
if parameter.default != funcsigs._empty:
has_keyword_arg = True
if parameter.kind == Parameter.KEYWORD_ONLY:
has_kwonly_param = True
if has_kwargs_param:
message = ("The function {} has a **kwargs argument, which is "
@ -103,12 +102,11 @@ def check_signature_supported(func, warn=False):
print(message)
else:
raise Exception(message)
# Check if the user specified a variable number of arguments and any
# keyword arguments.
if has_vararg_param and has_keyword_arg:
message = ("Function {} has a *args argument as well as a keyword "
"argument, which is currently not supported."
.format(function_name))
if has_kwonly_param:
message = ("The function {} has a keyword only argument "
"(defined after * or *args), which is currently "
"not supported.".format(function_name))
if warn:
print(message)
else:
@ -136,20 +134,18 @@ def extract_signature(func, ignore_first=False):
func.__name__))
sig_params = sig_params[1:]
# Extract the names of the keyword arguments.
keyword_names = set()
for keyword_name, parameter in sig_params:
if parameter.default != funcsigs._empty:
keyword_names.add(keyword_name)
# Construct the argument default values and other argument information.
arg_names = []
arg_defaults = []
arg_is_positionals = []
for keyword_name, parameter in sig_params:
arg_names.append(keyword_name)
keyword_names = set()
for arg_name, parameter in sig_params:
arg_names.append(arg_name)
arg_defaults.append(parameter.default)
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
if parameter.kind == Parameter.POSITIONAL_OR_KEYWORD:
# Note KEYWORD_ONLY arguments currently unsupported.
keyword_names.add(arg_name)
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
keyword_names, func.__name__)
@ -189,8 +185,14 @@ def extend_args(function_signature, args, kwargs):
keyword_name, function_name))
# Fill in the remaining arguments.
zipped_info = list(zip(arg_names, arg_defaults,
arg_is_positionals))[len(args):]
for skipped_name in arg_names[0:len(args)]:
if skipped_name in kwargs:
raise Exception("Positional and keyword value provided for the "
"argument '{}' for the function '{}'".format(
keyword_name, function_name))
zipped_info = zip(arg_names, arg_defaults, arg_is_positionals)
zipped_info = list(zipped_info)[len(args):]
for keyword_name, default_value, is_positional in zipped_info:
if keyword_name in kwargs:
args.append(kwargs[keyword_name])
@ -206,9 +208,8 @@ def extend_args(function_signature, args, kwargs):
"'{}' for the function '{}'.".format(
keyword_name, function_name))
too_many_arguments = (len(args) > len(arg_names)
and (len(arg_is_positionals) == 0
or not arg_is_positionals[-1]))
no_positionals = len(arg_is_positionals) == 0 or not arg_is_positionals[-1]
too_many_arguments = len(args) > len(arg_names) and no_positionals
if too_many_arguments:
raise Exception("Too many arguments were passed to the function '{}'"
.format(function_name))

View file

@ -68,16 +68,6 @@ try:
except Exception:
kwargs_exception_thrown = True
try:
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
except Exception:
varargs_and_kwargs_exception_thrown = True
# test throwing an exception

View file

@ -57,6 +57,9 @@ class ActorAPI(unittest.TestCase):
self.assertEqual(
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
(1, 2, "cd"))
self.assertEqual(
ray.get(actor.get_values.remote(arg2="d", arg1=0, arg0=2)),
(3, 2, "cd"))
# Make sure we get an exception if the constructor is called
# incorrectly.
@ -66,6 +69,9 @@ class ActorAPI(unittest.TestCase):
with self.assertRaises(Exception):
actor = Actor.remote(0, 1, 2, arg3=3)
with self.assertRaises(Exception):
actor = Actor.remote(0, arg0=1)
# Make sure we get an exception if the method is called incorrectly.
actor = Actor.remote(1)
with self.assertRaises(Exception):

View file

@ -529,6 +529,8 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(x), "1 hi")
x = test_functions.keyword_fct1.remote(1, b="world")
self.assertEqual(ray.get(x), "1 world")
x = test_functions.keyword_fct1.remote(a=1, b="world")
self.assertEqual(ray.get(x), "1 world")
x = test_functions.keyword_fct2.remote(a="w", b="hi")
self.assertEqual(ray.get(x), "w hi")
@ -545,6 +547,10 @@ class APITest(unittest.TestCase):
x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3.remote(0, b=1, c="w", d="hi")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3.remote(a=0, b=1, c="w", d="hi")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3.remote(0, 1, c="w")
@ -553,6 +559,8 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(x), "0 1 hello hi")
x = test_functions.keyword_fct3.remote(0, 1)
self.assertEqual(ray.get(x), "0 1 hello world")
x = test_functions.keyword_fct3.remote(a=0, b=1)
self.assertEqual(ray.get(x), "0 1 hello world")
# Check that we cannot pass invalid keyword arguments to functions.
@ray.remote
@ -573,6 +581,9 @@ class APITest(unittest.TestCase):
with self.assertRaises(Exception):
f2.remote(0, w=0)
with self.assertRaises(Exception):
f2.remote(3, x=3)
# Make sure we get an exception if too many arguments are passed in.
with self.assertRaises(Exception):
f2.remote(1, 2, 3, 4)
@ -593,7 +604,6 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(x), "1 2")
self.assertTrue(test_functions.kwargs_exception_thrown)
self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown)
@ray.remote
def f1(*args):