mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
afbb260ca4
commit
470887c2ad
4 changed files with 47 additions and 40 deletions
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
|
||||
from collections import namedtuple
|
||||
import funcsigs
|
||||
from funcsigs import Parameter
|
||||
|
||||
from ray.utils import is_cython
|
||||
|
||||
|
@ -14,15 +15,16 @@ FunctionSignature = namedtuple("FunctionSignature", [
|
|||
"""This class is used to represent a function signature.
|
||||
|
||||
Attributes:
|
||||
keyword_names: The names of the functions keyword arguments. This is used
|
||||
to test if an incorrect keyword argument has been passed to the
|
||||
function.
|
||||
arg_names: A list containing the name of all arguments.
|
||||
arg_defaults: A dictionary mapping from argument name to argument default
|
||||
value. If the argument is not a keyword argument, the default value
|
||||
will be funcsigs._empty.
|
||||
arg_is_positionals: A dictionary mapping from argument name to a bool. The
|
||||
bool will be true if the argument is a *args argument. Otherwise it
|
||||
will be false.
|
||||
keyword_names: A set containing the names of the keyword arguments.
|
||||
Note most arguments in Python can be called as positional or keyword
|
||||
arguments, so this overlaps (sometimes completely) with arg_names.
|
||||
function_name: The name of the function whose signature is being
|
||||
inspected. This is used for printing better error messages.
|
||||
"""
|
||||
|
@ -85,16 +87,13 @@ def check_signature_supported(func, warn=False):
|
|||
function_name = func.__name__
|
||||
sig_params = get_signature_params(func)
|
||||
|
||||
has_vararg_param = False
|
||||
has_kwargs_param = False
|
||||
has_keyword_arg = False
|
||||
has_kwonly_param = False
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.kind == parameter.VAR_KEYWORD:
|
||||
if parameter.kind == Parameter.VAR_KEYWORD:
|
||||
has_kwargs_param = True
|
||||
if parameter.kind == parameter.VAR_POSITIONAL:
|
||||
has_vararg_param = True
|
||||
if parameter.default != funcsigs._empty:
|
||||
has_keyword_arg = True
|
||||
if parameter.kind == Parameter.KEYWORD_ONLY:
|
||||
has_kwonly_param = True
|
||||
|
||||
if has_kwargs_param:
|
||||
message = ("The function {} has a **kwargs argument, which is "
|
||||
|
@ -103,12 +102,11 @@ def check_signature_supported(func, warn=False):
|
|||
print(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
# Check if the user specified a variable number of arguments and any
|
||||
# keyword arguments.
|
||||
if has_vararg_param and has_keyword_arg:
|
||||
message = ("Function {} has a *args argument as well as a keyword "
|
||||
"argument, which is currently not supported."
|
||||
.format(function_name))
|
||||
|
||||
if has_kwonly_param:
|
||||
message = ("The function {} has a keyword only argument "
|
||||
"(defined after * or *args), which is currently "
|
||||
"not supported.".format(function_name))
|
||||
if warn:
|
||||
print(message)
|
||||
else:
|
||||
|
@ -136,20 +134,18 @@ def extract_signature(func, ignore_first=False):
|
|||
func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
|
||||
# Extract the names of the keyword arguments.
|
||||
keyword_names = set()
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.default != funcsigs._empty:
|
||||
keyword_names.add(keyword_name)
|
||||
|
||||
# Construct the argument default values and other argument information.
|
||||
arg_names = []
|
||||
arg_defaults = []
|
||||
arg_is_positionals = []
|
||||
for keyword_name, parameter in sig_params:
|
||||
arg_names.append(keyword_name)
|
||||
keyword_names = set()
|
||||
for arg_name, parameter in sig_params:
|
||||
arg_names.append(arg_name)
|
||||
arg_defaults.append(parameter.default)
|
||||
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
|
||||
if parameter.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
# Note KEYWORD_ONLY arguments currently unsupported.
|
||||
keyword_names.add(arg_name)
|
||||
|
||||
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
|
||||
keyword_names, func.__name__)
|
||||
|
@ -189,8 +185,14 @@ def extend_args(function_signature, args, kwargs):
|
|||
keyword_name, function_name))
|
||||
|
||||
# Fill in the remaining arguments.
|
||||
zipped_info = list(zip(arg_names, arg_defaults,
|
||||
arg_is_positionals))[len(args):]
|
||||
for skipped_name in arg_names[0:len(args)]:
|
||||
if skipped_name in kwargs:
|
||||
raise Exception("Positional and keyword value provided for the "
|
||||
"argument '{}' for the function '{}'".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
zipped_info = zip(arg_names, arg_defaults, arg_is_positionals)
|
||||
zipped_info = list(zipped_info)[len(args):]
|
||||
for keyword_name, default_value, is_positional in zipped_info:
|
||||
if keyword_name in kwargs:
|
||||
args.append(kwargs[keyword_name])
|
||||
|
@ -206,9 +208,8 @@ def extend_args(function_signature, args, kwargs):
|
|||
"'{}' for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
too_many_arguments = (len(args) > len(arg_names)
|
||||
and (len(arg_is_positionals) == 0
|
||||
or not arg_is_positionals[-1]))
|
||||
no_positionals = len(arg_is_positionals) == 0 or not arg_is_positionals[-1]
|
||||
too_many_arguments = len(args) > len(arg_names) and no_positionals
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
|
|
|
@ -68,16 +68,6 @@ try:
|
|||
except Exception:
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
varargs_and_kwargs_exception_thrown = True
|
||||
|
||||
# test throwing an exception
|
||||
|
||||
|
||||
|
|
|
@ -57,6 +57,9 @@ class ActorAPI(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(0, arg2="d", arg1=0)),
|
||||
(1, 2, "cd"))
|
||||
self.assertEqual(
|
||||
ray.get(actor.get_values.remote(arg2="d", arg1=0, arg0=2)),
|
||||
(3, 2, "cd"))
|
||||
|
||||
# Make sure we get an exception if the constructor is called
|
||||
# incorrectly.
|
||||
|
@ -66,6 +69,9 @@ class ActorAPI(unittest.TestCase):
|
|||
with self.assertRaises(Exception):
|
||||
actor = Actor.remote(0, 1, 2, arg3=3)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
actor = Actor.remote(0, arg0=1)
|
||||
|
||||
# Make sure we get an exception if the method is called incorrectly.
|
||||
actor = Actor.remote(1)
|
||||
with self.assertRaises(Exception):
|
||||
|
|
|
@ -529,6 +529,8 @@ class APITest(unittest.TestCase):
|
|||
self.assertEqual(ray.get(x), "1 hi")
|
||||
x = test_functions.keyword_fct1.remote(1, b="world")
|
||||
self.assertEqual(ray.get(x), "1 world")
|
||||
x = test_functions.keyword_fct1.remote(a=1, b="world")
|
||||
self.assertEqual(ray.get(x), "1 world")
|
||||
|
||||
x = test_functions.keyword_fct2.remote(a="w", b="hi")
|
||||
self.assertEqual(ray.get(x), "w hi")
|
||||
|
@ -545,6 +547,10 @@ class APITest(unittest.TestCase):
|
|||
|
||||
x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi")
|
||||
self.assertEqual(ray.get(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3.remote(0, b=1, c="w", d="hi")
|
||||
self.assertEqual(ray.get(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3.remote(a=0, b=1, c="w", d="hi")
|
||||
self.assertEqual(ray.get(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w")
|
||||
self.assertEqual(ray.get(x), "0 1 w hi")
|
||||
x = test_functions.keyword_fct3.remote(0, 1, c="w")
|
||||
|
@ -553,6 +559,8 @@ class APITest(unittest.TestCase):
|
|||
self.assertEqual(ray.get(x), "0 1 hello hi")
|
||||
x = test_functions.keyword_fct3.remote(0, 1)
|
||||
self.assertEqual(ray.get(x), "0 1 hello world")
|
||||
x = test_functions.keyword_fct3.remote(a=0, b=1)
|
||||
self.assertEqual(ray.get(x), "0 1 hello world")
|
||||
|
||||
# Check that we cannot pass invalid keyword arguments to functions.
|
||||
@ray.remote
|
||||
|
@ -573,6 +581,9 @@ class APITest(unittest.TestCase):
|
|||
with self.assertRaises(Exception):
|
||||
f2.remote(0, w=0)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
f2.remote(3, x=3)
|
||||
|
||||
# Make sure we get an exception if too many arguments are passed in.
|
||||
with self.assertRaises(Exception):
|
||||
f2.remote(1, 2, 3, 4)
|
||||
|
@ -593,7 +604,6 @@ class APITest(unittest.TestCase):
|
|||
self.assertEqual(ray.get(x), "1 2")
|
||||
|
||||
self.assertTrue(test_functions.kwargs_exception_thrown)
|
||||
self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown)
|
||||
|
||||
@ray.remote
|
||||
def f1(*args):
|
||||
|
|
Loading…
Add table
Reference in a new issue