mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Core] Port concurrency groups with asyncio (#18567)
## Why are these changes needed? This PR aims to port concurrency groups functionality with asyncio for Python. ### API ```python @ray.remote(concurrency_groups={"io": 2, "compute": 4}) class AsyncActor: def __init__(self): pass @ray.method(concurrency_group="io") async def f1(self): pass @ray.method(concurrency_group="io") def f2(self): pass @ray.method(concurrency_group="compute") def f3(self): pass @ray.method(concurrency_group="compute") def f4(self): pass def f5(self): pass ``` The annotation above the actor class `AsyncActor` defines this actor will have 2 concurrency groups and defines their max concurrencies, and it has a default concurrency group. Every concurrency group has an async eventloop and a pythread to execute the methods which is defined on them. Method `f1` will be invoked in the `io` concurrency group. `f2` in `io`, `f3` in `compute` and etc. TO BE NOTICED, `f5` and `__init__` will be invoked in the default concurrency. The following method `f2` will be invoked in the concurrency group `compute` since the dynamic specifying has a higher priority. ```python a.f2.options(concurrency_group="compute").remote() ``` ### Implementation The straightforward implementation details are: - Before we only have 1 eventloop binding 1 pythread for an asyncio actor. Now we create 1 eventloop binding 1 pythread for every concurrency group of the asyncio actor. - Before we have 1 fiber state for every caller in the asyncio actor. Now we create a FiberStateManager for every caller in the asyncio actor. And the FiberStateManager manages the fiber states for concurrency groups. ## Related issue number #16047
This commit is contained in:
parent
a04b02e2e8
commit
048e7f7d5d
18 changed files with 623 additions and 67 deletions
|
@ -129,7 +129,9 @@ Status TaskExecutor::ExecuteTask(
|
|||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error) {
|
||||
bool *is_application_level_error,
|
||||
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||
const std::string name_of_concurrency_group_to_execute) {
|
||||
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
|
||||
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
|
||||
auto function_descriptor = ray_function.GetFunctionDescriptor();
|
||||
|
|
|
@ -82,7 +82,9 @@ class TaskExecutor {
|
|||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error);
|
||||
bool *is_application_level_error,
|
||||
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||
const std::string name_of_concurrency_group_to_execute);
|
||||
|
||||
virtual ~TaskExecutor(){};
|
||||
|
||||
|
|
101
doc/source/concurrency_group_api.rst
Normal file
101
doc/source/concurrency_group_api.rst
Normal file
|
@ -0,0 +1,101 @@
|
|||
Limiting Concurrency Per-Method with Concurrency Groups
|
||||
=======================================================
|
||||
|
||||
Besides setting the max concurrency overall for an asyncio actor, Ray allows methods to be separated into *concurrency groups*, each with its own asyncio event loop. This allows you to limit the concurrency per-method, e.g., allow a health-check method to be given its own concurrency quota separate from request serving methods.
|
||||
|
||||
.. warning:: Concurrency groups are only supported for asyncio actors, not threaded actors.
|
||||
|
||||
.. _defining-concurrency-groups:
|
||||
|
||||
Defining Concurrency Groups
|
||||
---------------------------
|
||||
|
||||
You can define concurrency groups for asyncio actors using the ``concurrency_groups`` decorator argument:
|
||||
|
||||
.. tabs::
|
||||
.. group-tab:: Python
|
||||
|
||||
This defines two concurrency groups, "io" with max_concurrency=2 and
|
||||
"compute" with max_concurrency=4. The methods ``f1`` and ``f2`` are
|
||||
placed in the "io" group, and the methods ``f3`` and ``f4`` are placed
|
||||
into the "compute" group. Note that there is always a default
|
||||
concurrency group, which has a default concurrency of 1000.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
|
||||
class AsyncIOActor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
async def f1(self):
|
||||
pass
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
async def f2(self):
|
||||
pass
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
async def f3(self):
|
||||
pass
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
async def f4(self):
|
||||
pass
|
||||
|
||||
async def f5(self):
|
||||
pass
|
||||
|
||||
a = AsyncIOActor.remote()
|
||||
a.f1.remote() # executed in the "io" group.
|
||||
a.f2.remote() # executed in the "io" group.
|
||||
a.f3.remote() # executed in the "compute" group.
|
||||
a.f4.remote() # executed in the "compute" group.
|
||||
a.f5.remote() # executed in the default group.
|
||||
|
||||
|
||||
.. _default-concurrency-group:
|
||||
|
||||
Default Concurrency Group
|
||||
-------------------------
|
||||
|
||||
By default, methods are placed in a default concurrency group which has a concurrency limit of 1000.
|
||||
The concurrency of the default group can be changed by setting the ``max_concurrency`` actor option.
|
||||
|
||||
.. tabs::
|
||||
.. group-tab:: Python
|
||||
|
||||
The following AsyncIOActor has 2 concurrency groups: "io" and "default".
|
||||
The max concurrency of "io" is 2, and the max concurrency of "default" is 10.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ray.remote(concurrency_groups={"io": 2)
|
||||
class AsyncIOActor:
|
||||
async def f1(self):
|
||||
pass
|
||||
|
||||
actor = AsyncIOActor.options(max_concurrency=10).remote()
|
||||
|
||||
|
||||
.. _setting-the-concurrency-group-at-runtime:
|
||||
|
||||
Setting the Concurrency Group at Runtime
|
||||
----------------------------------------
|
||||
|
||||
You can also dispatch actor methods into a specific concurrency group at runtime using the ``.options`` method:
|
||||
|
||||
.. tabs::
|
||||
.. group-tab:: Python
|
||||
|
||||
The following snippet demonstrates setting the concurrency group of the
|
||||
``f2`` method dynamically at runtime.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Executed in the "io" group (as defined in the actor class).
|
||||
a.f2.options().remote()
|
||||
|
||||
# Executed in the "compute" group.
|
||||
a.f2.options(concurrency_group="compute").remote()
|
|
@ -14,6 +14,7 @@ Finally, we've also included some content on using core Ray APIs with `Tensorflo
|
|||
actors.rst
|
||||
namespaces.rst
|
||||
async_api.rst
|
||||
concurrency_group_api.rst
|
||||
using-ray-with-gpus.rst
|
||||
serialization.rst
|
||||
memory-management.rst
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.includes.common cimport (
|
|||
CBuffer,
|
||||
CRayObject,
|
||||
CAddress,
|
||||
CConcurrencyGroup,
|
||||
)
|
||||
from ray.includes.libcoreworker cimport (
|
||||
ActorHandleSharedPtr,
|
||||
|
@ -117,6 +118,11 @@ cdef class CoreWorker:
|
|||
object current_runtime_env
|
||||
c_bool is_local_mode
|
||||
|
||||
object cgname_to_eventloop_dict
|
||||
object eventloop_for_default_cg
|
||||
object thread_for_default_cg
|
||||
object fd_to_cgname_dict
|
||||
|
||||
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
||||
size_t data_size, ObjectRef object_ref,
|
||||
c_vector[CObjectID] contained_ids,
|
||||
|
@ -130,6 +136,10 @@ cdef class CoreWorker:
|
|||
c_vector[shared_ptr[CRayObject]] *returns)
|
||||
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
|
||||
cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle)
|
||||
cdef c_function_descriptors_to_python(
|
||||
self, const c_vector[CFunctionDescriptor] &c_function_descriptors)
|
||||
cdef initialize_eventloops_for_actor_concurrency_group(
|
||||
self, const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups)
|
||||
|
||||
cdef class FunctionDescriptor:
|
||||
cdef:
|
||||
|
|
|
@ -60,6 +60,7 @@ from ray.includes.common cimport (
|
|||
CRayFunction,
|
||||
CWorkerType,
|
||||
CJobConfig,
|
||||
CConcurrencyGroup,
|
||||
move,
|
||||
LANGUAGE_CPP,
|
||||
LANGUAGE_JAVA,
|
||||
|
@ -316,6 +317,34 @@ cdef int prepare_resources(
|
|||
resource_map[0][key.encode("ascii")] = float(value)
|
||||
return 0
|
||||
|
||||
cdef c_vector[CFunctionDescriptor] prepare_function_descriptors(pyfd_list):
|
||||
cdef:
|
||||
c_vector[CFunctionDescriptor] fd_list
|
||||
CRayFunction ray_function
|
||||
|
||||
for pyfd in pyfd_list:
|
||||
fd_list.push_back(CFunctionDescriptorBuilder.BuildPython(
|
||||
pyfd.module_name, pyfd.class_name, pyfd.function_name, b""))
|
||||
return fd_list
|
||||
|
||||
|
||||
cdef int prepare_actor_concurrency_groups(
|
||||
dict concurrency_groups_dict,
|
||||
c_vector[CConcurrencyGroup] *concurrency_groups):
|
||||
|
||||
cdef:
|
||||
CConcurrencyGroup cg
|
||||
c_vector[CFunctionDescriptor] c_fd_list
|
||||
|
||||
if concurrency_groups_dict is None:
|
||||
raise ValueError("Must provide it...")
|
||||
|
||||
for key, value in concurrency_groups_dict.items():
|
||||
c_fd_list = prepare_function_descriptors(value["function_descriptors"])
|
||||
cg = CConcurrencyGroup(
|
||||
key.encode("ascii"), value["max_concurrency"], c_fd_list)
|
||||
concurrency_groups.push_back(cg)
|
||||
return 1
|
||||
|
||||
cdef prepare_args(
|
||||
CoreWorker core_worker,
|
||||
|
@ -411,7 +440,11 @@ cdef execute_task(
|
|||
const c_vector[CObjectID] &c_return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
c_bool *is_application_level_error):
|
||||
c_bool *is_application_level_error,
|
||||
# This parameter is only used for actor creation task to define
|
||||
# the concurrency groups of this actor.
|
||||
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups,
|
||||
const c_string c_name_of_concurrency_group_to_execute):
|
||||
|
||||
is_application_level_error[0] = False
|
||||
|
||||
|
@ -462,6 +495,11 @@ cdef execute_task(
|
|||
print(actor_magic_token)
|
||||
print(actor_magic_token, file=sys.stderr)
|
||||
|
||||
# Initial eventloops for asyncio for this actor.
|
||||
if core_worker.current_actor_is_asyncio():
|
||||
core_worker.initialize_eventloops_for_actor_concurrency_group(
|
||||
c_defined_concurrency_groups)
|
||||
|
||||
execution_info = execution_infos.get(function_descriptor)
|
||||
if not execution_info:
|
||||
execution_info = manager.get_execution_info(
|
||||
|
@ -473,6 +511,8 @@ cdef execute_task(
|
|||
b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
|
||||
|
||||
task_name = name.decode("utf-8")
|
||||
name_of_concurrency_group_to_execute = \
|
||||
c_name_of_concurrency_group_to_execute.decode("ascii")
|
||||
title = f"ray::{task_name}"
|
||||
|
||||
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
|
||||
|
@ -520,7 +560,9 @@ cdef execute_task(
|
|||
async_function = sync_to_async(function)
|
||||
|
||||
return core_worker.run_async_func_in_event_loop(
|
||||
async_function, actor, *arguments, **kwarguments)
|
||||
async_function, function_descriptor,
|
||||
name_of_concurrency_group_to_execute, actor,
|
||||
*arguments, **kwarguments)
|
||||
|
||||
return function(actor, *arguments, **kwarguments)
|
||||
|
||||
|
@ -546,7 +588,8 @@ cdef execute_task(
|
|||
.deserialize_objects(
|
||||
metadata_pairs, object_refs))
|
||||
args = core_worker.run_async_func_in_event_loop(
|
||||
deserialize_args)
|
||||
deserialize_args, function_descriptor,
|
||||
name_of_concurrency_group_to_execute)
|
||||
else:
|
||||
args = ray.worker.global_worker.deserialize_objects(
|
||||
metadata_pairs, object_refs)
|
||||
|
@ -692,7 +735,9 @@ cdef CRayStatus task_execution_handler(
|
|||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
|
||||
c_bool *is_application_level_error) nogil:
|
||||
c_bool *is_application_level_error,
|
||||
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
|
||||
const c_string name_of_concurrency_group_to_execute) nogil:
|
||||
with gil, disable_client_hook():
|
||||
try:
|
||||
try:
|
||||
|
@ -701,7 +746,9 @@ cdef CRayStatus task_execution_handler(
|
|||
execute_task(task_type, task_name, ray_function, c_resources,
|
||||
c_args, c_arg_refs, c_return_ids,
|
||||
debugger_breakpoint, returns,
|
||||
is_application_level_error)
|
||||
is_application_level_error,
|
||||
defined_concurrency_groups,
|
||||
name_of_concurrency_group_to_execute)
|
||||
except Exception as e:
|
||||
sys_exit = SystemExit()
|
||||
if isinstance(e, RayActorError) and \
|
||||
|
@ -1020,6 +1067,10 @@ cdef class CoreWorker:
|
|||
options.startup_token = startup_token
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
self.cgname_to_eventloop_dict = None
|
||||
self.fd_to_cgname_dict = None
|
||||
self.eventloop_for_default_cg = None
|
||||
|
||||
def shutdown(self):
|
||||
with nogil:
|
||||
# If it's a worker, the core worker process should have been
|
||||
|
@ -1414,6 +1465,7 @@ cdef class CoreWorker:
|
|||
c_string extension_data,
|
||||
c_string serialized_runtime_env,
|
||||
runtime_env_uris,
|
||||
concurrency_groups_dict,
|
||||
):
|
||||
cdef:
|
||||
CRayFunction ray_function
|
||||
|
@ -1425,6 +1477,7 @@ cdef class CoreWorker:
|
|||
CPlacementGroupID c_placement_group_id = \
|
||||
placement_group_id.native()
|
||||
c_vector[c_string] c_runtime_env_uris = runtime_env_uris
|
||||
c_vector[CConcurrencyGroup] c_concurrency_groups
|
||||
|
||||
with self.profile_event(b"submit_task"):
|
||||
prepare_resources(resources, &c_resources)
|
||||
|
@ -1432,6 +1485,8 @@ cdef class CoreWorker:
|
|||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
prepare_actor_concurrency_groups(
|
||||
concurrency_groups_dict, &c_concurrency_groups)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
||||
|
@ -1447,7 +1502,8 @@ cdef class CoreWorker:
|
|||
placement_group_bundle_index),
|
||||
placement_group_capture_child_tasks,
|
||||
serialized_runtime_env,
|
||||
c_runtime_env_uris),
|
||||
c_runtime_env_uris,
|
||||
c_concurrency_groups),
|
||||
extension_data,
|
||||
&c_actor_id))
|
||||
|
||||
|
@ -1814,32 +1870,96 @@ cdef class CoreWorker:
|
|||
CCoreWorkerProcess.GetCoreWorker().SealReturnObject(
|
||||
return_id, returns[0][i]))
|
||||
|
||||
def create_or_get_event_loop(self):
|
||||
if self.async_event_loop is None:
|
||||
self.async_event_loop = get_new_event_loop()
|
||||
asyncio.set_event_loop(self.async_event_loop)
|
||||
cdef c_function_descriptors_to_python(
|
||||
self,
|
||||
const c_vector[CFunctionDescriptor] &c_function_descriptors):
|
||||
|
||||
if self.async_thread is None:
|
||||
self.async_thread = threading.Thread(
|
||||
target=lambda: self.async_event_loop.run_forever(),
|
||||
name="AsyncIO Thread"
|
||||
ret = []
|
||||
for i in range(c_function_descriptors.size()):
|
||||
ret.append(CFunctionDescriptorToPython(c_function_descriptors[i]))
|
||||
return ret
|
||||
|
||||
cdef initialize_eventloops_for_actor_concurrency_group(
|
||||
self,
|
||||
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups):
|
||||
|
||||
cdef:
|
||||
CConcurrencyGroup c_concurrency_group
|
||||
c_vector[CFunctionDescriptor] c_function_descriptors
|
||||
|
||||
self.cgname_to_eventloop_dict = {}
|
||||
self.fd_to_cgname_dict = {}
|
||||
|
||||
self.eventloop_for_default_cg = get_new_event_loop()
|
||||
self.thread_for_default_cg = threading.Thread(
|
||||
target=lambda: self.eventloop_for_default_cg.run_forever(),
|
||||
name="AsyncIO Thread: default"
|
||||
)
|
||||
# Making the thread as daemon to let it exit
|
||||
# when the main thread exits.
|
||||
self.thread_for_default_cg.daemon = True
|
||||
self.thread_for_default_cg.start()
|
||||
|
||||
for i in range(c_defined_concurrency_groups.size()):
|
||||
c_concurrency_group = c_defined_concurrency_groups[i]
|
||||
cg_name = c_concurrency_group.GetName().decode("ascii")
|
||||
function_descriptors = self.c_function_descriptors_to_python(
|
||||
c_concurrency_group.GetFunctionDescriptors())
|
||||
|
||||
async_eventloop = get_new_event_loop()
|
||||
async_thread = threading.Thread(
|
||||
target=lambda: async_eventloop.run_forever(),
|
||||
name="AsyncIO Thread: {}".format(cg_name)
|
||||
)
|
||||
# Making the thread a daemon causes it to exit
|
||||
# when the main thread exits.
|
||||
self.async_thread.daemon = True
|
||||
self.async_thread.start()
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
return self.async_event_loop
|
||||
self.cgname_to_eventloop_dict[cg_name] = {
|
||||
"eventloop": async_eventloop,
|
||||
"thread": async_thread,
|
||||
}
|
||||
|
||||
for fd in function_descriptors:
|
||||
self.fd_to_cgname_dict[fd] = cg_name
|
||||
|
||||
def get_event_loop(self, function_descriptor, specified_cgname):
|
||||
# __init__ will be invoked in default eventloop
|
||||
if function_descriptor.function_name == "__init__":
|
||||
return self.eventloop_for_default_cg, self.thread_for_default_cg
|
||||
|
||||
if specified_cgname is not None:
|
||||
if specified_cgname in self.cgname_to_eventloop_dict:
|
||||
this_group = self.cgname_to_eventloop_dict[specified_cgname]
|
||||
return (this_group["eventloop"], this_group["thread"])
|
||||
|
||||
if function_descriptor in self.fd_to_cgname_dict:
|
||||
curr_cgname = self.fd_to_cgname_dict[function_descriptor]
|
||||
if curr_cgname in self.cgname_to_eventloop_dict:
|
||||
return (
|
||||
self.cgname_to_eventloop_dict[curr_cgname]["eventloop"],
|
||||
self.cgname_to_eventloop_dict[curr_cgname]["thread"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The function {} is defined to be executed "
|
||||
"in the concurrency group {} . But there is no this group."
|
||||
.format(function_descriptor, curr_cgname))
|
||||
|
||||
return self.eventloop_for_default_cg, self.thread_for_default_cg
|
||||
|
||||
def run_async_func_in_event_loop(
|
||||
self, func, function_descriptor, specified_cgname, *args, **kwargs):
|
||||
|
||||
def run_async_func_in_event_loop(self, func, *args, **kwargs):
|
||||
cdef:
|
||||
CFiberEvent event
|
||||
loop = self.create_or_get_event_loop()
|
||||
eventloop, async_thread = self.get_event_loop(
|
||||
function_descriptor, specified_cgname)
|
||||
coroutine = func(*args, **kwargs)
|
||||
if threading.get_ident() == self.async_thread.ident:
|
||||
future = asyncio.ensure_future(coroutine, loop)
|
||||
if threading.get_ident() == async_thread.ident:
|
||||
future = asyncio.ensure_future(coroutine, eventloop)
|
||||
else:
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, eventloop)
|
||||
future.add_done_callback(lambda _: event.Notify())
|
||||
with nogil:
|
||||
(CCoreWorkerProcess.GetCoreWorker()
|
||||
|
|
|
@ -54,11 +54,14 @@ def method(*args, **kwargs):
|
|||
"""
|
||||
assert len(args) == 0
|
||||
assert len(kwargs) == 1
|
||||
assert "num_returns" in kwargs
|
||||
num_returns = kwargs["num_returns"]
|
||||
|
||||
assert "num_returns" in kwargs or "concurrency_group" in kwargs
|
||||
|
||||
def annotate_method(method):
|
||||
method.__ray_num_returns__ = num_returns
|
||||
if "num_returns" in kwargs:
|
||||
method.__ray_num_returns__ = kwargs["num_returns"]
|
||||
if "concurrency_group" in kwargs:
|
||||
method.__ray_concurrency_group__ = kwargs["concurrency_group"]
|
||||
return method
|
||||
|
||||
return annotate_method
|
||||
|
@ -138,7 +141,12 @@ class ActorMethod:
|
|||
return FuncWrapper()
|
||||
|
||||
@_tracing_actor_method_invocation
|
||||
def _remote(self, args=None, kwargs=None, name="", num_returns=None):
|
||||
def _remote(self,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
name="",
|
||||
num_returns=None,
|
||||
concurrency_group=None):
|
||||
if num_returns is None:
|
||||
num_returns = self._num_returns
|
||||
|
||||
|
@ -222,6 +230,8 @@ class ActorClassMethodMetadata(object):
|
|||
self.decorators = {}
|
||||
self.signatures = {}
|
||||
self.num_returns = {}
|
||||
self.concurrency_group_for_methods = {}
|
||||
|
||||
for method_name, method in actor_methods:
|
||||
# Whether or not this method requires binding of its first
|
||||
# argument. For class and static methods, we do not want to bind
|
||||
|
@ -247,6 +257,10 @@ class ActorClassMethodMetadata(object):
|
|||
self.decorators[method_name] = (
|
||||
method.__ray_invocation_decorator__)
|
||||
|
||||
if hasattr(method, "__ray_concurrency_group__"):
|
||||
self.concurrency_group_for_methods[method_name] = (
|
||||
method.__ray_concurrency_group__)
|
||||
|
||||
# Update cache.
|
||||
cls._cache[actor_creation_function_descriptor] = self
|
||||
return self
|
||||
|
@ -285,8 +299,8 @@ class ActorClassMetadata:
|
|||
def __init__(self, language, modified_class,
|
||||
actor_creation_function_descriptor, class_id, max_restarts,
|
||||
max_task_retries, num_cpus, num_gpus, memory,
|
||||
object_store_memory, resources, accelerator_type,
|
||||
runtime_env):
|
||||
object_store_memory, resources, accelerator_type, runtime_env,
|
||||
concurrency_groups):
|
||||
self.language = language
|
||||
self.modified_class = modified_class
|
||||
self.actor_creation_function_descriptor = \
|
||||
|
@ -303,6 +317,7 @@ class ActorClassMetadata:
|
|||
self.resources = resources
|
||||
self.accelerator_type = accelerator_type
|
||||
self.runtime_env = runtime_env
|
||||
self.concurrency_groups = concurrency_groups
|
||||
self.last_export_session_and_job = None
|
||||
self.method_meta = ActorClassMethodMetadata.create(
|
||||
modified_class, actor_creation_function_descriptor)
|
||||
|
@ -358,10 +373,10 @@ class ActorClass:
|
|||
f"use '{self.__ray_metadata__.class_name}.remote()'.")
|
||||
|
||||
@classmethod
|
||||
def _ray_from_modified_class(cls, modified_class, class_id, max_restarts,
|
||||
max_task_retries, num_cpus, num_gpus, memory,
|
||||
object_store_memory, resources,
|
||||
accelerator_type, runtime_env):
|
||||
def _ray_from_modified_class(
|
||||
cls, modified_class, class_id, max_restarts, max_task_retries,
|
||||
num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||
accelerator_type, runtime_env, concurrency_groups):
|
||||
for attribute in [
|
||||
"remote",
|
||||
"_remote",
|
||||
|
@ -398,7 +413,7 @@ class ActorClass:
|
|||
Language.PYTHON, modified_class,
|
||||
actor_creation_function_descriptor, class_id, max_restarts,
|
||||
max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
|
||||
resources, accelerator_type, new_runtime_env)
|
||||
resources, accelerator_type, new_runtime_env, concurrency_groups)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -413,10 +428,12 @@ class ActorClass:
|
|||
# .remote(), it would get run in the Ray Client server, which runs on
|
||||
# a remote node where the files aren't available.
|
||||
new_runtime_env = ParsedRuntimeEnv(runtime_env or {})
|
||||
|
||||
self.__ray_metadata__ = ActorClassMetadata(
|
||||
language, None, actor_creation_function_descriptor, None,
|
||||
max_restarts, max_task_retries, num_cpus, num_gpus, memory,
|
||||
object_store_memory, resources, accelerator_type, new_runtime_env)
|
||||
object_store_memory, resources, accelerator_type, new_runtime_env,
|
||||
[])
|
||||
|
||||
return self
|
||||
|
||||
|
@ -740,6 +757,25 @@ class ActorClass:
|
|||
parsed_runtime_env = override_task_or_actor_runtime_env(
|
||||
runtime_env, parent_runtime_env)
|
||||
|
||||
concurrency_groups_dict = {}
|
||||
for cg_name in meta.concurrency_groups:
|
||||
concurrency_groups_dict[cg_name] = {
|
||||
"name": cg_name,
|
||||
"max_concurrency": meta.concurrency_groups[cg_name],
|
||||
"function_descriptors": [],
|
||||
}
|
||||
|
||||
# Update methods
|
||||
for method_name in meta.method_meta.concurrency_group_for_methods:
|
||||
cg_name = meta.method_meta.concurrency_group_for_methods[
|
||||
method_name]
|
||||
assert cg_name in concurrency_groups_dict
|
||||
|
||||
module_name = meta.actor_creation_function_descriptor.module_name
|
||||
class_name = meta.actor_creation_function_descriptor.class_name
|
||||
concurrency_groups_dict[cg_name]["function_descriptors"].append(
|
||||
PythonFunctionDescriptor(module_name, method_name, class_name))
|
||||
|
||||
actor_id = worker.core_worker.create_actor(
|
||||
meta.language,
|
||||
meta.actor_creation_function_descriptor,
|
||||
|
@ -759,7 +795,8 @@ class ActorClass:
|
|||
# Store actor_method_cpu in actor handle's extension data.
|
||||
extension_data=str(actor_method_cpu),
|
||||
serialized_runtime_env=parsed_runtime_env.serialize(),
|
||||
runtime_env_uris=parsed_runtime_env.get_uris())
|
||||
runtime_env_uris=parsed_runtime_env.get_uris(),
|
||||
concurrency_groups_dict=concurrency_groups_dict or dict())
|
||||
|
||||
actor_handle = ActorHandle(
|
||||
meta.language,
|
||||
|
@ -1060,7 +1097,8 @@ def modify_class(cls):
|
|||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||
accelerator_type, max_restarts, max_task_retries, runtime_env):
|
||||
accelerator_type, max_restarts, max_task_retries, runtime_env,
|
||||
concurrency_groups):
|
||||
Class = modify_class(cls)
|
||||
_inject_tracing_into_class(Class)
|
||||
|
||||
|
@ -1068,6 +1106,8 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
|||
max_restarts = 0
|
||||
if max_task_retries is None:
|
||||
max_task_retries = 0
|
||||
if concurrency_groups is None:
|
||||
concurrency_groups = []
|
||||
|
||||
infinite_restart = max_restarts == -1
|
||||
if not infinite_restart:
|
||||
|
@ -1086,7 +1126,7 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
|||
return ActorClass._ray_from_modified_class(
|
||||
Class, ActorClassID.from_random(), max_restarts, max_task_retries,
|
||||
num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||
accelerator_type, runtime_env)
|
||||
accelerator_type, runtime_env, concurrency_groups)
|
||||
|
||||
|
||||
def exit_actor():
|
||||
|
|
|
@ -2,7 +2,7 @@ from libcpp cimport bool as c_bool
|
|||
from libcpp.memory cimport shared_ptr, unique_ptr
|
||||
from libcpp.string cimport string as c_string
|
||||
|
||||
from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t
|
||||
from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t, uint32_t
|
||||
from libcpp.unordered_map cimport unordered_map
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.pair cimport pair as c_pair
|
||||
|
@ -255,7 +255,8 @@ cdef extern from "ray/core_worker/common.h" nogil:
|
|||
c_pair[CPlacementGroupID, int64_t] placement_options,
|
||||
c_bool placement_group_capture_child_tasks,
|
||||
c_string serialized_runtime_env,
|
||||
c_vector[c_string] runtime_env_uris)
|
||||
c_vector[c_string] runtime_env_uris,
|
||||
const c_vector[CConcurrencyGroup] &concurrency_groups)
|
||||
|
||||
cdef cppclass CPlacementGroupCreationOptions \
|
||||
"ray::core::PlacementGroupCreationOptions":
|
||||
|
@ -283,3 +284,14 @@ cdef extern from "ray/gcs/gcs_client.h" nogil:
|
|||
cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
|
||||
cdef cppclass CJobConfig "ray::rpc::JobConfig":
|
||||
const c_string &SerializeAsString()
|
||||
|
||||
cdef extern from "ray/common/task/task_spec.h" nogil:
|
||||
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
|
||||
CConcurrencyGroup(
|
||||
const c_string &name,
|
||||
uint32_t max_concurrency,
|
||||
const c_vector[CFunctionDescriptor] &c_fds)
|
||||
CConcurrencyGroup()
|
||||
c_string GetName() const
|
||||
uint32_t GetMaxConcurrency() const
|
||||
c_vector[CFunctionDescriptor] GetFunctionDescriptors() const
|
||||
|
|
|
@ -43,6 +43,7 @@ from ray.includes.common cimport (
|
|||
CGcsClientOptions,
|
||||
LocalMemoryBuffer,
|
||||
CJobConfig,
|
||||
CConcurrencyGroup,
|
||||
)
|
||||
from ray.includes.function_descriptor cimport (
|
||||
CFunctionDescriptor,
|
||||
|
@ -284,7 +285,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
shared_ptr[LocalMemoryBuffer]
|
||||
&creation_task_exception_pb_bytes,
|
||||
c_bool *is_application_level_error) nogil
|
||||
c_bool *is_application_level_error,
|
||||
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
|
||||
const c_string name_of_concurrency_group_to_execute) nogil
|
||||
) task_execution_callback
|
||||
(void(const CWorkerID &) nogil) on_worker_shutdown
|
||||
(CRayStatus() nogil) check_signals
|
||||
|
|
114
python/ray/tests/test_concurrency_group.py
Normal file
114
python/ray/tests/test_concurrency_group.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
# coding: utf-8
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
|
||||
# This tests the methods are executed in the correct eventloop.
|
||||
def test_basic():
|
||||
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
|
||||
class AsyncActor:
|
||||
def __init__(self):
|
||||
self.eventloop_f1 = None
|
||||
self.eventloop_f2 = None
|
||||
self.eventloop_f3 = None
|
||||
self.eventloop_f4 = None
|
||||
self.default_eventloop = asyncio.get_event_loop()
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
async def f1(self):
|
||||
self.eventloop_f1 = asyncio.get_event_loop()
|
||||
return threading.current_thread().ident
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def f2(self):
|
||||
self.eventloop_f2 = asyncio.get_event_loop()
|
||||
return threading.current_thread().ident
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
def f3(self):
|
||||
self.eventloop_f3 = asyncio.get_event_loop()
|
||||
return threading.current_thread().ident
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
def f4(self):
|
||||
self.eventloop_f4 = asyncio.get_event_loop()
|
||||
return threading.current_thread().ident
|
||||
|
||||
def f5(self):
|
||||
# If this method is executed in default eventloop.
|
||||
assert asyncio.get_event_loop() == self.default_eventloop
|
||||
return threading.current_thread().ident
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def do_assert(self):
|
||||
if self.eventloop_f1 != self.eventloop_f2:
|
||||
return False
|
||||
if self.eventloop_f3 != self.eventloop_f4:
|
||||
return False
|
||||
if self.eventloop_f1 == self.eventloop_f3:
|
||||
return False
|
||||
if self.eventloop_f1 == self.eventloop_f4:
|
||||
return False
|
||||
return True
|
||||
|
||||
###############################################
|
||||
a = AsyncActor.remote()
|
||||
f1_thread_id = ray.get(a.f1.remote()) # executed in the "io" group.
|
||||
f2_thread_id = ray.get(a.f2.remote()) # executed in the "io" group.
|
||||
f3_thread_id = ray.get(a.f3.remote()) # executed in the "compute" group.
|
||||
f4_thread_id = ray.get(a.f4.remote()) # executed in the "compute" group.
|
||||
|
||||
assert f1_thread_id == f2_thread_id
|
||||
assert f3_thread_id == f4_thread_id
|
||||
assert f1_thread_id != f3_thread_id
|
||||
|
||||
assert ray.get(a.do_assert.remote())
|
||||
|
||||
assert ray.get(a.f5.remote()) # executed in the default group.
|
||||
|
||||
# It also has the ability to specify it at runtime.
|
||||
# This task will be invoked in the `compute` thread pool.
|
||||
a.f2.options(concurrency_group="compute").remote()
|
||||
|
||||
|
||||
# The case tests that the asyncio count down works well in one concurrency
|
||||
# group.
|
||||
def test_async_methods_in_concurrency_group():
|
||||
@ray.remote(concurrency_groups={"async": 3})
|
||||
class AsyncBatcher:
|
||||
def __init__(self):
|
||||
self.batch = []
|
||||
self.event = None
|
||||
|
||||
@ray.method(concurrency_group="async")
|
||||
def init_event(self):
|
||||
self.event = asyncio.Event()
|
||||
return True
|
||||
|
||||
@ray.method(concurrency_group="async")
|
||||
async def add(self, x):
|
||||
self.batch.append(x)
|
||||
if len(self.batch) >= 3:
|
||||
self.event.set()
|
||||
else:
|
||||
await self.event.wait()
|
||||
return sorted(self.batch)
|
||||
|
||||
a = AsyncBatcher.remote()
|
||||
ray.get(a.init_event.remote())
|
||||
|
||||
x1 = a.add.remote(1)
|
||||
x2 = a.add.remote(2)
|
||||
x3 = a.add.remote(3)
|
||||
r1 = ray.get(x1)
|
||||
r2 = ray.get(x2)
|
||||
r3 = ray.get(x3)
|
||||
assert r1 == [1, 2, 3]
|
||||
assert r1 == r2 == r3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1928,7 +1928,8 @@ def make_decorator(num_returns=None,
|
|||
runtime_env=None,
|
||||
placement_group="default",
|
||||
worker=None,
|
||||
retry_exceptions=None):
|
||||
retry_exceptions=None,
|
||||
concurrency_groups=None):
|
||||
def decorator(function_or_class):
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
|
@ -1983,10 +1984,10 @@ def make_decorator(num_returns=None,
|
|||
raise ValueError(
|
||||
"The keyword 'max_task_retries' only accepts -1, 0 or a"
|
||||
" positive integer")
|
||||
return ray.actor.make_actor(function_or_class, num_cpus, num_gpus,
|
||||
memory, object_store_memory, resources,
|
||||
accelerator_type, max_restarts,
|
||||
max_task_retries, runtime_env)
|
||||
return ray.actor.make_actor(
|
||||
function_or_class, num_cpus, num_gpus, memory,
|
||||
object_store_memory, resources, accelerator_type, max_restarts,
|
||||
max_task_retries, runtime_env, concurrency_groups)
|
||||
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
@ -2105,10 +2106,21 @@ def remote(*args, **kwargs):
|
|||
|
||||
# Parse the keyword arguments from the decorator.
|
||||
valid_kwargs = [
|
||||
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
|
||||
"resources", "accelerator_type", "max_calls", "max_restarts",
|
||||
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions",
|
||||
"placement_group"
|
||||
"num_returns",
|
||||
"num_cpus",
|
||||
"num_gpus",
|
||||
"memory",
|
||||
"object_store_memory",
|
||||
"resources",
|
||||
"accelerator_type",
|
||||
"max_calls",
|
||||
"max_restarts",
|
||||
"max_task_retries",
|
||||
"max_retries",
|
||||
"runtime_env",
|
||||
"retry_exceptions",
|
||||
"placement_group",
|
||||
"concurrency_groups",
|
||||
]
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
|
@ -2143,6 +2155,7 @@ def remote(*args, **kwargs):
|
|||
runtime_env = kwargs.get("runtime_env")
|
||||
placement_group = kwargs.get("placement_group", "default")
|
||||
retry_exceptions = kwargs.get("retry_exceptions")
|
||||
concurrency_groups = kwargs.get("concurrency_groups")
|
||||
|
||||
return make_decorator(
|
||||
num_returns=num_returns,
|
||||
|
@ -2159,4 +2172,5 @@ def remote(*args, **kwargs):
|
|||
runtime_env=runtime_env,
|
||||
placement_group=placement_group,
|
||||
worker=worker,
|
||||
retry_exceptions=retry_exceptions)
|
||||
retry_exceptions=retry_exceptions,
|
||||
concurrency_groups=concurrency_groups or [])
|
||||
|
|
|
@ -43,6 +43,20 @@ struct ConcurrencyGroup {
|
|||
uint32_t max_concurrency;
|
||||
// Function descriptors of the actor methods in this group.
|
||||
std::vector<ray::FunctionDescriptor> function_descriptors;
|
||||
|
||||
ConcurrencyGroup() = default;
|
||||
|
||||
ConcurrencyGroup(const std::string &name, uint32_t max_concurrency,
|
||||
const std::vector<ray::FunctionDescriptor> &fds)
|
||||
: name(name), max_concurrency(max_concurrency), function_descriptors(fds) {}
|
||||
|
||||
std::string GetName() const { return name; }
|
||||
|
||||
uint32_t GetMaxConcurrency() const { return max_concurrency; }
|
||||
|
||||
std::vector<ray::FunctionDescriptor> GetFunctionDescriptors() const {
|
||||
return function_descriptors;
|
||||
}
|
||||
};
|
||||
|
||||
static inline rpc::ObjectReference GetReferenceForActorDummyObject(
|
||||
|
|
|
@ -2273,11 +2273,21 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr;
|
||||
|
||||
std::vector<ConcurrencyGroup> defined_concurrency_groups = {};
|
||||
std::string name_of_concurrency_group_to_execute;
|
||||
if (task_spec.IsActorCreationTask()) {
|
||||
defined_concurrency_groups = task_spec.ConcurrencyGroups();
|
||||
} else if (task_spec.IsActorTask()) {
|
||||
name_of_concurrency_group_to_execute = task_spec.ConcurrencyGroupName();
|
||||
}
|
||||
|
||||
status = options_.task_execution_callback(
|
||||
task_type, task_spec.GetName(), func,
|
||||
task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
|
||||
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
|
||||
creation_task_exception_pb_bytes, is_application_level_error);
|
||||
creation_task_exception_pb_bytes, is_application_level_error,
|
||||
defined_concurrency_groups, name_of_concurrency_group_to_execute);
|
||||
|
||||
// Get the reference counts for any IDs that we borrowed during this task,
|
||||
// remove the local reference for these IDs, and return the ref count info to
|
||||
|
|
|
@ -71,7 +71,15 @@ struct CoreWorkerOptions {
|
|||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<RayObject>> *results,
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error)>;
|
||||
bool *is_application_level_error,
|
||||
// The following 2 parameters `defined_concurrency_groups` and
|
||||
// `name_of_concurrency_group_to_execute` are used for Python
|
||||
// asyncio actor only.
|
||||
//
|
||||
// Defined concurrency groups of this actor. Note this is only
|
||||
// used for actor creation task.
|
||||
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||
const std::string name_of_concurrency_group_to_execute)>;
|
||||
|
||||
CoreWorkerOptions()
|
||||
: store_socket(""),
|
||||
|
|
|
@ -101,7 +101,13 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<RayObject>> *results,
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
|
||||
bool *is_application_level_error) {
|
||||
bool *is_application_level_error,
|
||||
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||
const std::string name_of_concurrency_group_to_execute) {
|
||||
// These 2 parameters are used for Python only, and Java worker
|
||||
// will not use them.
|
||||
RAY_UNUSED(defined_concurrency_groups);
|
||||
RAY_UNUSED(name_of_concurrency_group_to_execute);
|
||||
// TODO(jjyao): Support retrying application-level errors for Java
|
||||
*is_application_level_error = false;
|
||||
|
||||
|
|
|
@ -509,6 +509,21 @@ class MockWorkerContext : public WorkerContext {
|
|||
}
|
||||
};
|
||||
|
||||
class MockCoreWorkerDirectTaskReceiver : public CoreWorkerDirectTaskReceiver {
|
||||
public:
|
||||
MockCoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
|
||||
instrumented_io_context &main_io_service,
|
||||
const TaskHandler &task_handler,
|
||||
const OnTaskDone &task_done)
|
||||
: CoreWorkerDirectTaskReceiver(worker_context, main_io_service, task_handler,
|
||||
task_done) {}
|
||||
|
||||
void UpdateConcurrencyGroupsCache(const ActorID &actor_id,
|
||||
const std::vector<ConcurrencyGroup> &cgs) {
|
||||
concurrency_groups_cache_[actor_id] = cgs;
|
||||
}
|
||||
};
|
||||
|
||||
class DirectActorReceiverTest : public ::testing::Test {
|
||||
public:
|
||||
DirectActorReceiverTest()
|
||||
|
@ -518,7 +533,7 @@ class DirectActorReceiverTest : public ::testing::Test {
|
|||
auto execute_task =
|
||||
std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
|
||||
receiver_ = std::make_unique<MockCoreWorkerDirectTaskReceiver>(
|
||||
worker_context_, main_io_service_, execute_task, [] { return Status::OK(); });
|
||||
receiver_->Init(std::make_shared<rpc::CoreWorkerClientPool>(
|
||||
[&](const rpc::Address &addr) { return worker_client_; }),
|
||||
|
@ -541,7 +556,7 @@ class DirectActorReceiverTest : public ::testing::Test {
|
|||
main_io_service_.stop();
|
||||
}
|
||||
|
||||
std::unique_ptr<CoreWorkerDirectTaskReceiver> receiver_;
|
||||
std::unique_ptr<MockCoreWorkerDirectTaskReceiver> receiver_;
|
||||
|
||||
private:
|
||||
rpc::Address rpc_address_;
|
||||
|
@ -575,6 +590,7 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) {
|
|||
++callback_count;
|
||||
ASSERT_TRUE(status.ok());
|
||||
};
|
||||
receiver_->UpdateConcurrencyGroupsCache(actor_id, {});
|
||||
receiver_->HandleTask(request, &reply, reply_callback);
|
||||
}
|
||||
|
||||
|
|
|
@ -524,6 +524,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
return_object->add_nested_inlined_refs()->CopyFrom(nested_ref);
|
||||
}
|
||||
}
|
||||
|
||||
if (task_spec.IsActorCreationTask()) {
|
||||
/// The default max concurrency for creating PoolManager should
|
||||
/// be 0 if this is an asyncio actor.
|
||||
|
@ -531,6 +532,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency();
|
||||
pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(),
|
||||
default_max_concurrency);
|
||||
concurrency_groups_cache_[task_spec.TaskId().ActorId()] =
|
||||
task_spec.ConcurrencyGroups();
|
||||
RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId()
|
||||
<< ", actor_id: " << task_spec.ActorCreationId();
|
||||
// Tell raylet that an actor creation task has finished execution, so that
|
||||
|
@ -573,11 +576,13 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
if (task_spec.IsActorTask()) {
|
||||
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
||||
if (it == actor_scheduling_queues_.end()) {
|
||||
auto cg_it = concurrency_groups_cache_.find(task_spec.ActorId());
|
||||
RAY_CHECK(cg_it != concurrency_groups_cache_.end());
|
||||
auto result = actor_scheduling_queues_.emplace(
|
||||
task_spec.CallerWorkerId(),
|
||||
std::unique_ptr<SchedulingQueue>(
|
||||
new ActorSchedulingQueue(task_main_io_service_, *waiter_, pool_manager_,
|
||||
is_asyncio_, fiber_max_concurrency_)));
|
||||
std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
|
||||
task_main_io_service_, *waiter_, pool_manager_, is_asyncio_,
|
||||
fiber_max_concurrency_, cg_it->second)));
|
||||
it = result.first;
|
||||
}
|
||||
|
||||
|
|
|
@ -283,6 +283,69 @@ class CoreWorkerDirectActorTaskSubmitter
|
|||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
||||
/// The class that manages fiber states for Python asyncio actors.
|
||||
///
|
||||
/// We'll create one fiber state for every concurrency group. And
|
||||
/// create one default fiber state for default concurrency group if
|
||||
/// necessary.
|
||||
class FiberStateManager final {
|
||||
public:
|
||||
explicit FiberStateManager(const std::vector<ConcurrencyGroup> &concurrency_groups = {},
|
||||
const int32_t default_group_max_concurrency = 1000) {
|
||||
for (auto &group : concurrency_groups) {
|
||||
const auto name = group.name;
|
||||
const auto max_concurrency = group.max_concurrency;
|
||||
auto fiber = std::make_shared<FiberState>(max_concurrency);
|
||||
auto &fds = group.function_descriptors;
|
||||
for (auto fd : fds) {
|
||||
functions_to_fiber_index_[fd->ToString()] = fiber;
|
||||
}
|
||||
name_to_fiber_index_[name] = fiber;
|
||||
}
|
||||
/// Create default fiber state for default concurrency group.
|
||||
if (default_group_max_concurrency >= 1) {
|
||||
default_fiber_ = std::make_shared<FiberState>(default_group_max_concurrency);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the corresponding fiber state by the give concurrency group or function
|
||||
/// descriptor.
|
||||
///
|
||||
/// Return the corresponding fiber state of the concurrency group
|
||||
/// if concurrency_group_name is given.
|
||||
/// Otherwise return the corresponding fiber state by the given function descriptor.
|
||||
std::shared_ptr<FiberState> GetFiber(const std::string &concurrency_group_name,
|
||||
ray::FunctionDescriptor fd) {
|
||||
if (!concurrency_group_name.empty()) {
|
||||
auto it = name_to_fiber_index_.find(concurrency_group_name);
|
||||
RAY_CHECK(it != name_to_fiber_index_.end())
|
||||
<< "Failed to look up the fiber state of the given concurrency group "
|
||||
<< concurrency_group_name << " . It might be that you didn't define "
|
||||
<< "the concurrency group " << concurrency_group_name;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
/// Code path of that this task wasn't specified in a concurrency group addtionally.
|
||||
/// Use the predefined concurrency group.
|
||||
if (functions_to_fiber_index_.find(fd->ToString()) !=
|
||||
functions_to_fiber_index_.end()) {
|
||||
return functions_to_fiber_index_[fd->ToString()];
|
||||
}
|
||||
return default_fiber_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Map from the name to their corresponding fibers.
|
||||
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> name_to_fiber_index_;
|
||||
|
||||
// Map from the FunctionDescriptors to their corresponding fibers.
|
||||
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> functions_to_fiber_index_;
|
||||
|
||||
// The fiber for default concurrency group. It's nullptr if its max concurrency
|
||||
// is 1.
|
||||
std::shared_ptr<FiberState> default_fiber_ = nullptr;
|
||||
};
|
||||
|
||||
class BoundedExecutor;
|
||||
|
||||
/// A manager that manages a set of thread pool. which will perform
|
||||
|
@ -489,6 +552,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
instrumented_io_context &main_io_service, DependencyWaiter &waiter,
|
||||
std::shared_ptr<PoolManager> pool_manager = std::make_shared<PoolManager>(),
|
||||
bool is_asyncio = false, int fiber_max_concurrency = 1,
|
||||
const std::vector<ConcurrencyGroup> &concurrency_groups = {},
|
||||
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
||||
: reorder_wait_seconds_(reorder_wait_seconds),
|
||||
wait_timer_(main_io_service),
|
||||
|
@ -497,9 +561,16 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
pool_manager_(pool_manager),
|
||||
is_asyncio_(is_asyncio) {
|
||||
if (is_asyncio_) {
|
||||
RAY_LOG(INFO) << "Setting actor as async with max_concurrency="
|
||||
<< fiber_max_concurrency << ", creating new fiber thread.";
|
||||
fiber_state_ = std::make_unique<FiberState>(fiber_max_concurrency);
|
||||
std::stringstream ss;
|
||||
ss << "Setting actor as asyncio with max_concurrency=" << fiber_max_concurrency
|
||||
<< ", and defined concurrency groups are:" << std::endl;
|
||||
for (const auto &concurrency_group : concurrency_groups) {
|
||||
ss << "\t" << concurrency_group.name << " : "
|
||||
<< concurrency_group.max_concurrency;
|
||||
}
|
||||
RAY_LOG(INFO) << ss.str();
|
||||
fiber_state_manager_ =
|
||||
std::make_unique<FiberStateManager>(concurrency_groups, fiber_max_concurrency);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -595,7 +666,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
|
||||
if (is_asyncio_) {
|
||||
// Process async actor task.
|
||||
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
|
||||
auto fiber = fiber_state_manager_->GetFiber(request.ConcurrencyGroupName(),
|
||||
request.FunctionDescriptor());
|
||||
fiber->EnqueueFiber([request]() mutable { request.Accept(); });
|
||||
} else {
|
||||
// Process actor tasks.
|
||||
RAY_CHECK(pool_manager_ != nullptr);
|
||||
|
@ -661,9 +734,10 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
||||
/// will instantiate all tasks as fibers that can be yielded.
|
||||
bool is_asyncio_ = false;
|
||||
/// If is_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// to enable continuation and work together with python asyncio.
|
||||
std::unique_ptr<FiberState> fiber_state_;
|
||||
/// Manage the running fiber states of actors in this worker. It works with
|
||||
/// python asyncio if this is an asyncio actor.
|
||||
std::unique_ptr<FiberStateManager> fiber_state_manager_;
|
||||
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
|
||||
|
@ -822,6 +896,10 @@ class CoreWorkerDirectTaskReceiver {
|
|||
|
||||
bool CancelQueuedNormalTask(TaskID task_id);
|
||||
|
||||
protected:
|
||||
/// Cache the concurrency groups of actors.
|
||||
absl::flat_hash_map<ActorID, std::vector<ConcurrencyGroup>> concurrency_groups_cache_;
|
||||
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
|
|
Loading…
Add table
Reference in a new issue