[core] Simplify options handling [Part 1] (#23127)

* handle options

* update doc

* fix serve
This commit is contained in:
Siyuan (Ryans) Zhuang 2022-04-11 20:49:58 -07:00 committed by GitHub
parent 77b0015ea0
commit d7ef546352
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 528 additions and 806 deletions

View file

@ -0,0 +1,212 @@
"""Manage, parse and validate options for Ray tasks, actors and actor methods."""
from typing import Dict, Any, Callable, Tuple, Union, Optional
from dataclasses import dataclass
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@dataclass
class Option:
# Type constraint of an option.
type_constraint: Optional[Union[type, Tuple[type]]] = None
# Value constraint of an option.
value_constraint: Optional[Callable[[Any], bool]] = None
# Error message for value constraint.
error_message_for_value_constraint: Optional[str] = None
# Default value.
default_value: Any = None
def validate(self, keyword: str, value: Any):
"""Validate the option."""
if self.type_constraint is not None:
if not isinstance(value, self.type_constraint):
raise TypeError(
f"The type of keyword '{keyword}' must be {self.type_constraint}, "
f"but received type {type(value)}"
)
if self.value_constraint is not None:
if not self.value_constraint(value):
raise ValueError(self.error_message_for_value_constraint)
def _counting_option(name: str, infinite: bool = True, default_value: Any = None):
"""This is used for positive and discrete options.
Args:
name: The name of the option keyword.
infinite: If True, user could use -1 to represent infinity.
default_value: The default value for this option.
"""
if infinite:
return Option(
(int, type(None)),
lambda x: x is None or x >= -1,
f"The keyword '{name}' only accepts None, 0, -1 or a positive integer, "
"where -1 represents infinity.",
default_value=default_value,
)
return Option(
(int, type(None)),
lambda x: x is None or x >= 0,
f"The keyword '{name}' only accepts None, 0 or a positive integer.",
default_value=default_value,
)
def _resource_option(name: str, default_value: Any = None):
"""This is used for non-negative options, typically for defining resources."""
return Option(
(float, int, type(None)),
lambda x: x is None or x >= 0,
f"The keyword '{name}' only accepts None, 0 or a positive number",
default_value=default_value,
)
_common_options = {
"accelerator_type": Option((str, type(None))),
"memory": _resource_option("memory"),
"name": Option((str, type(None))),
"num_cpus": _resource_option("num_cpus"),
"num_gpus": _resource_option("num_gpus"),
"object_store_memory": _counting_option("object_store_memory", False),
# TODO(suquark): "placement_group", "placement_group_bundle_index"
# and "placement_group_capture_child_tasks" are deprecated,
# use "scheduling_strategy" instead.
"placement_group": Option(
(type(None), str, PlacementGroup), default_value="default"
),
"placement_group_bundle_index": Option(int, default_value=-1),
"placement_group_capture_child_tasks": Option((bool, type(None))),
"resources": Option(
(dict, type(None)),
lambda x: x is None or ("CPU" not in x and "GPU" not in x),
"Use the 'num_cpus' and 'num_gpus' keyword instead of 'CPU' and 'GPU' "
"in 'resources' keyword",
),
"runtime_env": Option((dict, type(None))),
"scheduling_strategy": Option((type(None), str, PlacementGroupSchedulingStrategy)),
}
_task_only_options = {
"max_calls": _counting_option("max_calls", False, default_value=0),
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
"max_retries": _counting_option("max_retries", default_value=3),
# override "_common_options"
"num_cpus": _resource_option("num_cpus", default_value=1),
"num_returns": _counting_option("num_returns", False, default_value=1),
"object_store_memory": Option( # override "_common_options"
(int, type(None)),
lambda x: x is None,
"Setting 'object_store_memory' is not implemented for tasks",
),
"retry_exceptions": Option(bool, default_value=False),
}
_actor_only_options = {
"concurrency_groups": Option((list, dict, type(None))),
"lifetime": Option(
(str, type(None)),
lambda x: x in (None, "detached", "non_detached"),
"actor `lifetime` argument must be one of 'detached', "
"'non_detached' and 'None'.",
),
"max_concurrency": _counting_option("max_concurrency", False),
"max_restarts": _counting_option("max_restarts", default_value=0),
"max_task_retries": _counting_option("max_task_retries", default_value=0),
"max_pending_calls": _counting_option("max_pending_calls", default_value=-1),
"namespace": Option((str, type(None))),
"get_if_exists": Option(bool, default_value=False),
}
# Priority is important here because during dictionary update, same key with higher
# priority overrides the same key with lower priority. We make use of priority
# to set the correct default value for tasks / actors.
# priority: _common_options > _actor_only_options > _task_only_options
valid_options: Dict[str, Option] = {
**_task_only_options,
**_actor_only_options,
**_common_options,
}
# priority: _task_only_options > _common_options
task_options: Dict[str, Option] = {**_common_options, **_task_only_options}
# priority: _actor_only_options > _common_options
actor_options: Dict[str, Option] = {**_common_options, **_actor_only_options}
remote_args_error_string = (
"The @ray.remote decorator must be applied either with no arguments and no "
"parentheses, for example '@ray.remote', or it must be applied using some of "
f"the arguments in the list {list(valid_options.keys())}, for example "
"'@ray.remote(num_returns=2, resources={\"CustomResource\": 1})'."
)
def _check_deprecate_placement_group(options: Dict[str, Any]):
"""Check if deprecated placement group option exists."""
placement_group = options.get("placement_group", "default")
scheduling_strategy = options.get("scheduling_strategy")
# TODO(suquark): @ray.remote(placement_group=None) is used in
# "python/ray/data/impl/remote_fn.py" and many other places,
# while "ray.data.read_api.read_datasource" set "scheduling_strategy=SPREAD".
# This might be a bug, but it is also ok to allow them co-exist.
if (placement_group not in ("default", None)) and (scheduling_strategy is not None):
raise ValueError(
"Placement groups should be specified via the "
"scheduling_strategy option. "
"The placement_group option is deprecated."
)
def validate_task_options(options: Dict[str, Any], in_options: bool):
"""Options check for Ray tasks.
Args:
options: Options for Ray tasks.
in_options: If True, we are checking the options under the context of
".options()".
"""
for k, v in options.items():
if k not in task_options:
raise ValueError(
f"Invalid option keyword {k} for remote functions. "
f"Valid ones are {list(task_options)}."
)
task_options[k].validate(k, v)
if in_options and "max_calls" in options:
raise ValueError("Setting 'max_calls' is not supported in '.options()'.")
_check_deprecate_placement_group(options)
def validate_actor_options(options: Dict[str, Any], in_options: bool):
"""Options check for Ray actors.
Args:
options: Options for Ray actors.
in_options: If True, we are checking the options under the context of
".options()".
"""
for k, v in options.items():
if k not in actor_options:
raise ValueError(
f"Invalid option keyword {k} for actors. "
f"Valid ones are {list(actor_options)}."
)
actor_options[k].validate(k, v)
if in_options and "concurrency_groups" in options:
raise ValueError(
"Setting 'concurrency_groups' is not supported in '.options()'."
)
if options.get("max_restarts", 0) == 0 and options.get("max_task_retries", 0) != 0:
raise ValueError(
"'max_task_retries' cannot be set if 'max_restarts' "
"is 0 or if 'max_restarts' is not set."
)
if options.get("get_if_exists") and not options.get("name"):
raise ValueError("The actor name must be specified to use `get_if_exists`.")
_check_deprecate_placement_group(options)

View file

@ -31,6 +31,7 @@ from ray.util.tracing.tracing_helper import (
_tracing_actor_method_invocation,
_inject_tracing_into_class,
)
from ray._private import ray_option_utils
logger = logging.getLogger(__name__)
@ -356,6 +357,16 @@ class ActorClassInheritanceException(TypeError):
pass
def _process_option_dict(actor_options):
_filled_options = {}
arg_names = set(inspect.getfullargspec(ActorClassMetadata.__init__)[0])
for k, v in ray_option_utils.actor_options.items():
if k in arg_names:
_filled_options[k] = actor_options.get(k, v.default_value)
_filled_options["runtime_env"] = parse_runtime_env(_filled_options["runtime_env"])
return _filled_options
class ActorClass:
"""An actor class.
@ -419,17 +430,7 @@ class ActorClass:
cls,
modified_class,
class_id,
max_restarts,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
concurrency_groups,
scheduling_strategy: SchedulingStrategyT,
actor_options,
):
for attribute in [
"remote",
@ -473,25 +474,16 @@ class ActorClass:
modified_class.__ray_actor_class__
)
new_runtime_env = parse_runtime_env(runtime_env)
self.__ray_metadata__ = ActorClassMetadata(
Language.PYTHON,
modified_class,
actor_creation_function_descriptor,
class_id,
max_restarts,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
new_runtime_env,
concurrency_groups,
scheduling_strategy,
**_process_option_dict(actor_options),
)
self._default_options = actor_options
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
return self
@ -500,38 +492,19 @@ class ActorClass:
cls,
language,
actor_creation_function_descriptor,
max_restarts,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
actor_options,
):
self = ActorClass.__new__(ActorClass)
new_runtime_env = parse_runtime_env(runtime_env)
self.__ray_metadata__ = ActorClassMetadata(
language,
None,
actor_creation_function_descriptor,
None,
max_restarts,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
new_runtime_env,
[],
None,
**_process_option_dict(actor_options),
)
self._default_options = actor_options
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self.__ray_metadata__.runtime_env
return self
def remote(self, *args, **kwargs):
@ -546,32 +519,9 @@ class ActorClass:
Returns:
A handle to the newly created actor.
"""
return self._remote(args=args, kwargs=kwargs)
return self._remote(args=args, kwargs=kwargs, **self._default_options)
def options(
self,
args=None,
kwargs=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_concurrency=None,
max_restarts=None,
max_task_retries=None,
name=None,
namespace=None,
get_if_exists=False,
lifetime=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
max_pending_calls=-1,
scheduling_strategy: SchedulingStrategyT = None,
):
def options(self, args=None, kwargs=None, **actor_options):
"""Configures and overrides the actor instantiation parameters.
The arguments are the same as those that can be passed
@ -592,45 +542,28 @@ class ActorClass:
actor_cls = self
new_runtime_env = parse_runtime_env(runtime_env)
# override original options
default_options = self._default_options.copy()
# "concurrency_groups" could not be used in ".options()",
# we should remove it before merging options from '@ray.remote'.
default_options.pop("concurrency_groups", None)
updated_options = {**default_options, **actor_options}
ray_option_utils.validate_actor_options(updated_options, in_options=True)
cls_options = dict(
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources,
accelerator_type=accelerator_type,
max_concurrency=max_concurrency,
max_restarts=max_restarts,
max_task_retries=max_task_retries,
name=name,
namespace=namespace,
lifetime=lifetime,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(placement_group_capture_child_tasks),
runtime_env=new_runtime_env,
max_pending_calls=max_pending_calls,
scheduling_strategy=scheduling_strategy,
)
# only update runtime_env when ".options()" specifies new runtime_env
if "runtime_env" in actor_options:
updated_options["runtime_env"] = parse_runtime_env(
updated_options["runtime_env"]
)
class ActorOptionWrapper:
def remote(self, *args, **kwargs):
# Handle the get-or-create case.
if get_if_exists:
if not cls_options.get("name"):
raise ValueError(
"The actor name must be specified to use `get_if_exists`."
)
if updated_options.get("get_if_exists"):
return self._get_or_create_impl(args, kwargs)
# Normal create case.
return actor_cls._remote(
args=args,
kwargs=kwargs,
**cls_options,
)
return actor_cls._remote(args=args, kwargs=kwargs, **updated_options)
def bind(self, *args, **kwargs):
"""
@ -645,52 +578,34 @@ class ActorClass:
actor_cls.__ray_metadata__.modified_class,
args,
kwargs,
cls_options,
updated_options,
)
def _get_or_create_impl(self, args, kwargs):
name = cls_options["name"]
name = updated_options["name"]
try:
return ray.get_actor(name, namespace=cls_options.get("namespace"))
return ray.get_actor(
name, namespace=updated_options.get("namespace")
)
except ValueError:
# Attempt to create it (may race with other attempts).
try:
return actor_cls._remote(
args=args,
kwargs=kwargs,
**cls_options,
**updated_options,
)
except ValueError:
# We lost the creation race, ignore.
pass
return ray.get_actor(name, namespace=cls_options.get("namespace"))
return ray.get_actor(
name, namespace=updated_options.get("namespace")
)
return ActorOptionWrapper()
@_tracing_actor_creation
def _remote(
self,
args=None,
kwargs=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_concurrency=None,
max_restarts=None,
max_task_retries=None,
name=None,
namespace=None,
lifetime=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
max_pending_calls=-1,
scheduling_strategy: SchedulingStrategyT = None,
):
def _remote(self, args=None, kwargs=None, **actor_options):
"""Create an actor.
This method allows more flexibility than the remote method because
@ -752,6 +667,10 @@ class ActorClass:
Returns:
A handle to the newly created actor.
"""
# We pop the "concurrency_groups" coming from "@ray.remote" here. We no longer
# need it in "_remote()".
actor_options.pop("concurrency_groups", None)
if args is None:
args = []
if kwargs is None:
@ -767,41 +686,40 @@ class ActorClass:
)
is_asyncio = actor_has_async_methods
if max_concurrency is None:
if is_asyncio:
max_concurrency = 1000
else:
max_concurrency = 1
if max_concurrency < 1:
raise ValueError("max_concurrency must be >= 1")
if actor_options.get("max_concurrency") is None:
actor_options["max_concurrency"] = 1000 if is_asyncio else 1
if client_mode_should_convert(auto_init=True):
return client_mode_convert_actor(
self,
args,
kwargs,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources,
accelerator_type=accelerator_type,
max_concurrency=max_concurrency,
max_restarts=max_restarts,
max_task_retries=max_task_retries,
name=name,
namespace=namespace,
lifetime=lifetime,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
placement_group_capture_child_tasks
),
runtime_env=runtime_env,
max_pending_calls=max_pending_calls,
scheduling_strategy=scheduling_strategy,
)
return client_mode_convert_actor(self, args, kwargs, **actor_options)
# fill actor required options
for k, v in ray_option_utils.actor_options.items():
actor_options[k] = actor_options.get(k, v.default_value)
# "concurrency_groups" already takes effects and should not apply again.
# Remove the default value here.
actor_options.pop("concurrency_groups", None)
# TODO(suquark): cleanup these fields
max_concurrency = actor_options["max_concurrency"]
name = actor_options["name"]
namespace = actor_options["namespace"]
lifetime = actor_options["lifetime"]
num_cpus = actor_options["num_cpus"]
num_gpus = actor_options["num_gpus"]
accelerator_type = actor_options["accelerator_type"]
resources = actor_options["resources"]
memory = actor_options["memory"]
object_store_memory = actor_options["object_store_memory"]
runtime_env = actor_options["runtime_env"]
placement_group = actor_options["placement_group"]
placement_group_bundle_index = actor_options["placement_group_bundle_index"]
placement_group_capture_child_tasks = actor_options[
"placement_group_capture_child_tasks"
]
scheduling_strategy = actor_options["scheduling_strategy"]
max_restarts = actor_options["max_restarts"]
max_task_retries = actor_options["max_task_retries"]
max_pending_calls = actor_options["max_pending_calls"]
worker = ray.worker.global_worker
worker.check_connected()
@ -849,11 +767,7 @@ class ActorClass:
# decorator. Last three conditions are to check that no resources were
# specified when _remote() was called.
if (
meta.num_cpus is None
and meta.num_gpus is None
and meta.resources is None
and meta.accelerator_type is None
and num_cpus is None
num_cpus is None
and num_gpus is None
and resources is None
and accelerator_type is None
@ -868,8 +782,8 @@ class ActorClass:
# resources are associated with methods.
cpus_to_use = (
ray_constants.DEFAULT_ACTOR_CREATION_CPU_SPECIFIED
if meta.num_cpus is None
else meta.num_cpus
if num_cpus is None
else num_cpus
)
actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED
@ -897,13 +811,14 @@ class ActorClass:
meta.method_meta.methods.keys(),
)
# TODO(suquark): cleanup "resources_from_resource_arguments" later.
resources = ray._private.utils.resources_from_resource_arguments(
cpus_to_use,
meta.num_gpus,
meta.memory,
meta.object_store_memory,
meta.resources,
meta.accelerator_type,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_cpus,
num_gpus,
memory,
@ -926,14 +841,6 @@ class ActorClass:
function_signature = meta.method_meta.signatures["__init__"]
creation_args = signature.flatten_args(function_signature, args, kwargs)
scheduling_strategy = scheduling_strategy or meta.scheduling_strategy
if (placement_group != "default") and (scheduling_strategy is not None):
raise ValueError(
"Placement groups should be specified via the "
"scheduling_strategy option. "
"The placement_group option is deprecated."
)
if scheduling_strategy is None or isinstance(
scheduling_strategy, PlacementGroupSchedulingStrategy
):
@ -971,19 +878,17 @@ class ActorClass:
else:
scheduling_strategy = "DEFAULT"
if runtime_env:
new_runtime_env = parse_runtime_env(runtime_env)
else:
new_runtime_env = meta.runtime_env
serialized_runtime_env_info = None
if new_runtime_env is not None:
if runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info(
new_runtime_env,
runtime_env,
is_job_runtime_env=False,
serialize=True,
)
concurrency_groups_dict = {}
if meta.concurrency_groups is None:
meta.concurrency_groups = []
for cg_name in meta.concurrency_groups:
concurrency_groups_dict[cg_name] = {
"name": cg_name,
@ -1020,8 +925,8 @@ class ActorClass:
meta.language,
meta.actor_creation_function_descriptor,
creation_args,
max_restarts or meta.max_restarts,
max_task_retries or meta.max_task_retries,
max_restarts,
max_task_retries,
resources,
actor_placement_resources,
max_concurrency,
@ -1375,59 +1280,22 @@ def modify_class(cls):
return Class
def make_actor(
cls,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
max_restarts,
max_task_retries,
runtime_env,
concurrency_groups,
scheduling_strategy: SchedulingStrategyT,
):
def make_actor(cls, actor_options):
Class = modify_class(cls)
_inject_tracing_into_class(Class)
if max_restarts is None:
max_restarts = 0
if max_task_retries is None:
max_task_retries = 0
if concurrency_groups is None:
concurrency_groups = []
infinite_restart = max_restarts == -1
if not infinite_restart:
if max_restarts < 0:
raise ValueError(
"max_restarts must be an integer >= -1 "
"-1 indicates infinite restarts"
)
else:
if "max_restarts" in actor_options:
if actor_options["max_restarts"] != -1: # -1 represents infinite restart
# Make sure we don't pass too big of an int to C++, causing
# an overflow.
max_restarts = min(max_restarts, ray_constants.MAX_INT64_VALUE)
if max_restarts == 0 and max_task_retries != 0:
raise ValueError("max_task_retries cannot be set if max_restarts is 0.")
actor_options["max_restarts"] = min(
actor_options["max_restarts"], ray_constants.MAX_INT64_VALUE
)
return ActorClass._ray_from_modified_class(
Class,
ActorClassID.from_random(),
max_restarts,
max_task_retries,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
runtime_env,
concurrency_groups,
scheduling_strategy,
actor_options,
)

View file

@ -82,20 +82,8 @@ def java_function(class_name, function_name):
Language.JAVA,
lambda *args, **kwargs: None,
JavaFunctionDescriptor(class_name, function_name, ""),
None, # num_cpus,
None, # num_gpus,
None, # memory,
None, # object_store_memory,
None, # resources,
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries,
None, # retry_exceptions,
None, # runtime_env,
None, # placement_group,
None,
) # scheduling_strategy,
{},
)
@PublicAPI(stability="beta")
@ -111,20 +99,8 @@ def cpp_function(function_name):
Language.CPP,
lambda *args, **kwargs: None,
CppFunctionDescriptor(function_name, "PYTHON"),
None, # num_cpus,
None, # num_gpus,
None, # memory,
None, # object_store_memory,
None, # resources,
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries,
None, # retry_exceptions,
None, # runtime_env,
None, # placement_group,
None,
) # scheduling_strategy,
{},
)
@PublicAPI(stability="beta")
@ -139,15 +115,7 @@ def java_actor_class(class_name):
return ActorClass._ray_from_function_descriptor(
Language.JAVA,
JavaFunctionDescriptor(class_name, "<init>", ""),
max_restarts=0,
max_task_retries=0,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
runtime_env=None,
{},
)
@ -165,13 +133,5 @@ def cpp_actor_class(create_function_name, class_name):
return ActorClass._ray_from_function_descriptor(
Language.CPP,
CppFunctionDescriptor(create_function_name, "PYTHON", class_name),
max_restarts=0,
max_task_retries=0,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
runtime_env=None,
{},
)

View file

@ -4,10 +4,7 @@ import logging
import uuid
from ray import cloudpickle as pickle
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy,
SchedulingStrategyT,
)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray._raylet import PythonFunctionDescriptor
from ray import cross_language, Language
from ray._private.client_mode_hook import client_mode_convert_function
@ -19,16 +16,7 @@ from ray.util.tracing.tracing_helper import (
_tracing_task_invocation,
_inject_tracing_into_function,
)
# Default parameters for remote functions.
DEFAULT_REMOTE_FUNCTION_CPUS = 1
DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS = 1
DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False
from ray._private import ray_option_utils
logger = logging.getLogger(__name__)
@ -85,19 +73,7 @@ class RemoteFunction:
language,
function,
function_descriptor,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_returns,
max_calls,
max_retries,
retry_exceptions,
runtime_env,
placement_group,
scheduling_strategy: SchedulingStrategyT,
task_options,
):
if inspect.iscoroutinefunction(function):
raise ValueError(
@ -105,56 +81,34 @@ class RemoteFunction:
"async function with `asyncio.get_event_loop.run_until(f())`. "
"See more at https://docs.ray.io/en/latest/ray-core/async_api.html#asyncio-for-remote-tasks" # noqa
)
self._default_options = task_options
# TODO(suquark): This is a workaround for class attributes of options.
# They are being used in some other places, mostly tests. Need cleanup later.
# E.g., actors uses "__ray_metadata__" to collect options, we can so something
# similar for remote functions.
for k, v in ray_option_utils.task_options.items():
setattr(self, "_" + k, task_options.get(k, v.default_value))
self._runtime_env = parse_runtime_env(self._runtime_env)
if "runtime_env" in self._default_options:
self._default_options["runtime_env"] = self._runtime_env
self._language = language
self._function = _inject_tracing_into_function(function)
self._function_name = function.__module__ + "." + function.__name__
self._function_descriptor = function_descriptor
self._is_cross_language = language != Language.PYTHON
self._num_cpus = DEFAULT_REMOTE_FUNCTION_CPUS if num_cpus is None else num_cpus
self._num_gpus = num_gpus
self._memory = memory
if object_store_memory is not None:
raise NotImplementedError(
"setting object_store_memory is not implemented for tasks"
)
self._object_store_memory = None
self._resources = resources
self._accelerator_type = accelerator_type
self._num_returns = (
DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS
if num_returns is None
else num_returns
)
self._max_calls = (
DEFAULT_REMOTE_FUNCTION_MAX_CALLS if max_calls is None else max_calls
)
self._max_retries = (
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
if max_retries is None
else max_retries
)
self._retry_exceptions = (
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS
if retry_exceptions is None
else retry_exceptions
)
self._runtime_env = parse_runtime_env(runtime_env)
self._placement_group = placement_group
self._decorator = getattr(function, "__ray_invocation_decorator__", None)
self._function_signature = ray._private.signature.extract_signature(
self._function
)
self._scheduling_strategy = scheduling_strategy
self._last_export_session_and_job = None
self._uuid = uuid.uuid4()
# Override task.remote's signature and docstring
@wraps(function)
def _remote_proxy(*args, **kwargs):
return self._remote(args=args, kwargs=kwargs)
return self._remote(args=args, kwargs=kwargs, **self._default_options)
self.remote = _remote_proxy
@ -169,21 +123,7 @@ class RemoteFunction:
self,
args=None,
kwargs=None,
num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
name="",
scheduling_strategy: SchedulingStrategyT = None,
**task_options,
):
"""Configures and overrides the task invocation parameters.
@ -202,29 +142,24 @@ class RemoteFunction:
"""
func_cls = self
new_runtime_env = parse_runtime_env(runtime_env)
options = dict(
num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(placement_group_capture_child_tasks),
runtime_env=new_runtime_env,
name=name,
scheduling_strategy=scheduling_strategy,
)
# override original options
default_options = self._default_options.copy()
# max_calls could not be used in ".options()", we should remove it before
# merging options from '@ray.remote'.
default_options.pop("max_calls", None)
updated_options = {**default_options, **task_options}
ray_option_utils.validate_task_options(updated_options, in_options=True)
# only update runtime_env when ".options()" specifies new runtime_env
if "runtime_env" in task_options:
updated_options["runtime_env"] = parse_runtime_env(
updated_options["runtime_env"]
)
class FuncWrapper:
def remote(self, *args, **kwargs):
return func_cls._remote(args=args, kwargs=kwargs, **options)
return func_cls._remote(args=args, kwargs=kwargs, **updated_options)
def bind(self, *args, **kwargs):
"""
@ -234,61 +169,18 @@ class RemoteFunction:
"""
from ray.experimental.dag.function_node import FunctionNode
return FunctionNode(
func_cls._function,
args,
kwargs,
options,
)
return FunctionNode(func_cls._function, args, kwargs, updated_options)
return FuncWrapper()
@_tracing_task_invocation
def _remote(
self,
args=None,
kwargs=None,
num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
runtime_env=None,
name="",
scheduling_strategy: SchedulingStrategyT = None,
):
def _remote(self, args=None, kwargs=None, **task_options):
"""Submit the remote function for execution."""
# We pop the "max_calls" coming from "@ray.remote" here. We no longer need
# it in "_remote()".
task_options.pop("max_calls", None)
if client_mode_should_convert(auto_init=True):
return client_mode_convert_function(
self,
args,
kwargs,
num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
placement_group_capture_child_tasks
),
runtime_env=runtime_env,
name=name,
scheduling_strategy=scheduling_strategy,
)
return client_mode_convert_function(self, args, kwargs, **task_options)
worker = ray.worker.global_worker
worker.check_connected()
@ -328,22 +220,40 @@ class RemoteFunction:
kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args
if num_returns is None:
num_returns = self._num_returns
if max_retries is None:
max_retries = self._max_retries
if retry_exceptions is None:
retry_exceptions = self._retry_exceptions
if scheduling_strategy is None:
scheduling_strategy = self._scheduling_strategy
# fill task required options
for k, v in ray_option_utils.task_options.items():
task_options[k] = task_options.get(k, v.default_value)
# "max_calls" already takes effects and should not apply again.
# Remove the default value here.
task_options.pop("max_calls", None)
# TODO(suquark): cleanup these fields
name = task_options["name"]
num_cpus = task_options["num_cpus"]
num_gpus = task_options["num_gpus"]
accelerator_type = task_options["accelerator_type"]
resources = task_options["resources"]
memory = task_options["memory"]
object_store_memory = task_options["object_store_memory"]
runtime_env = parse_runtime_env(task_options["runtime_env"])
placement_group = task_options["placement_group"]
placement_group_bundle_index = task_options["placement_group_bundle_index"]
placement_group_capture_child_tasks = task_options[
"placement_group_capture_child_tasks"
]
scheduling_strategy = task_options["scheduling_strategy"]
num_returns = task_options["num_returns"]
max_retries = task_options["max_retries"]
retry_exceptions = task_options["retry_exceptions"]
# TODO(suquark): cleanup "resources_from_resource_arguments" later.
resources = ray._private.utils.resources_from_resource_arguments(
self._num_cpus,
self._num_gpus,
self._memory,
self._object_store_memory,
self._resources,
self._accelerator_type,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_cpus,
num_gpus,
memory,
@ -352,13 +262,6 @@ class RemoteFunction:
accelerator_type,
)
if (placement_group != "default") and (scheduling_strategy is not None):
raise ValueError(
"Placement groups should be specified via the "
"scheduling_strategy option. "
"The placement_group option is deprecated."
)
if scheduling_strategy is None or isinstance(
scheduling_strategy, PlacementGroupSchedulingStrategy
):
@ -375,8 +278,6 @@ class RemoteFunction:
placement_group_capture_child_tasks = (
worker.should_capture_child_tasks_in_placement_group
)
if placement_group == "default":
placement_group = self._placement_group
placement_group = configure_placement_group_based_on_context(
placement_group_capture_child_tasks,
placement_group_bundle_index,
@ -394,8 +295,6 @@ class RemoteFunction:
else:
scheduling_strategy = "DEFAULT"
if not runtime_env or runtime_env == "{}":
runtime_env = self._runtime_env
serialized_runtime_env_info = None
if runtime_env is not None:
serialized_runtime_env_info = get_runtime_env_info(
@ -422,7 +321,7 @@ class RemoteFunction:
self._language,
self._function_descriptor,
list_args,
name,
name if name is not None else "",
num_returns,
resources,
max_retries,

View file

@ -257,6 +257,7 @@ class ReplicaConfig:
f"Specifying {option} in ray_actor_options is not allowed."
)
# TODO(suquark): reuse options validation of remote function/actor.
# Ray defaults to zero CPUs for placement, we default to one here.
if self.ray_actor_options.get("num_cpus", None) is None:
self.ray_actor_options["num_cpus"] = 1
@ -286,14 +287,13 @@ class ReplicaConfig:
raise ValueError("memory in ray_actor_options must be > 0.")
self.resource_dict["memory"] = memory
if self.ray_actor_options.get("object_store_memory", None) is None:
self.ray_actor_options["object_store_memory"] = 0
object_store_memory = self.ray_actor_options["object_store_memory"]
if not isinstance(object_store_memory, (int, float)):
object_store_memory = self.ray_actor_options.get("object_store_memory")
if not isinstance(object_store_memory, (int, float, type(None))):
raise TypeError(
"object_store_memory in ray_actor_options must be an int or a float."
"object_store_memory in ray_actor_options must be an int, float "
"or None."
)
elif object_store_memory < 0:
elif object_store_memory is not None and object_store_memory < 0:
raise ValueError("object_store_memory in ray_actor_options must be >= 0.")
self.resource_dict["object_store_memory"] = object_store_memory

View file

@ -178,60 +178,103 @@ def test_submit_api(shutdown_only):
def test_invalid_arguments(shutdown_only):
import re
ray.init(num_cpus=2)
for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'num_returns' only accepts 0 or a positive integer",
):
def f():
return 1
@ray.remote(num_returns=opt)
def g1():
return 1
class A:
x = 1
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_retries' only accepts 0, -1 or a"
" positive integer",
):
template1 = (
"The type of keyword '{}' "
+ f"must be {(int, type(None))}, but received type {float}"
)
@ray.remote(max_retries=opt)
def g2():
return 1
# Type check
for keyword in ("num_returns", "max_retries", "max_calls"):
with pytest.raises(TypeError, match=re.escape(template1.format(keyword))):
ray.remote(**{keyword: np.random.uniform(0, 1)})(f)
for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_calls' only accepts 0 or a positive integer",
):
for keyword in ("max_restarts", "max_task_retries"):
with pytest.raises(TypeError, match=re.escape(template1.format(keyword))):
ray.remote(**{keyword: np.random.uniform(0, 1)})(A)
@ray.remote(max_calls=opt)
def g3():
return 1
# Value check for non-negative finite values
for keyword in ("num_returns", "max_calls"):
for v in (np.random.randint(-100, -2), -1):
with pytest.raises(
ValueError,
match=f"The keyword '{keyword}' only accepts None, "
f"0 or a positive integer",
):
ray.remote(**{keyword: v})(f)
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_restarts' only accepts -1, 0 or a"
" positive integer",
):
# Value check for non-negative and infinite values
template2 = (
"The keyword '{}' only accepts None, 0, -1 or a positive integer, "
"where -1 represents infinity."
)
@ray.remote(max_restarts=opt)
class A1:
x = 1
with pytest.raises(ValueError, match=template2.format("max_retries")):
ray.remote(max_retries=np.random.randint(-100, -2))(f)
for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
with pytest.raises(
ValueError,
match="The keyword 'max_task_retries' only accepts -1, 0 or a"
" positive integer",
):
for keyword in ("max_restarts", "max_task_retries"):
with pytest.raises(ValueError, match=template2.format(keyword)):
ray.remote(**{keyword: np.random.randint(-100, -2)})(A)
@ray.remote(max_task_retries=opt)
class A2:
x = 1
def test_options():
"""General test of option keywords in Ray."""
import re
from ray._private import ray_option_utils
def f():
return 1
class A:
x = 1
task_defaults = {
k: v.default_value for k, v in ray_option_utils.task_options.items()
}
task_defaults_for_options = task_defaults.copy()
task_defaults_for_options.pop("max_calls")
ray.remote(f).options(**task_defaults_for_options)
ray.remote(**task_defaults)(f).options(**task_defaults_for_options)
with pytest.raises(
ValueError,
match=re.escape("Setting 'max_calls' is not supported in '.options()'."),
):
ray.remote(f).options(max_calls=1)
actor_defaults = {
k: v.default_value for k, v in ray_option_utils.actor_options.items()
}
actor_defaults_for_options = actor_defaults.copy()
actor_defaults_for_options.pop("concurrency_groups")
ray.remote(A).options(**actor_defaults_for_options)
ray.remote(**actor_defaults)(A).options(**actor_defaults_for_options)
with pytest.raises(
ValueError,
match=re.escape(
"Setting 'concurrency_groups' is not supported in '.options()'."
),
):
ray.remote(A).options(concurrency_groups=[])
unique_object = type("###", (), {})()
for k, v in ray_option_utils.task_options.items():
v.validate(k, v.default_value)
with pytest.raises(TypeError):
v.validate(k, unique_object)
for k, v in ray_option_utils.actor_options.items():
v.validate(k, v.default_value)
with pytest.raises(TypeError):
v.validate(k, unique_object)
# https://github.com/ray-project/ray/issues/17842

View file

@ -479,16 +479,11 @@ def test_serializing_exceptions(ray_start_regular_shared):
def test_invalid_task(ray_start_regular_shared):
with ray_start_client_server() as ray:
@ray.remote(runtime_env="invalid value")
def f():
return 1
with pytest.raises(TypeError):
# No exception on making the remote call.
ref = f.remote()
# Exception during scheduling will be raised on ray.get()
with pytest.raises(Exception):
ray.get(ref)
@ray.remote(runtime_env="invalid value")
def f():
return 1
def test_create_remote_before_start(ray_start_regular_shared):

View file

@ -6,6 +6,7 @@ import json
import logging
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
from ray._private import ray_option_utils
from typing import Any, Callable, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
@ -77,17 +78,9 @@ class ClientAPI:
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return remote_decorator(options=None)(args[0])
error_string = (
"The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
"'@ray.remote(num_returns=2, "
'resources={"CustomResource": 1})\'.'
)
assert len(args) == 0 and len(kwargs) > 0, error_string
assert (
len(args) == 0 and len(kwargs) > 0
), ray_option_utils.remote_args_error_string
return remote_decorator(options=kwargs)
# TODO(mwtian): consider adding _internal_ prefix to call_remote /

View file

@ -2,54 +2,9 @@ from typing import Any
from typing import Dict
from typing import Optional
from ray._private import ray_option_utils
from ray.util.placement_group import PlacementGroup, check_placement_group_index
options = {
"num_returns": (
int,
lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 or a positive integer",
),
"num_cpus": (),
"num_gpus": (),
"resources": (),
"accelerator_type": (),
"max_calls": (
int,
lambda x: x >= 0,
"The keyword 'max_calls' only accepts 0 or a positive integer",
),
"max_restarts": (
int,
lambda x: x >= -1,
"The keyword 'max_restarts' only accepts -1, 0 or a positive integer",
),
"max_task_retries": (
int,
lambda x: x >= -1,
"The keyword 'max_task_retries' only accepts -1, 0 or a positive integer",
),
"max_retries": (
int,
lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 or a positive integer",
),
"retry_exceptions": (),
"max_concurrency": (),
"name": (),
"namespace": (),
"lifetime": (),
"memory": (),
"object_store_memory": (),
"placement_group": (),
"placement_group_bundle_index": (),
"placement_group_capture_child_tasks": (),
"runtime_env": (),
"max_pending_calls": (),
"concurrency_groups": (),
"scheduling_strategy": (),
}
def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if kwargs_dict is None:
@ -59,15 +14,12 @@ def validate_options(kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str
out = {}
for k, v in kwargs_dict.items():
if k not in options.keys():
raise TypeError(f"Invalid option passed to remote(): {k}")
validator = options[k]
if len(validator) != 0:
if v is not None:
if not isinstance(v, validator[0]):
raise ValueError(validator[2])
if not validator[1](v):
raise ValueError(validator[2])
if k not in ray_option_utils.valid_options:
raise ValueError(
f"Invalid option keyword: '{k}'. "
f"{ray_option_utils.remote_args_error_string}"
)
ray_option_utils.valid_options[k].validate(k, v)
out[k] = v
# Validate placement setting similar to the logic in ray/actor.py and

View file

@ -810,44 +810,15 @@ class Worker:
def _convert_actor(self, actor: "ActorClass") -> str:
"""Register a ClientActorClass for the ActorClass and return a UUID"""
key = uuid.uuid4().hex
md = actor.__ray_metadata__
cls = md.modified_class
self._converted[key] = ClientActorClass(
cls,
options={
"max_restarts": md.max_restarts,
"max_task_retries": md.max_task_retries,
"num_cpus": md.num_cpus,
"num_gpus": md.num_gpus,
"memory": md.memory,
"object_store_memory": md.object_store_memory,
"resources": md.resources,
"accelerator_type": md.accelerator_type,
"runtime_env": md.runtime_env,
"concurrency_groups": md.concurrency_groups,
"scheduling_strategy": md.scheduling_strategy,
},
)
cls = actor.__ray_metadata__.modified_class
self._converted[key] = ClientActorClass(cls, options=actor._default_options)
return key
def _convert_function(self, func: "RemoteFunction") -> str:
"""Register a ClientRemoteFunc for the ActorClass and return a UUID"""
key = uuid.uuid4().hex
f = func._function
self._converted[key] = ClientRemoteFunc(
f,
options={
"num_cpus": func._num_cpus,
"num_gpus": func._num_gpus,
"max_calls": func._max_calls,
"max_retries": func._max_retries,
"resources": func._resources,
"accelerator_type": func._accelerator_type,
"num_returns": func._num_returns,
"memory": func._memory,
"runtime_env": func._runtime_env,
"scheduling_strategy": func._scheduling_strategy,
},
func._function, options=func._default_options
)
return key

View file

@ -1,6 +1,7 @@
from contextlib import contextmanager
import atexit
import faulthandler
import functools
import hashlib
import inspect
import io
@ -29,7 +30,6 @@ import ray.remote_function
import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils
import ray._private.services as services
from ray.util.scheduling_strategies import SchedulingStrategyT
from ray._private.gcs_pubsub import (
GcsPublisher,
GcsErrorSubscriber,
@ -43,6 +43,7 @@ import ray._private.import_thread as import_thread
from ray.util.tracing.tracing_helper import import_from_string
from ray.util.annotations import PublicAPI, DeveloperAPI, Deprecated
from ray.util.debug import log_once
from ray._private import ray_option_utils
import ray
import colorama
import setproctitle
@ -2110,121 +2111,24 @@ def _mode(worker=global_worker):
return worker.mode
def make_decorator(
num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
resources=None,
accelerator_type=None,
max_calls=None,
max_retries=None,
max_restarts=None,
max_task_retries=None,
runtime_env=None,
placement_group="default",
worker=None,
retry_exceptions=None,
concurrency_groups=None,
scheduling_strategy: SchedulingStrategyT = None,
):
def decorator(function_or_class):
if inspect.isfunction(function_or_class) or is_cython(function_or_class):
# Set the remote function default resources.
if max_restarts is not None:
raise ValueError(
"The keyword 'max_restarts' is not allowed for remote functions."
)
if max_task_retries is not None:
raise ValueError(
"The keyword 'max_task_retries' is not "
"allowed for remote functions."
)
if num_returns is not None and (
not isinstance(num_returns, int) or num_returns < 0
):
raise ValueError(
"The keyword 'num_returns' only accepts 0 or a positive integer"
)
if max_retries is not None and (
not isinstance(max_retries, int) or max_retries < -1
):
raise ValueError(
"The keyword 'max_retries' only accepts 0, -1 or a"
" positive integer"
)
if max_calls is not None and (
not isinstance(max_calls, int) or max_calls < 0
):
raise ValueError(
"The keyword 'max_calls' only accepts 0 or a positive integer"
)
return ray.remote_function.RemoteFunction(
Language.PYTHON,
function_or_class,
None,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
num_returns,
max_calls,
max_retries,
retry_exceptions,
runtime_env,
placement_group,
scheduling_strategy,
)
if inspect.isclass(function_or_class):
if num_returns is not None:
raise TypeError("The keyword 'num_returns' is not allowed for actors.")
if max_retries is not None:
raise TypeError("The keyword 'max_retries' is not allowed for actors.")
if retry_exceptions is not None:
raise TypeError(
"The keyword 'retry_exceptions' is not allowed for actors."
)
if max_calls is not None:
raise TypeError("The keyword 'max_calls' is not allowed for actors.")
if max_restarts is not None and (
not isinstance(max_restarts, int) or max_restarts < -1
):
raise ValueError(
"The keyword 'max_restarts' only accepts -1, 0 or a"
" positive integer"
)
if max_task_retries is not None and (
not isinstance(max_task_retries, int) or max_task_retries < -1
):
raise ValueError(
"The keyword 'max_task_retries' only accepts -1, 0 or a"
" positive integer"
)
return ray.actor.make_actor(
function_or_class,
num_cpus,
num_gpus,
memory,
object_store_memory,
resources,
accelerator_type,
max_restarts,
max_task_retries,
runtime_env,
concurrency_groups,
scheduling_strategy,
)
raise TypeError(
"The @ray.remote decorator must be applied to "
"either a function or to a class."
def _make_remote(function_or_class, options):
# filter out placeholders in options
if inspect.isfunction(function_or_class) or is_cython(function_or_class):
ray_option_utils.validate_task_options(options, in_options=False)
return ray.remote_function.RemoteFunction(
Language.PYTHON,
function_or_class,
None,
options,
)
return decorator
if inspect.isclass(function_or_class):
ray_option_utils.validate_actor_options(options, in_options=False)
return ray.actor.make_actor(function_or_class, options)
raise TypeError(
"The @ray.remote decorator must be applied to either a function or a class."
)
@PublicAPI
@ -2294,6 +2198,8 @@ def remote(*args, **kwargs):
accelerator_type: If specified, requires that the task or actor run
on a node with the specified type of accelerator.
See `ray.accelerators` for accelerator types.
memory (float): The heap memory request for this task/actor.
object_store_memory (int): The object store memory request for this task/actor.
max_calls (int): Only for *remote functions*. This specifies the
maximum number of times that a given worker can execute
the given remote function before it must exit
@ -2341,87 +2247,10 @@ def remote(*args, **kwargs):
`PlacementGroupSchedulingStrategy`:
placement group based scheduling.
"""
worker = global_worker
# "callable" returns true for both function and class.
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return make_decorator(worker=worker)(args[0])
# Parse the keyword arguments from the decorator.
valid_kwargs = [
"num_returns",
"num_cpus",
"num_gpus",
"memory",
"object_store_memory",
"resources",
"accelerator_type",
"max_calls",
"max_restarts",
"max_task_retries",
"max_retries",
"runtime_env",
"retry_exceptions",
"placement_group",
"concurrency_groups",
"scheduling_strategy",
]
error_string = (
"The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
f"the arguments in the list {valid_kwargs}, for example "
"'@ray.remote(num_returns=2, "
'resources={"CustomResource": 1})\'.'
)
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in valid_kwargs, error_string
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
resources = kwargs.get("resources")
if not isinstance(resources, dict) and resources is not None:
raise TypeError(
"The 'resources' keyword argument must be a "
f"dictionary, but received type {type(resources)}."
)
if resources is not None:
assert "CPU" not in resources, "Use the 'num_cpus' argument."
assert "GPU" not in resources, "Use the 'num_gpus' argument."
accelerator_type = kwargs.get("accelerator_type")
# Handle other arguments.
num_returns = kwargs.get("num_returns")
max_calls = kwargs.get("max_calls")
max_restarts = kwargs.get("max_restarts")
max_task_retries = kwargs.get("max_task_retries")
memory = kwargs.get("memory")
object_store_memory = kwargs.get("object_store_memory")
max_retries = kwargs.get("max_retries")
runtime_env = kwargs.get("runtime_env")
placement_group = kwargs.get("placement_group", "default")
retry_exceptions = kwargs.get("retry_exceptions")
concurrency_groups = kwargs.get("concurrency_groups")
scheduling_strategy = kwargs.get("scheduling_strategy")
return make_decorator(
num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
object_store_memory=object_store_memory,
resources=resources,
accelerator_type=accelerator_type,
max_calls=max_calls,
max_restarts=max_restarts,
max_task_retries=max_task_retries,
max_retries=max_retries,
runtime_env=runtime_env,
placement_group=placement_group,
worker=worker,
retry_exceptions=retry_exceptions,
concurrency_groups=concurrency_groups or [],
scheduling_strategy=scheduling_strategy,
)
# "args[0]" is the class or function under the decorator.
return _make_remote(args[0], {})
assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string
return functools.partial(_make_remote, options=kwargs)