[core] Simplify options handling [Part 1] (#23127)

* handle options

* update doc

* fix serve
This commit is contained in:
Siyuan (Ryans) Zhuang 2022-04-11 20:49:58 -07:00 committed by GitHub
parent 77b0015ea0
commit d7ef546352
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 528 additions and 806 deletions

View file

@ -0,0 +1,212 @@
"""Manage, parse and validate options for Ray tasks, actors and actor methods."""
from typing import Dict, Any, Callable, Tuple, Union, Optional
from dataclasses import dataclass
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@dataclass
class Option:
# Type constraint of an option.
type_constraint: Optional[Union[type, Tuple[type]]] = None
# Value constraint of an option.
value_constraint: Optional[Callable[[Any], bool]] = None
# Error message for value constraint.
error_message_for_value_constraint: Optional[str] = None
# Default value.
default_value: Any = None
def validate(self, keyword: str, value: Any):
"""Validate the option."""
if self.type_constraint is not None:
if not isinstance(value, self.type_constraint):
raise TypeError(
f"The type of keyword '{keyword}' must be {self.type_constraint}, "
f"but received type {type(value)}"
)
if self.value_constraint is not None:
if not self.value_constraint(value):
raise ValueError(self.error_message_for_value_constraint)
def _counting_option(name: str, infinite: bool = True, default_value: Any = None):
"""This is used for positive and discrete options.
Args:
name: The name of the option keyword.
infinite: If True, user could use -1 to represent infinity.
default_value: The default value for this option.
"""
if infinite:
return Option(
(int, type(None)),
lambda x: x is None or x >= -1,
f"The keyword '{name}' only accepts None, 0, -1 or a positive integer, "
"where -1 represents infinity.",
default_value=default_value,
)
return Option(
(int, type(None)),
lambda x: x is None or x >= 0,
f"The keyword '{name}' only accepts None, 0 or a positive integer.",
default_value=default_value,
)
def _resource_option(name: str, default_value: Any = None):
"""This is used for non-negative options, typically for defining resources."""
return Option(
(float, int, type(None)),
lambda x: x is None or x >= 0,
f"The keyword '{name}' only accepts None, 0 or a positive number",
default_value=default_value,
)
_common_options = {
"accelerator_type": Option((str, type(None))),
"memory": _resource_option("memory"),
"name": Option((str, type(None))),
"num_cpus": _resource_option("num_cpus"),
"num_gpus": _resource_option("num_gpus"),
"object_store_memory": _counting_option("object_store_memory", False),
# TODO(suquark): "placement_group", "placement_group_bundle_index"
# and "placement_group_capture_child_tasks" are deprecated,
# use "scheduling_strategy" instead.
"placement_group": Option(
(type(None), str, PlacementGroup), default_value="default"
),
"placement_group_bundle_index": Option(int, default_value=-1),
"placement_group_capture_child_tasks": Option((bool, type(None))),
"resources": Option(
(dict, type(None)),
lambda x: x is None or ("CPU" not in x and "GPU" not in x),
"Use the 'num_cpus' and 'num_gpus' keyword instead of 'CPU' and 'GPU' "
"in 'resources' keyword",
),
"runtime_env": Option((dict, type(None))),
"scheduling_strategy": Option((type(None), str, PlacementGroupSchedulingStrategy)),
}
_task_only_options = {
"max_calls": _counting_option("max_calls", False, default_value=0),
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
"max_retries": _counting_option("max_retries", default_value=3),
# override "_common_options"
"num_cpus": _resource_option("num_cpus", default_value=1),
"num_returns": _counting_option("num_returns", False, default_value=1),
"object_store_memory": Option( # override "_common_options"
(int, type(None)),
lambda x: x is None,
"Setting 'object_store_memory' is not implemented for tasks",
),
"retry_exceptions": Option(bool, default_value=False),
}
_actor_only_options = {
"concurrency_groups": Option((list, dict, type(None))),
"lifetime": Option(
(str, type(None)),
lambda x: x in (None, "detached", "non_detached"),
"actor `lifetime` argument must be one of 'detached', "
"'non_detached' and 'None'.",
),
"max_concurrency": _counting_option("max_concurrency", False),
"max_restarts": _counting_option("max_restarts", default_value=0),
"max_task_retries": _counting_option("max_task_retries", default_value=0),
"max_pending_calls": _counting_option("max_pending_calls", default_value=-1),
"namespace": Option((str, type(None))),
"get_if_exists": Option(bool, default_value=False),
}
# Priority is important here because during dictionary update, same key with higher
# priority overrides the same key with lower priority. We make use of priority
# to set the correct default value for tasks / actors.
# priority: _common_options > _actor_only_options > _task_only_options
valid_options: Dict[str, Option] = {
**_task_only_options,
**_actor_only_options,
**_common_options,
}
# priority: _task_only_options > _common_options
task_options: Dict[str, Option] = {**_common_options, **_task_only_options}
# priority: _actor_only_options > _common_options
actor_options: Dict[str, Option] = {**_common_options, **_actor_only_options}
remote_args_error_string = (
"The @ray.remote decorator must be applied either with no arguments and no "
"parentheses, for example '@ray.remote', or it must be applied using some of "
f"the arguments in the list {list(valid_options.keys())}, for example "
"'@ray.remote(num_returns=2, resources={\"CustomResource\": 1})'."
)
def _check_deprecate_placement_group(options: Dict[str, Any]):
"""Check if deprecated placement group option exists."""
placement_group = options.get("placement_group", "default")
scheduling_strategy = options.get("scheduling_strategy")
# TODO(suquark): @ray.remote(placement_group=None) is used in
# "python/ray/data/impl/remote_fn.py" and many other places,
# while "ray.data.read_api.read_datasource" set "scheduling_strategy=SPREAD".
# This might be a bug, but it is also ok to allow them co-exist.
if (placement_group not in ("default", None)) and (scheduling_strategy is not None):
raise ValueError(
"Placement groups should be specified via the "
"scheduling_strategy option. "
"The placement_group option is deprecated."
)
def validate_task_options(options: Dict[str, Any], in_options: bool):
"""Options check for Ray tasks.
Args:
options: Options for Ray tasks.
in_options: If True, we are checking the options under the context of
".options()".
"""
for k, v in options.items():
if k not in task_options:
raise ValueError(
f"Invalid option keyword {k} for remote functions. "
f"Valid ones are {list(task_options)}."
)
task_options[k].validate(k, v)
if in_options and "max_calls" in options:
raise ValueError("Setting 'max_calls' is not supported in '.options()'.")
_check_deprecate_placement_group(options)
def validate_actor_options(options: Dict[str, Any], in_options: bool):
"""Options check for Ray actors.
Args:
options: Options for Ray actors.
in_options: If True, we are checking the options under the context of
".options()".
"""
for k, v in options.items():
if k not in actor_options:
raise ValueError(
f"Invalid option keyword {k} for actors. "
f"Valid ones are {list(actor_options)}."
)
actor_options[k].validate(k, v)
if in_options and "concurrency_groups" in options:
raise ValueError(
"Setting 'concurrency_groups' is not supported in '.options()'."
)
if options.get("max_restarts", 0) == 0 and options.get("max_task_retries", 0) != 0:
raise ValueError(
"'max_task_retries' cannot be set if 'max_restarts' "
"is 0 or if 'max_restarts' is not set."
)
if options.get("get_if_exists") and not options.get("name"):
raise ValueError("The actor name must be specified to use `get_if_exists`.")
_check_deprecate_placement_group(options)

View file

@ -31,6 +31,7 @@ from ray.util.tracing.tracing_helper import (
_tracing_actor_method_invocation, _tracing_actor_method_invocation,
_inject_tracing_into_class, _inject_tracing_into_class,
) )
from ray._private import ray_option_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -356,6 +357,16 @@ class ActorClassInheritanceException(TypeError):
pass pass
def _process_option_dict(actor_options):
_filled_options = {}
arg_names = set(inspect.getfullargspec(ActorClassMetadata.__init__)[0])
for k, v in ray_option_utils.actor_options.items():
if k in arg_names:
_filled_options[k] = actor_options.get(k, v.default_value)
_filled_options["runtime_env"] = parse_runtime_env(_filled_options["runtime_env"])
return _filled_options
class ActorClass: class ActorClass:
"""An actor class. """An actor class.
@ -419,17 +430,7 @@ class ActorClass:
cls, cls,
modified_class, modified_class,
class_id, class_id,
max_restarts, actor_options,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
concurrency_groups,
scheduling_strategy: SchedulingStrategyT,
): ):
for attribute in [ for attribute in [
"remote", "remote",
@ -473,25 +474,16 @@ class ActorClass:
modified_class.__ray_actor_class__ modified_class.__ray_actor_class__
) )
new_runtime_env = parse_runtime_env(runtime_env)
self.__ray_metadata__ = ActorClassMetadata( self.__ray_metadata__ = ActorClassMetadata(
Language.PYTHON, Language.PYTHON,
modified_class, modified_class,
actor_creation_function_descriptor, actor_creation_function_descriptor,
class_id, class_id,
max_restarts, **_process_option_dict(actor_options),
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
new_runtime_env,
concurrency_groups,
scheduling_strategy,
) )
self._default_options = actor_options
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
return self return self
@ -500,38 +492,19 @@ class ActorClass:
cls, cls,
language, language,
actor_creation_function_descriptor, actor_creation_function_descriptor,
max_restarts, actor_options,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
): ):
self = ActorClass.__new__(ActorClass) self = ActorClass.__new__(ActorClass)
new_runtime_env = parse_runtime_env(runtime_env)
self.__ray_metadata__ = ActorClassMetadata( self.__ray_metadata__ = ActorClassMetadata(
language, language,
None, None,
actor_creation_function_descriptor, actor_creation_function_descriptor,
None, None,
max_restarts, **_process_option_dict(actor_options),
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
new_runtime_env,
[],
None,
) )
self._default_options = actor_options
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
return self return self
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
@ -546,32 +519,9 @@ class ActorClass:
Returns: Returns:
A handle to the newly created actor. A handle to the newly created actor.
""" """
return self._remote(args=args, kwargs=kwargs) return self._remote(args=args, kwargs=kwargs, **self._default_options)
def options( def options(self, args=None, kwargs=None, **actor_options):
self,
args=None,
kwargs=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_concurrency=None,
max_restarts=None,
max_task_retries=None,
name=None,
namespace=None,
get_if_exists=False,
lifetime=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
max_pending_calls=-1,
scheduling_strategy: SchedulingStrategyT = None,
):
"""Configures and overrides the actor instantiation parameters. """Configures and overrides the actor instantiation parameters.
The arguments are the same as those that can be passed The arguments are the same as those that can be passed
@ -592,45 +542,28 @@ class ActorClass:
actor_cls = self actor_cls = self
new_runtime_env = parse_runtime_env(runtime_env) # override original options
default_options = self._default_options.copy()
# "concurrency_groups" could not be used in ".options()",
# we should remove it before merging options from '@ray.remote'.
default_options.pop("concurrency_groups", None)
updated_options = {**default_options, **actor_options}
ray_option_utils.validate_actor_options(updated_options, in_options=True)
cls_options = dict( # only update runtime_env when ".options()" specifies new runtime_env
num_cpus=num_cpus, if "runtime_env" in actor_options:
num_gpus=num_gpus, updated_options["runtime_env"] = parse_runtime_env(
memory=memory, updated_options["runtime_env"]
object_store_memory=object_store_memory,
resources=resources,
accelerator_type=accelerator_type,
max_concurrency=max_concurrency,
max_restarts=max_restarts,
max_task_retries=max_task_retries,
name=name,
namespace=namespace,
lifetime=lifetime,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(placement_group_capture_child_tasks),
runtime_env=new_runtime_env,
max_pending_calls=max_pending_calls,
scheduling_strategy=scheduling_strategy,
) )
class ActorOptionWrapper: class ActorOptionWrapper:
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
# Handle the get-or-create case. # Handle the get-or-create case.
if get_if_exists: if updated_options.get("get_if_exists"):
if not cls_options.get("name"):
raise ValueError(
"The actor name must be specified to use `get_if_exists`."
)
return self._get_or_create_impl(args, kwargs) return self._get_or_create_impl(args, kwargs)
# Normal create case. # Normal create case.
return actor_cls._remote( return actor_cls._remote(args=args, kwargs=kwargs, **updated_options)
args=args,
kwargs=kwargs,
**cls_options,
)
def bind(self, *args, **kwargs): def bind(self, *args, **kwargs):
""" """
@ -645,52 +578,34 @@ class ActorClass:
actor_cls.__ray_metadata__.modified_class, actor_cls.__ray_metadata__.modified_class,
args, args,
kwargs, kwargs,
cls_options, updated_options,
) )
def _get_or_create_impl(self, args, kwargs): def _get_or_create_impl(self, args, kwargs):
name = cls_options["name"] name = updated_options["name"]
try: try:
return ray.get_actor(name, namespace=cls_options.get("namespace")) return ray.get_actor(
name, namespace=updated_options.get("namespace")
)
except ValueError: except ValueError:
# Attempt to create it (may race with other attempts). # Attempt to create it (may race with other attempts).
try: try:
return actor_cls._remote( return actor_cls._remote(
args=args, args=args,
kwargs=kwargs, kwargs=kwargs,
**cls_options, **updated_options,
) )
except ValueError: except ValueError:
# We lost the creation race, ignore. # We lost the creation race, ignore.
pass pass
return ray.get_actor(name, namespace=cls_options.get("namespace")) return ray.get_actor(
name, namespace=updated_options.get("namespace")
)
return ActorOptionWrapper() return ActorOptionWrapper()
@_tracing_actor_creation @_tracing_actor_creation
def _remote( def _remote(self, args=None, kwargs=None, **actor_options):
self,
args=None,
kwargs=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_concurrency=None,
max_restarts=None,
max_task_retries=None,
name=None,
namespace=None,
lifetime=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
max_pending_calls=-1,
scheduling_strategy: SchedulingStrategyT = None,
):
"""Create an actor. """Create an actor.
This method allows more flexibility than the remote method because This method allows more flexibility than the remote method because
@ -752,6 +667,10 @@ class ActorClass:
Returns: Returns:
A handle to the newly created actor. A handle to the newly created actor.
""" """
# We pop the "concurrency_groups" coming from "@ray.remote" here. We no longer
# need it in "_remote()".
actor_options.pop("concurrency_groups", None)
if args is None: if args is None:
args = [] args = []
if kwargs is None: if kwargs is None:
@ -767,41 +686,40 @@ class ActorClass:
) )
is_asyncio = actor_has_async_methods is_asyncio = actor_has_async_methods
if max_concurrency is None: if actor_options.get("max_concurrency") is None:
if is_asyncio: actor_options["max_concurrency"] = 1000 if is_asyncio else 1
max_concurrency = 1000
else:
max_concurrency = 1
if max_concurrency < 1:
raise ValueError("max_concurrency must be >= 1")
if client_mode_should_convert(auto_init=True): if client_mode_should_convert(auto_init=True):
return client_mode_convert_actor( return client_mode_convert_actor(self, args, kwargs, **actor_options)
self,
args, # fill actor required options
kwargs, for k, v in ray_option_utils.actor_options.items():
num_cpus=num_cpus, actor_options[k] = actor_options.get(k, v.default_value)
num_gpus=num_gpus, # "concurrency_groups" already takes effects and should not apply again.
memory=memory, # Remove the default value here.
object_store_memory=object_store_memory, actor_options.pop("concurrency_groups", None)
resources=resources,
accelerator_type=accelerator_type, # TODO(suquark): cleanup these fields
max_concurrency=max_concurrency, max_concurrency = actor_options["max_concurrency"]
max_restarts=max_restarts, name = actor_options["name"]
max_task_retries=max_task_retries, namespace = actor_options["namespace"]
name=name, lifetime = actor_options["lifetime"]
namespace=namespace, num_cpus = actor_options["num_cpus"]
lifetime=lifetime, num_gpus = actor_options["num_gpus"]
placement_group=placement_group, accelerator_type = actor_options["accelerator_type"]
placement_group_bundle_index=placement_group_bundle_index, resources = actor_options["resources"]
placement_group_capture_child_tasks=( memory = actor_options["memory"]
placement_group_capture_child_tasks object_store_memory = actor_options["object_store_memory"]
), runtime_env = actor_options["runtime_env"]
runtime_env=runtime_env, placement_group = actor_options["placement_group"]
max_pending_calls=max_pending_calls, placement_group_bundle_index = actor_options["placement_group_bundle_index"]
scheduling_strategy=scheduling_strategy, placement_group_capture_child_tasks = actor_options[
) "placement_group_capture_child_tasks"
]
scheduling_strategy = actor_options["scheduling_strategy"]
max_restarts = actor_options["max_restarts"]
max_task_retries = actor_options["max_task_retries"]
max_pending_calls = actor_options["max_pending_calls"]
worker = ray.worker.global_worker worker = ray.worker.global_worker
worker.check_connected() worker.check_connected()
@ -849,11 +767,7 @@ class ActorClass:
# decorator. Last three conditions are to check that no resources were # decorator. Last three conditions are to check that no resources were
# specified when _remote() was called. # specified when _remote() was called.
if ( if (
meta.num_cpus is None num_cpus is None
and meta.num_gpus is None
and meta.resources is None
and meta.accelerator_type is None
and num_cpus is None
and num_gpus is None and num_gpus is None
and resources is None and resources is None
and accelerator_type is None and accelerator_type is None
@ -868,8 +782,8 @@ class ActorClass:
# resources are associated with methods. # resources are associated with methods.
cpus_to_use = ( cpus_to_use = (
ray_constants.DEFAULT_ACTOR_CREATION_CPU_SPECIFIED ray_constants.DEFAULT_ACTOR_CREATION_CPU_SPECIFIED
if meta.num_cpus is None if num_cpus is None
else meta.num_cpus else num_cpus
) )
actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED
@ -897,13 +811,14 @@ class ActorClass:
meta.method_meta.methods.keys(), meta.method_meta.methods.keys(),
) )
# TODO(suquark): cleanup "resources_from_resource_arguments" later.
resources = ray._private.utils.resources_from_resource_arguments( resources = ray._private.utils.resources_from_resource_arguments(
cpus_to_use, cpus_to_use,
meta.num_gpus, num_gpus,
meta.memory, memory,
meta.object_store_memory, object_store_memory,
meta.resources, resources,
meta.accelerator_type, accelerator_type,
num_cpus, num_cpus,
num_gpus, num_gpus,
memory, memory,
@ -926,14 +841,6 @@ class ActorClass:
function_signature = meta.method_meta.signatures["__init__"] function_signature = meta.method_meta.signatures["__init__"]
creation_args = signature.flatten_args(function_signature, args, kwargs) creation_args = signature.flatten_args(function_signature, args, kwargs)
scheduling_strategy = scheduling_strategy or meta.scheduling_strategy
if (placement_group != "default") and (scheduling_strategy is not None):
raise ValueError(
"Placement groups should be specified via the "
"scheduling_strategy option. "
"The placement_group option is deprecated."
)
if scheduling_strategy is None or isinstance( if scheduling_strategy is None or isinstance(
scheduling_strategy, PlacementGroupSchedulingStrategy scheduling_strategy, PlacementGroupSchedulingStrategy
): ):
@ -971,19 +878,17 @@ class ActorClass:
else: else:
scheduling_strategy = "DEFAULT" scheduling_strategy = "DEFAULT"
if runtime_env:
new_runtime_env = parse_runtime_env(runtime_env)
else:
new_runtime_env = meta.runtime_env
serialized_runtime_env_info = None serialized_runtime_env_info = None
if new_runtime_env is not None: if runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info( serialized_runtime_env_info = get_runtime_env_info(
new_runtime_env, runtime_env,
is_job_runtime_env=False, is_job_runtime_env=False,
serialize=True, serialize=True,
) )
concurrency_groups_dict = {} concurrency_groups_dict = {}
if meta.concurrency_groups is None:
meta.concurrency_groups = []
for cg_name in meta.concurrency_groups: for cg_name in meta.concurrency_groups:
concurrency_groups_dict[cg_name] = { concurrency_groups_dict[cg_name] = {
"name": cg_name, "name": cg_name,
@ -1020,8 +925,8 @@ class ActorClass:
meta.language, meta.language,
meta.actor_creation_function_descriptor, meta.actor_creation_function_descriptor,
creation_args, creation_args,
max_restarts or meta.max_restarts, max_restarts,
max_task_retries or meta.max_task_retries, max_task_retries,
resources, resources,
actor_placement_resources, actor_placement_resources,
max_concurrency, max_concurrency,
@ -1375,59 +1280,22 @@ def modify_class(cls):
return Class return Class
def make_actor( def make_actor(cls, actor_options):
cls,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
max_restarts,
max_task_retries,
runtime_env,
concurrency_groups,
scheduling_strategy: SchedulingStrategyT,
):
Class = modify_class(cls) Class = modify_class(cls)
_inject_tracing_into_class(Class) _inject_tracing_into_class(Class)
if max_restarts is None: if "max_restarts" in actor_options:
max_restarts = 0 if actor_options["max_restarts"] != -1: # -1 represents infinite restart
if max_task_retries is None:
max_task_retries = 0
if concurrency_groups is None:
concurrency_groups = []
infinite_restart = max_restarts == -1
if not infinite_restart:
if max_restarts < 0:
raise ValueError(
"max_restarts must be an integer >= -1 "
"-1 indicates infinite restarts"
)
else:
# Make sure we don't pass too big of an int to C++, causing # Make sure we don't pass too big of an int to C++, causing
# an overflow. # an overflow.
max_restarts = min(max_restarts, ray_constants.MAX_INT64_VALUE) actor_options["max_restarts"] = min(
actor_options["max_restarts"], ray_constants.MAX_INT64_VALUE
if max_restarts == 0 and max_task_retries != 0: )
raise ValueError("max_task_retries cannot be set if max_restarts is 0.")
return ActorClass._ray_from_modified_class( return ActorClass._ray_from_modified_class(
Class, Class,
ActorClassID.from_random(), ActorClassID.from_random(),
max_restarts, actor_options,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
concurrency_groups,
scheduling_strategy,
) )

View file

@ -82,20 +82,8 @@ def java_function(class_name, function_name):
Language.JAVA, Language.JAVA,
lambda *args, **kwargs: None, lambda *args, **kwargs: None,
JavaFunctionDescriptor(class_name, function_name, ""), JavaFunctionDescriptor(class_name, function_name, ""),
None, # num_cpus, {},
None, # num_gpus, )
None, # memory,
None, # object_store_memory,
None, # resources,
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries,
None, # retry_exceptions,
None, # runtime_env,
None, # placement_group,
None,
) # scheduling_strategy,
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
@ -111,20 +99,8 @@ def cpp_function(function_name):
Language.CPP, Language.CPP,
lambda *args, **kwargs: None, lambda *args, **kwargs: None,
CppFunctionDescriptor(function_name, "PYTHON"), CppFunctionDescriptor(function_name, "PYTHON"),
None, # num_cpus, {},
None, # num_gpus, )
None, # memory,
None, # object_store_memory,
None, # resources,
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries,
None, # retry_exceptions,
None, # runtime_env,
None, # placement_group,
None,
) # scheduling_strategy,
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
@ -139,15 +115,7 @@ def java_actor_class(class_name):
return ActorClass._ray_from_function_descriptor( return ActorClass._ray_from_function_descriptor(
Language.JAVA, Language.JAVA,
JavaFunctionDescriptor(class_name, "<init>", ""), JavaFunctionDescriptor(class_name, "<init>", ""),
max_restarts=0, {},
max_task_retries=0,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
runtime_env=None,
) )
@ -165,13 +133,5 @@ def cpp_actor_class(create_function_name, class_name):
return ActorClass._ray_from_function_descriptor( return ActorClass._ray_from_function_descriptor(
Language.CPP, Language.CPP,
CppFunctionDescriptor(create_function_name, "PYTHON", class_name), CppFunctionDescriptor(create_function_name, "PYTHON", class_name),
max_restarts=0, {},
max_task_retries=0,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
runtime_env=None,
) )

View file

@ -4,10 +4,7 @@ import logging
import uuid import uuid
from ray import cloudpickle as pickle from ray import cloudpickle as pickle
from ray.util.scheduling_strategies import ( from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
PlacementGroupSchedulingStrategy,
SchedulingStrategyT,
)
from ray._raylet import PythonFunctionDescriptor from ray._raylet import PythonFunctionDescriptor
from ray import cross_language, Language from ray import cross_language, Language
from ray._private.client_mode_hook import client_mode_convert_function from ray._private.client_mode_hook import client_mode_convert_function
@ -19,16 +16,7 @@ from ray.util.tracing.tracing_helper import (
_tracing_task_invocation, _tracing_task_invocation,
_inject_tracing_into_function, _inject_tracing_into_function,
) )
from ray._private import ray_option_utils
# Default parameters for remote functions.
DEFAULT_REMOTE_FUNCTION_CPUS = 1
DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS = 1
DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,19 +73,7 @@ class RemoteFunction:
language, language,
function, function,
function_descriptor, function_descriptor,
num_cpus, task_options,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_returns,
max_calls,
max_retries,
retry_exceptions,
runtime_env,
placement_group,
scheduling_strategy: SchedulingStrategyT,
): ):
if inspect.iscoroutinefunction(function): if inspect.iscoroutinefunction(function):
raise ValueError( raise ValueError(
@ -105,56 +81,34 @@ class RemoteFunction:
"async function with `asyncio.get_event_loop.run_until(f())`. " "async function with `asyncio.get_event_loop.run_until(f())`. "
"See more at https://docs.ray.io/en/latest/ray-core/async_api.html#asyncio-for-remote-tasks" # noqa "See more at https://docs.ray.io/en/latest/ray-core/async_api.html#asyncio-for-remote-tasks" # noqa
) )
self._default_options = task_options
# TODO(suquark): This is a workaround for class attributes of options.
# They are being used in some other places, mostly tests. Need cleanup later.
# E.g., actors uses "__ray_metadata__" to collect options, we can so something
# similar for remote functions.
for k, v in ray_option_utils.task_options.items():
setattr(self, "_" + k, task_options.get(k, v.default_value))
self._runtime_env = parse_runtime_env(self._runtime_env)
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self._runtime_env
self._language = language self._language = language
self._function = _inject_tracing_into_function(function) self._function = _inject_tracing_into_function(function)
self._function_name = function.__module__ + "." + function.__name__ self._function_name = function.__module__ + "." + function.__name__
self._function_descriptor = function_descriptor self._function_descriptor = function_descriptor
self._is_cross_language = language != Language.PYTHON self._is_cross_language = language != Language.PYTHON
self._num_cpus = DEFAULT_REMOTE_FUNCTION_CPUS if num_cpus is None else num_cpus
self._num_gpus = num_gpus
self._memory = memory
if object_store_memory is not None:
raise NotImplementedError(
"setting object_store_memory is not implemented for tasks"
)
self._object_store_memory = None
self._resources = resources
self._accelerator_type = accelerator_type
self._num_returns = (
DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS
if num_returns is None
else num_returns
)
self._max_calls = (
DEFAULT_REMOTE_FUNCTION_MAX_CALLS if max_calls is None else max_calls
)
self._max_retries = (
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
if max_retries is None
else max_retries
)
self._retry_exceptions = (
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS
if retry_exceptions is None
else retry_exceptions
)
self._runtime_env = parse_runtime_env(runtime_env)
self._placement_group = placement_group
self._decorator = getattr(function, "__ray_invocation_decorator__", None) self._decorator = getattr(function, "__ray_invocation_decorator__", None)
self._function_signature = ray._private.signature.extract_signature( self._function_signature = ray._private.signature.extract_signature(
self._function self._function
) )
self._scheduling_strategy = scheduling_strategy
self._last_export_session_and_job = None self._last_export_session_and_job = None
self._uuid = uuid.uuid4() self._uuid = uuid.uuid4()
# Override task.remote's signature and docstring # Override task.remote's signature and docstring
@wraps(function) @wraps(function)
def _remote_proxy(*args, **kwargs): def _remote_proxy(*args, **kwargs):
return self._remote(args=args, kwargs=kwargs) return self._remote(args=args, kwargs=kwargs, **self._default_options)
self.remote = _remote_proxy self.remote = _remote_proxy
@ -169,21 +123,7 @@ class RemoteFunction:
self, self,
args=None, args=None,
kwargs=None, kwargs=None,
num_returns=None, **task_options,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
name="",
scheduling_strategy: SchedulingStrategyT = None,
): ):
"""Configures and overrides the task invocation parameters. """Configures and overrides the task invocation parameters.
@ -202,29 +142,24 @@ class RemoteFunction:
""" """
func_cls = self func_cls = self
new_runtime_env = parse_runtime_env(runtime_env)
options = dict( # override original options
num_returns=num_returns, default_options = self._default_options.copy()
num_cpus=num_cpus, # max_calls could not be used in ".options()", we should remove it before
num_gpus=num_gpus, # merging options from '@ray.remote'.
memory=memory, default_options.pop("max_calls", None)
object_store_memory=object_store_memory, updated_options = {**default_options, **task_options}
accelerator_type=accelerator_type, ray_option_utils.validate_task_options(updated_options, in_options=True)
resources=resources,
max_retries=max_retries, # only update runtime_env when ".options()" specifies new runtime_env
retry_exceptions=retry_exceptions, if "runtime_env" in task_options:
placement_group=placement_group, updated_options["runtime_env"] = parse_runtime_env(
placement_group_bundle_index=placement_group_bundle_index, updated_options["runtime_env"]
placement_group_capture_child_tasks=(placement_group_capture_child_tasks),
runtime_env=new_runtime_env,
name=name,
scheduling_strategy=scheduling_strategy,
) )
class FuncWrapper: class FuncWrapper:
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return func_cls._remote(args=args, kwargs=kwargs, **options) return func_cls._remote(args=args, kwargs=kwargs, **updated_options)
def bind(self, *args, **kwargs): def bind(self, *args, **kwargs):
""" """
@ -234,61 +169,18 @@ class RemoteFunction:
""" """
from ray.experimental.dag.function_node import FunctionNode from ray.experimental.dag.function_node import FunctionNode
return FunctionNode( return FunctionNode(func_cls._function, args, kwargs, updated_options)
func_cls._function,
args,
kwargs,
options,
)
return FuncWrapper() return FuncWrapper()
@_tracing_task_invocation @_tracing_task_invocation
def _remote( def _remote(self, args=None, kwargs=None, **task_options):
self,
args=None,
kwargs=None,
num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
name="",
scheduling_strategy: SchedulingStrategyT = None,
):
"""Submit the remote function for execution.""" """Submit the remote function for execution."""
# We pop the "max_calls" coming from "@ray.remote" here. We no longer need
# it in "_remote()".
task_options.pop("max_calls", None)
if client_mode_should_convert(auto_init=True): if client_mode_should_convert(auto_init=True):
return client_mode_convert_function( return client_mode_convert_function(self, args, kwargs, **task_options)
self,
args,
kwargs,
num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
placement_group_capture_child_tasks
),
runtime_env=runtime_env,
name=name,
scheduling_strategy=scheduling_strategy,
)
worker = ray.worker.global_worker worker = ray.worker.global_worker
worker.check_connected() worker.check_connected()
@ -328,35 +220,46 @@ class RemoteFunction:
kwargs = {} if kwargs is None else kwargs kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args args = [] if args is None else args
if num_returns is None: # fill task required options
num_returns = self._num_returns for k, v in ray_option_utils.task_options.items():
if max_retries is None: task_options[k] = task_options.get(k, v.default_value)
max_retries = self._max_retries # "max_calls" already takes effects and should not apply again.
if retry_exceptions is None: # Remove the default value here.
retry_exceptions = self._retry_exceptions task_options.pop("max_calls", None)
if scheduling_strategy is None:
scheduling_strategy = self._scheduling_strategy
# TODO(suquark): cleanup these fields
name = task_options["name"]
num_cpus = task_options["num_cpus"]
num_gpus = task_options["num_gpus"]
accelerator_type = task_options["accelerator_type"]
resources = task_options["resources"]
memory = task_options["memory"]
object_store_memory = task_options["object_store_memory"]
runtime_env = parse_runtime_env(task_options["runtime_env"])
placement_group = task_options["placement_group"]
placement_group_bundle_index = task_options["placement_group_bundle_index"]
placement_group_capture_child_tasks = task_options[
"placement_group_capture_child_tasks"
]
scheduling_strategy = task_options["scheduling_strategy"]
num_returns = task_options["num_returns"]
max_retries = task_options["max_retries"]
retry_exceptions = task_options["retry_exceptions"]
# TODO(suquark): cleanup "resources_from_resource_arguments" later.
resources = ray._private.utils.resources_from_resource_arguments( resources = ray._private.utils.resources_from_resource_arguments(
self._num_cpus,
self._num_gpus,
self._memory,
self._object_store_memory,
self._resources,
self._accelerator_type,
num_cpus, num_cpus,
num_gpus, num_gpus,
memory, memory,
object_store_memory, object_store_memory,
resources, resources,
accelerator_type, accelerator_type,
) num_cpus,
num_gpus,
if (placement_group != "default") and (scheduling_strategy is not None): memory,
raise ValueError( object_store_memory,
"Placement groups should be specified via the " resources,
"scheduling_strategy option. " accelerator_type,
"The placement_group option is deprecated."
) )
if scheduling_strategy is None or isinstance( if scheduling_strategy is None or isinstance(
@ -375,8 +278,6 @@ class RemoteFunction:
placement_group_capture_child_tasks = ( placement_group_capture_child_tasks = (
worker.should_capture_child_tasks_in_placement_group worker.should_capture_child_tasks_in_placement_group
) )
if placement_group == "default":
placement_group = self._placement_group
placement_group = configure_placement_group_based_on_context( placement_group = configure_placement_group_based_on_context(
placement_group_capture_child_tasks, placement_group_capture_child_tasks,
placement_group_bundle_index, placement_group_bundle_index,
@ -394,8 +295,6 @@ class RemoteFunction:
else: else:
scheduling_strategy = "DEFAULT" scheduling_strategy = "DEFAULT"
if not runtime_env or runtime_env == "{}":
runtime_env = self._runtime_env
serialized_runtime_env_info = None serialized_runtime_env_info = None
if runtime_env is not None: if runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info( serialized_runtime_env_info = get_runtime_env_info(
@ -422,7 +321,7 @@ class RemoteFunction:
self._language, self._language,
self._function_descriptor, self._function_descriptor,
list_args, list_args,
name, name if name is not None else "",
num_returns, num_returns,
resources, resources,
max_retries, max_retries,

View file

@ -257,6 +257,7 @@ class ReplicaConfig:
f"Specifying {option} in ray_actor_options is not allowed." f"Specifying {option} in ray_actor_options is not allowed."
) )
# TODO(suquark): reuse options validation of remote function/actor.
# Ray defaults to zero CPUs for placement, we default to one here. # Ray defaults to zero CPUs for placement, we default to one here.
if self.ray_actor_options.get("num_cpus", None) is None: if self.ray_actor_options.get("num_cpus", None) is None:
self.ray_actor_options["num_cpus"] = 1 self.ray_actor_options["num_cpus"] = 1
@ -286,14 +287,13 @@ class ReplicaConfig:
raise ValueError("memory in ray_actor_options must be > 0.") raise ValueError("memory in ray_actor_options must be > 0.")
self.resource_dict["memory"] = memory self.resource_dict["memory"] = memory
if self.ray_actor_options.get("object_store_memory", None) is None: object_store_memory = self.ray_actor_options.get("object_store_memory")
self.ray_actor_options["object_store_memory"] = 0 if not isinstance(object_store_memory, (int, float, type(None))):
object_store_memory = self.ray_actor_options["object_store_memory"]
if not isinstance(object_store_memory, (int, float)):
raise TypeError( raise TypeError(
"object_store_memory in ray_actor_options must be an int or a float." "object_store_memory in ray_actor_options must be an int, float "
"or None."
) )
elif object_store_memory < 0: elif object_store_memory is not None and object_store_memory < 0:
raise ValueError("object_store_memory in ray_actor_options must be >= 0.") raise ValueError("object_store_memory in ray_actor_options must be >= 0.")
self.resource_dict["object_store_memory"] = object_store_memory self.resource_dict["object_store_memory"] = object_store_memory

View file

@ -178,61 +178,104 @@ def test_submit_api(shutdown_only):
def test_invalid_arguments(shutdown_only): def test_invalid_arguments(shutdown_only):
import re
ray.init(num_cpus=2) ray.init(num_cpus=2)
for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]: def f():
with pytest.raises(
ValueError,
match="The keyword 'num_returns' only accepts 0 or a positive integer",
):
@ray.remote(num_returns=opt)
def g1():
return 1 return 1
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]: class A:
with pytest.raises(
ValueError,
match="The keyword 'max_retries' only accepts 0, -1 or a"
" positive integer",
):
@ray.remote(max_retries=opt)
def g2():
return 1
for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_calls' only accepts 0 or a positive integer",
):
@ray.remote(max_calls=opt)
def g3():
return 1
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_restarts' only accepts -1, 0 or a"
" positive integer",
):
@ray.remote(max_restarts=opt)
class A1:
x = 1 x = 1
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]: template1 = (
"The type of keyword '{}' "
+ f"must be {(int, type(None))}, but received type {float}"
)
# Type check
for keyword in ("num_returns", "max_retries", "max_calls"):
with pytest.raises(TypeError, match=re.escape(template1.format(keyword))):
ray.remote(**{keyword: np.random.uniform(0, 1)})(f)
for keyword in ("max_restarts", "max_task_retries"):
with pytest.raises(TypeError, match=re.escape(template1.format(keyword))):
ray.remote(**{keyword: np.random.uniform(0, 1)})(A)
# Value check for non-negative finite values
for keyword in ("num_returns", "max_calls"):
for v in (np.random.randint(-100, -2), -1):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="The keyword 'max_task_retries' only accepts -1, 0 or a" match=f"The keyword '{keyword}' only accepts None, "
" positive integer", f"0 or a positive integer",
): ):
ray.remote(**{keyword: v})(f)
@ray.remote(max_task_retries=opt) # Value check for non-negative and infinite values
class A2: template2 = (
"The keyword '{}' only accepts None, 0, -1 or a positive integer, "
"where -1 represents infinity."
)
with pytest.raises(ValueError, match=template2.format("max_retries")):
ray.remote(max_retries=np.random.randint(-100, -2))(f)
for keyword in ("max_restarts", "max_task_retries"):
with pytest.raises(ValueError, match=template2.format(keyword)):
ray.remote(**{keyword: np.random.randint(-100, -2)})(A)
def test_options():
"""General test of option keywords in Ray."""
import re
from ray._private import ray_option_utils
def f():
return 1
class A:
x = 1 x = 1
task_defaults = {
k: v.default_value for k, v in ray_option_utils.task_options.items()
}
task_defaults_for_options = task_defaults.copy()
task_defaults_for_options.pop("max_calls")
ray.remote(f).options(**task_defaults_for_options)
ray.remote(**task_defaults)(f).options(**task_defaults_for_options)
with pytest.raises(
ValueError,
match=re.escape("Setting 'max_calls' is not supported in '.options()'."),
):
ray.remote(f).options(max_calls=1)
actor_defaults = {
k: v.default_value for k, v in ray_option_utils.actor_options.items()
}
actor_defaults_for_options = actor_defaults.copy()
actor_defaults_for_options.pop("concurrency_groups")
ray.remote(A).options(**actor_defaults_for_options)
ray.remote(**actor_defaults)(A).options(**actor_defaults_for_options)
with pytest.raises(
ValueError,
match=re.escape(
"Setting 'concurrency_groups' is not supported in '.options()'."
),
):
ray.remote(A).options(concurrency_groups=[])
unique_object = type("###", (), {})()
for k, v in ray_option_utils.task_options.items():
v.validate(k, v.default_value)
with pytest.raises(TypeError):
v.validate(k, unique_object)
for k, v in ray_option_utils.actor_options.items():
v.validate(k, v.default_value)
with pytest.raises(TypeError):
v.validate(k, unique_object)
# https://github.com/ray-project/ray/issues/17842 # https://github.com/ray-project/ray/issues/17842
def test_disable_cuda_devices(): def test_disable_cuda_devices():

View file

@ -479,17 +479,12 @@ def test_serializing_exceptions(ray_start_regular_shared):
def test_invalid_task(ray_start_regular_shared): def test_invalid_task(ray_start_regular_shared):
with ray_start_client_server() as ray: with ray_start_client_server() as ray:
with pytest.raises(TypeError):
@ray.remote(runtime_env="invalid value") @ray.remote(runtime_env="invalid value")
def f(): def f():
return 1 return 1
# No exception on making the remote call.
ref = f.remote()
# Exception during scheduling will be raised on ray.get()
with pytest.raises(Exception):
ray.get(ref)
def test_create_remote_before_start(ray_start_regular_shared): def test_create_remote_before_start(ray_start_regular_shared):
"""Creates remote objects (as though in a library) before """Creates remote objects (as though in a library) before

View file

@ -6,6 +6,7 @@ import json
import logging import logging
from ray.util.client.runtime_context import ClientWorkerPropertyAPI from ray.util.client.runtime_context import ClientWorkerPropertyAPI
from ray._private import ray_option_utils
from typing import Any, Callable, List, Optional, TYPE_CHECKING from typing import Any, Callable, List, Optional, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
@ -77,17 +78,9 @@ class ClientAPI:
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote. # This is the case where the decorator is just @ray.remote.
return remote_decorator(options=None)(args[0]) return remote_decorator(options=None)(args[0])
error_string = ( assert (
"The @ray.remote decorator must be applied either " len(args) == 0 and len(kwargs) > 0
"with no arguments and no parentheses, for example " ), ray_option_utils.remote_args_error_string
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
"'@ray.remote(num_returns=2, "
'resources={"CustomResource": 1})\'.'
)
assert len(args) == 0 and len(kwargs) > 0, error_string
return remote_decorator(options=kwargs) return remote_decorator(options=kwargs)
# TODO(mwtian): consider adding _internal_ prefix to call_remote / # TODO(mwtian): consider adding _internal_ prefix to call_remote /

View file

@ -2,54 +2,9 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from ray._private import ray_option_utils
from ray.util.placement_group import PlacementGroup, check_placement_group_index from ray.util.placement_group import PlacementGroup, check_placement_group_index
options = {
"num_returns": (
int,
lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 or a positive integer",
),
"num_cpus": (),
"num_gpus": (),
"resources": (),
"accelerator_type": (),
"max_calls": (
int,
lambda x: x >= 0,
"The keyword 'max_calls' only accepts 0 or a positive integer",
),
"max_restarts": (
int,
lambda x: x >= -1,
"The keyword 'max_restarts' only accepts -1, 0 or a positive integer",
),
"max_task_retries": (
int,
lambda x: x >= -1,
"The keyword 'max_task_retries' only accepts -1, 0 or a positive integer",
),
"max_retries": (
int,
lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 or a positive integer",
),
"retry_exceptions": (),
"max_concurrency": (),
"name": (),
"namespace": (),
"lifetime": (),
"memory": (),
"object_store_memory": (),
"placement_group": (),
"placement_group_bundle_index": (),
"placement_group_capture_child_tasks": (),
"runtime_env": (),
"max_pending_calls": (),
"concurrency_groups": (),
"scheduling_strategy": (),
}
def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if kwargs_dict is None: if kwargs_dict is None:
@ -59,15 +14,12 @@ def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str
out = {} out = {}
for k, v in kwargs_dict.items(): for k, v in kwargs_dict.items():
if k not in options.keys(): if k not in ray_option_utils.valid_options:
raise TypeError(f"Invalid option passed to remote(): {k}") raise ValueError(
validator = options[k] f"Invalid option keyword: '{k}'. "
if len(validator) != 0: f"{ray_option_utils.remote_args_error_string}"
if v is not None: )
if not isinstance(v, validator[0]): ray_option_utils.valid_options[k].validate(k, v)
raise ValueError(validator[2])
if not validator[1](v):
raise ValueError(validator[2])
out[k] = v out[k] = v
# Validate placement setting similar to the logic in ray/actor.py and # Validate placement setting similar to the logic in ray/actor.py and

View file

@ -810,44 +810,15 @@ class Worker:
def _convert_actor(self, actor: "ActorClass") -> str: def _convert_actor(self, actor: "ActorClass") -> str:
"""Register a ClientActorClass for the ActorClass and return a UUID""" """Register a ClientActorClass for the ActorClass and return a UUID"""
key = uuid.uuid4().hex key = uuid.uuid4().hex
md = actor.__ray_metadata__ cls = actor.__ray_metadata__.modified_class
cls = md.modified_class self._converted[key] = ClientActorClass(cls, options=actor._default_options)
self._converted[key] = ClientActorClass(
cls,
options={
"max_restarts": md.max_restarts,
"max_task_retries": md.max_task_retries,
"num_cpus": md.num_cpus,
"num_gpus": md.num_gpus,
"memory": md.memory,
"object_store_memory": md.object_store_memory,
"resources": md.resources,
"accelerator_type": md.accelerator_type,
"runtime_env": md.runtime_env,
"concurrency_groups": md.concurrency_groups,
"scheduling_strategy": md.scheduling_strategy,
},
)
return key return key
def _convert_function(self, func: "RemoteFunction") -> str: def _convert_function(self, func: "RemoteFunction") -> str:
"""Register a ClientRemoteFunc for the ActorClass and return a UUID""" """Register a ClientRemoteFunc for the ActorClass and return a UUID"""
key = uuid.uuid4().hex key = uuid.uuid4().hex
f = func._function
self._converted[key] = ClientRemoteFunc( self._converted[key] = ClientRemoteFunc(
f, func._function, options=func._default_options
options={
"num_cpus": func._num_cpus,
"num_gpus": func._num_gpus,
"max_calls": func._max_calls,
"max_retries": func._max_retries,
"resources": func._resources,
"accelerator_type": func._accelerator_type,
"num_returns": func._num_returns,
"memory": func._memory,
"runtime_env": func._runtime_env,
"scheduling_strategy": func._scheduling_strategy,
},
) )
return key return key

View file

@ -1,6 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import atexit import atexit
import faulthandler import faulthandler
import functools
import hashlib import hashlib
import inspect import inspect
import io import io
@ -29,7 +30,6 @@ import ray.remote_function
import ray.serialization as serialization import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray._private.services as services import ray._private.services as services
from ray.util.scheduling_strategies import SchedulingStrategyT
from ray._private.gcs_pubsub import ( from ray._private.gcs_pubsub import (
GcsPublisher, GcsPublisher,
GcsErrorSubscriber, GcsErrorSubscriber,
@ -43,6 +43,7 @@ import ray._private.import_thread as import_thread
from ray.util.tracing.tracing_helper import import_from_string from ray.util.tracing.tracing_helper import import_from_string
from ray.util.annotations import PublicAPI, DeveloperAPI, Deprecated from ray.util.annotations import PublicAPI, DeveloperAPI, Deprecated
from ray.util.debug import log_once from ray.util.debug import log_once
from ray._private import ray_option_utils
import ray import ray
import colorama import colorama
import setproctitle import setproctitle
@ -2110,122 +2111,25 @@ def _mode(worker=global_worker):
return worker.mode return worker.mode
def make_decorator( def _make_remote(function_or_class, options):
num_returns=None, # filter out placeholders in options
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_calls=None,
max_retries=None,
max_restarts=None,
max_task_retries=None,
runtime_env=None,
placement_group="default",
worker=None,
retry_exceptions=None,
concurrency_groups=None,
scheduling_strategy: SchedulingStrategyT = None,
):
def decorator(function_or_class):
if inspect.isfunction(function_or_class) or is_cython(function_or_class): if inspect.isfunction(function_or_class) or is_cython(function_or_class):
# Set the remote function default resources. ray_option_utils.validate_task_options(options, in_options=False)
if max_restarts is not None:
raise ValueError(
"The keyword 'max_restarts' is not allowed for remote functions."
)
if max_task_retries is not None:
raise ValueError(
"The keyword 'max_task_retries' is not "
"allowed for remote functions."
)
if num_returns is not None and (
not isinstance(num_returns, int) or num_returns < 0
):
raise ValueError(
"The keyword 'num_returns' only accepts 0 or a positive integer"
)
if max_retries is not None and (
not isinstance(max_retries, int) or max_retries < -1
):
raise ValueError(
"The keyword 'max_retries' only accepts 0, -1 or a"
" positive integer"
)
if max_calls is not None and (
not isinstance(max_calls, int) or max_calls < 0
):
raise ValueError(
"The keyword 'max_calls' only accepts 0 or a positive integer"
)
return ray.remote_function.RemoteFunction( return ray.remote_function.RemoteFunction(
Language.PYTHON, Language.PYTHON,
function_or_class, function_or_class,
None, None,
num_cpus, options,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_returns,
max_calls,
max_retries,
retry_exceptions,
runtime_env,
placement_group,
scheduling_strategy,
) )
if inspect.isclass(function_or_class): if inspect.isclass(function_or_class):
if num_returns is not None: ray_option_utils.validate_actor_options(options, in_options=False)
raise TypeError("The keyword 'num_returns' is not allowed for actors.") return ray.actor.make_actor(function_or_class, options)
if max_retries is not None:
raise TypeError("The keyword 'max_retries' is not allowed for actors.")
if retry_exceptions is not None:
raise TypeError(
"The keyword 'retry_exceptions' is not allowed for actors."
)
if max_calls is not None:
raise TypeError("The keyword 'max_calls' is not allowed for actors.")
if max_restarts is not None and (
not isinstance(max_restarts, int) or max_restarts < -1
):
raise ValueError(
"The keyword 'max_restarts' only accepts -1, 0 or a"
" positive integer"
)
if max_task_retries is not None and (
not isinstance(max_task_retries, int) or max_task_retries < -1
):
raise ValueError(
"The keyword 'max_task_retries' only accepts -1, 0 or a"
" positive integer"
)
return ray.actor.make_actor(
function_or_class,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
max_restarts,
max_task_retries,
runtime_env,
concurrency_groups,
scheduling_strategy,
)
raise TypeError( raise TypeError(
"The @ray.remote decorator must be applied to " "The @ray.remote decorator must be applied to either a function or a class."
"either a function or to a class."
) )
return decorator
@PublicAPI @PublicAPI
def remote(*args, **kwargs): def remote(*args, **kwargs):
@ -2294,6 +2198,8 @@ def remote(*args, **kwargs):
accelerator_type: If specified, requires that the task or actor run accelerator_type: If specified, requires that the task or actor run
on a node with the specified type of accelerator. on a node with the specified type of accelerator.
See `ray.accelerators` for accelerator types. See `ray.accelerators` for accelerator types.
memory (float): The heap memory request for this task/actor.
object_store_memory (int): The object store memory request for this task/actor.
max_calls (int): Only for *remote functions*. This specifies the max_calls (int): Only for *remote functions*. This specifies the
maximum number of times that a given worker can execute maximum number of times that a given worker can execute
the given remote function before it must exit the given remote function before it must exit
@ -2341,87 +2247,10 @@ def remote(*args, **kwargs):
`PlacementGroupSchedulingStrategy`: `PlacementGroupSchedulingStrategy`:
placement group based scheduling. placement group based scheduling.
""" """
worker = global_worker # "callable" returns true for both function and class.
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote. # This is the case where the decorator is just @ray.remote.
return make_decorator(worker=worker)(args[0]) # "args[0]" is the class or function under the decorator.
return _make_remote(args[0], {})
# Parse the keyword arguments from the decorator. assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string
valid_kwargs = [ return functools.partial(_make_remote, options=kwargs)
"num_returns",
"num_cpus",
"num_gpus",
"memory",
"object_store_memory",
"resources",
"accelerator_type",
"max_calls",
"max_restarts",
"max_task_retries",
"max_retries",
"runtime_env",
"retry_exceptions",
"placement_group",
"concurrency_groups",
"scheduling_strategy",
]
error_string = (
"The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
f"the arguments in the list {valid_kwargs}, for example "
"'@ray.remote(num_returns=2, "
'resources={"CustomResource": 1})\'.'
)
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in valid_kwargs, error_string
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
resources = kwargs.get("resources")
if not isinstance(resources, dict) and resources is not None:
raise TypeError(
"The 'resources' keyword argument must be a "
f"dictionary, but received type {type(resources)}."
)
if resources is not None:
assert "CPU" not in resources, "Use the 'num_cpus' argument."
assert "GPU" not in resources, "Use the 'num_gpus' argument."
accelerator_type = kwargs.get("accelerator_type")
# Handle other arguments.
num_returns = kwargs.get("num_returns")
max_calls = kwargs.get("max_calls")
max_restarts = kwargs.get("max_restarts")
max_task_retries = kwargs.get("max_task_retries")
memory = kwargs.get("memory")
object_store_memory = kwargs.get("object_store_memory")
max_retries = kwargs.get("max_retries")
runtime_env = kwargs.get("runtime_env")
placement_group = kwargs.get("placement_group", "default")
retry_exceptions = kwargs.get("retry_exceptions")
concurrency_groups = kwargs.get("concurrency_groups")
scheduling_strategy = kwargs.get("scheduling_strategy")
return make_decorator(
num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources,
accelerator_type=accelerator_type,
max_calls=max_calls,
max_restarts=max_restarts,
max_task_retries=max_task_retries,
max_retries=max_retries,
runtime_env=runtime_env,
placement_group=placement_group,
worker=worker,
retry_exceptions=retry_exceptions,
concurrency_groups=concurrency_groups or [],
scheduling_strategy=scheduling_strategy,
)