[Core] Port concurrency groups with asyncio (#18567)

## Why are these changes needed?
This PR aims to port concurrency groups functionality with asyncio for Python.

### API
```python
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
class AsyncActor:
    def __init__(self):
        pass

    @ray.method(concurrency_group="io")
    async def f1(self):
        pass

    @ray.method(concurrency_group="io")
    def f2(self):
        pass

    @ray.method(concurrency_group="compute")
    def f3(self):
        pass

    @ray.method(concurrency_group="compute")
    def f4(self):
        pass

    def f5(self):
        pass
```
The annotation above the actor class `AsyncActor` defines this actor will have 2 concurrency groups and defines their max concurrencies, and it has a default concurrency group.  Every concurrency group has an async eventloop and a pythread to execute the methods which is defined on them.

Method `f1` will be invoked in the `io` concurrency group. `f2` in `io`, `f3` in `compute` and etc.
TO BE NOTICED, `f5` and `__init__` will be invoked in the default concurrency.

The following method `f2` will be invoked in the concurrency group `compute` since the dynamic specifying has a higher priority.
```python
a.f2.options(concurrency_group="compute").remote()
```

### Implementation
The straightforward implementation details are:
 - Before we only have 1 eventloop binding 1 pythread for an asyncio actor. Now we create 1 eventloop binding 1 pythread for every concurrency group of the asyncio actor.
- Before we have 1 fiber state for every caller in the asyncio actor. Now we create a FiberStateManager for every caller in the asyncio actor. And the FiberStateManager manages the fiber states for concurrency groups.


## Related issue number
#16047
This commit is contained in:
Qing Wang 2021-10-21 21:46:56 +08:00 committed by GitHub
parent a04b02e2e8
commit 048e7f7d5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 623 additions and 67 deletions

View file

@ -129,7 +129,9 @@ Status TaskExecutor::ExecuteTask(
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results, std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes, std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error) { bool *is_application_level_error,
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
const std::string name_of_concurrency_group_to_execute) {
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type); RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
auto function_descriptor = ray_function.GetFunctionDescriptor(); auto function_descriptor = ray_function.GetFunctionDescriptor();

View file

@ -82,7 +82,9 @@ class TaskExecutor {
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results, std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes, std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error); bool *is_application_level_error,
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
const std::string name_of_concurrency_group_to_execute);
virtual ~TaskExecutor(){}; virtual ~TaskExecutor(){};

View file

@ -0,0 +1,101 @@
Limiting Concurrency Per-Method with Concurrency Groups
=======================================================
Besides setting the max concurrency overall for an asyncio actor, Ray allows methods to be separated into *concurrency groups*, each with its own asyncio event loop. This allows you to limit the concurrency per-method, e.g., allow a health-check method to be given its own concurrency quota separate from request serving methods.
.. warning:: Concurrency groups are only supported for asyncio actors, not threaded actors.
.. _defining-concurrency-groups:
Defining Concurrency Groups
---------------------------
You can define concurrency groups for asyncio actors using the ``concurrency_groups`` decorator argument:
.. tabs::
.. group-tab:: Python
This defines two concurrency groups, "io" with max_concurrency=2 and
"compute" with max_concurrency=4. The methods ``f1`` and ``f2`` are
placed in the "io" group, and the methods ``f3`` and ``f4`` are placed
into the "compute" group. Note that there is always a default
concurrency group, which has a default concurrency of 1000.
.. code-block:: python
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
class AsyncIOActor:
def __init__(self):
pass
@ray.method(concurrency_group="io")
async def f1(self):
pass
@ray.method(concurrency_group="io")
async def f2(self):
pass
@ray.method(concurrency_group="compute")
async def f3(self):
pass
@ray.method(concurrency_group="compute")
async def f4(self):
pass
async def f5(self):
pass
a = AsyncIOActor.remote()
a.f1.remote() # executed in the "io" group.
a.f2.remote() # executed in the "io" group.
a.f3.remote() # executed in the "compute" group.
a.f4.remote() # executed in the "compute" group.
a.f5.remote() # executed in the default group.
.. _default-concurrency-group:
Default Concurrency Group
-------------------------
By default, methods are placed in a default concurrency group which has a concurrency limit of 1000.
The concurrency of the default group can be changed by setting the ``max_concurrency`` actor option.
.. tabs::
.. group-tab:: Python
The following AsyncIOActor has 2 concurrency groups: "io" and "default".
The max concurrency of "io" is 2, and the max concurrency of "default" is 10.
.. code-block:: python
@ray.remote(concurrency_groups={"io": 2)
class AsyncIOActor:
async def f1(self):
pass
actor = AsyncIOActor.options(max_concurrency=10).remote()
.. _setting-the-concurrency-group-at-runtime:
Setting the Concurrency Group at Runtime
----------------------------------------
You can also dispatch actor methods into a specific concurrency group at runtime using the ``.options`` method:
.. tabs::
.. group-tab:: Python
The following snippet demonstrates setting the concurrency group of the
``f2`` method dynamically at runtime.
.. code-block:: python
# Executed in the "io" group (as defined in the actor class).
a.f2.options().remote()
# Executed in the "compute" group.
a.f2.options(concurrency_group="compute").remote()

View file

@ -14,6 +14,7 @@ Finally, we've also included some content on using core Ray APIs with `Tensorflo
actors.rst actors.rst
namespaces.rst namespaces.rst
async_api.rst async_api.rst
concurrency_group_api.rst
using-ray-with-gpus.rst using-ray-with-gpus.rst
serialization.rst serialization.rst
memory-management.rst memory-management.rst

View file

@ -17,6 +17,7 @@ from ray.includes.common cimport (
CBuffer, CBuffer,
CRayObject, CRayObject,
CAddress, CAddress,
CConcurrencyGroup,
) )
from ray.includes.libcoreworker cimport ( from ray.includes.libcoreworker cimport (
ActorHandleSharedPtr, ActorHandleSharedPtr,
@ -117,6 +118,11 @@ cdef class CoreWorker:
object current_runtime_env object current_runtime_env
c_bool is_local_mode c_bool is_local_mode
object cgname_to_eventloop_dict
object eventloop_for_default_cg
object thread_for_default_cg
object fd_to_cgname_dict
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
size_t data_size, ObjectRef object_ref, size_t data_size, ObjectRef object_ref,
c_vector[CObjectID] contained_ids, c_vector[CObjectID] contained_ids,
@ -130,6 +136,10 @@ cdef class CoreWorker:
c_vector[shared_ptr[CRayObject]] *returns) c_vector[shared_ptr[CRayObject]] *returns)
cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef yield_current_fiber(self, CFiberEvent &fiber_event)
cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle)
cdef c_function_descriptors_to_python(
self, const c_vector[CFunctionDescriptor] &c_function_descriptors)
cdef initialize_eventloops_for_actor_concurrency_group(
self, const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups)
cdef class FunctionDescriptor: cdef class FunctionDescriptor:
cdef: cdef:

View file

@ -60,6 +60,7 @@ from ray.includes.common cimport (
CRayFunction, CRayFunction,
CWorkerType, CWorkerType,
CJobConfig, CJobConfig,
CConcurrencyGroup,
move, move,
LANGUAGE_CPP, LANGUAGE_CPP,
LANGUAGE_JAVA, LANGUAGE_JAVA,
@ -316,6 +317,34 @@ cdef int prepare_resources(
resource_map[0][key.encode("ascii")] = float(value) resource_map[0][key.encode("ascii")] = float(value)
return 0 return 0
cdef c_vector[CFunctionDescriptor] prepare_function_descriptors(pyfd_list):
cdef:
c_vector[CFunctionDescriptor] fd_list
CRayFunction ray_function
for pyfd in pyfd_list:
fd_list.push_back(CFunctionDescriptorBuilder.BuildPython(
pyfd.module_name, pyfd.class_name, pyfd.function_name, b""))
return fd_list
cdef int prepare_actor_concurrency_groups(
dict concurrency_groups_dict,
c_vector[CConcurrencyGroup] *concurrency_groups):
cdef:
CConcurrencyGroup cg
c_vector[CFunctionDescriptor] c_fd_list
if concurrency_groups_dict is None:
raise ValueError("Must provide it...")
for key, value in concurrency_groups_dict.items():
c_fd_list = prepare_function_descriptors(value["function_descriptors"])
cg = CConcurrencyGroup(
key.encode("ascii"), value["max_concurrency"], c_fd_list)
concurrency_groups.push_back(cg)
return 1
cdef prepare_args( cdef prepare_args(
CoreWorker core_worker, CoreWorker core_worker,
@ -411,7 +440,11 @@ cdef execute_task(
const c_vector[CObjectID] &c_return_ids, const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint, const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns, c_vector[shared_ptr[CRayObject]] *returns,
c_bool *is_application_level_error): c_bool *is_application_level_error,
# This parameter is only used for actor creation task to define
# the concurrency groups of this actor.
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups,
const c_string c_name_of_concurrency_group_to_execute):
is_application_level_error[0] = False is_application_level_error[0] = False
@ -462,6 +495,11 @@ cdef execute_task(
print(actor_magic_token) print(actor_magic_token)
print(actor_magic_token, file=sys.stderr) print(actor_magic_token, file=sys.stderr)
# Initial eventloops for asyncio for this actor.
if core_worker.current_actor_is_asyncio():
core_worker.initialize_eventloops_for_actor_concurrency_group(
c_defined_concurrency_groups)
execution_info = execution_infos.get(function_descriptor) execution_info = execution_infos.get(function_descriptor)
if not execution_info: if not execution_info:
execution_info = manager.get_execution_info( execution_info = manager.get_execution_info(
@ -473,6 +511,8 @@ cdef execute_task(
b' "task_id": ' + task_id.hex().encode("ascii") + b'}') b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
task_name = name.decode("utf-8") task_name = name.decode("utf-8")
name_of_concurrency_group_to_execute = \
c_name_of_concurrency_group_to_execute.decode("ascii")
title = f"ray::{task_name}" title = f"ray::{task_name}"
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK: if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
@ -520,7 +560,9 @@ cdef execute_task(
async_function = sync_to_async(function) async_function = sync_to_async(function)
return core_worker.run_async_func_in_event_loop( return core_worker.run_async_func_in_event_loop(
async_function, actor, *arguments, **kwarguments) async_function, function_descriptor,
name_of_concurrency_group_to_execute, actor,
*arguments, **kwarguments)
return function(actor, *arguments, **kwarguments) return function(actor, *arguments, **kwarguments)
@ -546,7 +588,8 @@ cdef execute_task(
.deserialize_objects( .deserialize_objects(
metadata_pairs, object_refs)) metadata_pairs, object_refs))
args = core_worker.run_async_func_in_event_loop( args = core_worker.run_async_func_in_event_loop(
deserialize_args) deserialize_args, function_descriptor,
name_of_concurrency_group_to_execute)
else: else:
args = ray.worker.global_worker.deserialize_objects( args = ray.worker.global_worker.deserialize_objects(
metadata_pairs, object_refs) metadata_pairs, object_refs)
@ -692,7 +735,9 @@ cdef CRayStatus task_execution_handler(
const c_string debugger_breakpoint, const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns, c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes, shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil: c_bool *is_application_level_error,
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
const c_string name_of_concurrency_group_to_execute) nogil:
with gil, disable_client_hook(): with gil, disable_client_hook():
try: try:
try: try:
@ -701,7 +746,9 @@ cdef CRayStatus task_execution_handler(
execute_task(task_type, task_name, ray_function, c_resources, execute_task(task_type, task_name, ray_function, c_resources,
c_args, c_arg_refs, c_return_ids, c_args, c_arg_refs, c_return_ids,
debugger_breakpoint, returns, debugger_breakpoint, returns,
is_application_level_error) is_application_level_error,
defined_concurrency_groups,
name_of_concurrency_group_to_execute)
except Exception as e: except Exception as e:
sys_exit = SystemExit() sys_exit = SystemExit()
if isinstance(e, RayActorError) and \ if isinstance(e, RayActorError) and \
@ -1020,6 +1067,10 @@ cdef class CoreWorker:
options.startup_token = startup_token options.startup_token = startup_token
CCoreWorkerProcess.Initialize(options) CCoreWorkerProcess.Initialize(options)
self.cgname_to_eventloop_dict = None
self.fd_to_cgname_dict = None
self.eventloop_for_default_cg = None
def shutdown(self): def shutdown(self):
with nogil: with nogil:
# If it's a worker, the core worker process should have been # If it's a worker, the core worker process should have been
@ -1414,6 +1465,7 @@ cdef class CoreWorker:
c_string extension_data, c_string extension_data,
c_string serialized_runtime_env, c_string serialized_runtime_env,
runtime_env_uris, runtime_env_uris,
concurrency_groups_dict,
): ):
cdef: cdef:
CRayFunction ray_function CRayFunction ray_function
@ -1425,6 +1477,7 @@ cdef class CoreWorker:
CPlacementGroupID c_placement_group_id = \ CPlacementGroupID c_placement_group_id = \
placement_group_id.native() placement_group_id.native()
c_vector[c_string] c_runtime_env_uris = runtime_env_uris c_vector[c_string] c_runtime_env_uris = runtime_env_uris
c_vector[CConcurrencyGroup] c_concurrency_groups
with self.profile_event(b"submit_task"): with self.profile_event(b"submit_task"):
prepare_resources(resources, &c_resources) prepare_resources(resources, &c_resources)
@ -1432,6 +1485,8 @@ cdef class CoreWorker:
ray_function = CRayFunction( ray_function = CRayFunction(
language.lang, function_descriptor.descriptor) language.lang, function_descriptor.descriptor)
prepare_args(self, language, args, &args_vector) prepare_args(self, language, args, &args_vector)
prepare_actor_concurrency_groups(
concurrency_groups_dict, &c_concurrency_groups)
with nogil: with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor( check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
@ -1447,7 +1502,8 @@ cdef class CoreWorker:
placement_group_bundle_index), placement_group_bundle_index),
placement_group_capture_child_tasks, placement_group_capture_child_tasks,
serialized_runtime_env, serialized_runtime_env,
c_runtime_env_uris), c_runtime_env_uris,
c_concurrency_groups),
extension_data, extension_data,
&c_actor_id)) &c_actor_id))
@ -1814,32 +1870,96 @@ cdef class CoreWorker:
CCoreWorkerProcess.GetCoreWorker().SealReturnObject( CCoreWorkerProcess.GetCoreWorker().SealReturnObject(
return_id, returns[0][i])) return_id, returns[0][i]))
def create_or_get_event_loop(self): cdef c_function_descriptors_to_python(
if self.async_event_loop is None: self,
self.async_event_loop = get_new_event_loop() const c_vector[CFunctionDescriptor] &c_function_descriptors):
asyncio.set_event_loop(self.async_event_loop)
if self.async_thread is None: ret = []
self.async_thread = threading.Thread( for i in range(c_function_descriptors.size()):
target=lambda: self.async_event_loop.run_forever(), ret.append(CFunctionDescriptorToPython(c_function_descriptors[i]))
name="AsyncIO Thread" return ret
cdef initialize_eventloops_for_actor_concurrency_group(
self,
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups):
cdef:
CConcurrencyGroup c_concurrency_group
c_vector[CFunctionDescriptor] c_function_descriptors
self.cgname_to_eventloop_dict = {}
self.fd_to_cgname_dict = {}
self.eventloop_for_default_cg = get_new_event_loop()
self.thread_for_default_cg = threading.Thread(
target=lambda: self.eventloop_for_default_cg.run_forever(),
name="AsyncIO Thread: default"
)
# Making the thread as daemon to let it exit
# when the main thread exits.
self.thread_for_default_cg.daemon = True
self.thread_for_default_cg.start()
for i in range(c_defined_concurrency_groups.size()):
c_concurrency_group = c_defined_concurrency_groups[i]
cg_name = c_concurrency_group.GetName().decode("ascii")
function_descriptors = self.c_function_descriptors_to_python(
c_concurrency_group.GetFunctionDescriptors())
async_eventloop = get_new_event_loop()
async_thread = threading.Thread(
target=lambda: async_eventloop.run_forever(),
name="AsyncIO Thread: {}".format(cg_name)
) )
# Making the thread a daemon causes it to exit # Making the thread a daemon causes it to exit
# when the main thread exits. # when the main thread exits.
self.async_thread.daemon = True async_thread.daemon = True
self.async_thread.start() async_thread.start()
return self.async_event_loop self.cgname_to_eventloop_dict[cg_name] = {
"eventloop": async_eventloop,
"thread": async_thread,
}
for fd in function_descriptors:
self.fd_to_cgname_dict[fd] = cg_name
def get_event_loop(self, function_descriptor, specified_cgname):
# __init__ will be invoked in default eventloop
if function_descriptor.function_name == "__init__":
return self.eventloop_for_default_cg, self.thread_for_default_cg
if specified_cgname is not None:
if specified_cgname in self.cgname_to_eventloop_dict:
this_group = self.cgname_to_eventloop_dict[specified_cgname]
return (this_group["eventloop"], this_group["thread"])
if function_descriptor in self.fd_to_cgname_dict:
curr_cgname = self.fd_to_cgname_dict[function_descriptor]
if curr_cgname in self.cgname_to_eventloop_dict:
return (
self.cgname_to_eventloop_dict[curr_cgname]["eventloop"],
self.cgname_to_eventloop_dict[curr_cgname]["thread"])
else:
raise ValueError(
"The function {} is defined to be executed "
"in the concurrency group {} . But there is no this group."
.format(function_descriptor, curr_cgname))
return self.eventloop_for_default_cg, self.thread_for_default_cg
def run_async_func_in_event_loop(
self, func, function_descriptor, specified_cgname, *args, **kwargs):
def run_async_func_in_event_loop(self, func, *args, **kwargs):
cdef: cdef:
CFiberEvent event CFiberEvent event
loop = self.create_or_get_event_loop() eventloop, async_thread = self.get_event_loop(
function_descriptor, specified_cgname)
coroutine = func(*args, **kwargs) coroutine = func(*args, **kwargs)
if threading.get_ident() == self.async_thread.ident: if threading.get_ident() == async_thread.ident:
future = asyncio.ensure_future(coroutine, loop) future = asyncio.ensure_future(coroutine, eventloop)
else: else:
future = asyncio.run_coroutine_threadsafe(coroutine, loop) future = asyncio.run_coroutine_threadsafe(coroutine, eventloop)
future.add_done_callback(lambda _: event.Notify()) future.add_done_callback(lambda _: event.Notify())
with nogil: with nogil:
(CCoreWorkerProcess.GetCoreWorker() (CCoreWorkerProcess.GetCoreWorker()

View file

@ -54,11 +54,14 @@ def method(*args, **kwargs):
""" """
assert len(args) == 0 assert len(args) == 0
assert len(kwargs) == 1 assert len(kwargs) == 1
assert "num_returns" in kwargs
num_returns = kwargs["num_returns"] assert "num_returns" in kwargs or "concurrency_group" in kwargs
def annotate_method(method): def annotate_method(method):
method.__ray_num_returns__ = num_returns if "num_returns" in kwargs:
method.__ray_num_returns__ = kwargs["num_returns"]
if "concurrency_group" in kwargs:
method.__ray_concurrency_group__ = kwargs["concurrency_group"]
return method return method
return annotate_method return annotate_method
@ -138,7 +141,12 @@ class ActorMethod:
return FuncWrapper() return FuncWrapper()
@_tracing_actor_method_invocation @_tracing_actor_method_invocation
def _remote(self, args=None, kwargs=None, name="", num_returns=None): def _remote(self,
args=None,
kwargs=None,
name="",
num_returns=None,
concurrency_group=None):
if num_returns is None: if num_returns is None:
num_returns = self._num_returns num_returns = self._num_returns
@ -222,6 +230,8 @@ class ActorClassMethodMetadata(object):
self.decorators = {} self.decorators = {}
self.signatures = {} self.signatures = {}
self.num_returns = {} self.num_returns = {}
self.concurrency_group_for_methods = {}
for method_name, method in actor_methods: for method_name, method in actor_methods:
# Whether or not this method requires binding of its first # Whether or not this method requires binding of its first
# argument. For class and static methods, we do not want to bind # argument. For class and static methods, we do not want to bind
@ -247,6 +257,10 @@ class ActorClassMethodMetadata(object):
self.decorators[method_name] = ( self.decorators[method_name] = (
method.__ray_invocation_decorator__) method.__ray_invocation_decorator__)
if hasattr(method, "__ray_concurrency_group__"):
self.concurrency_group_for_methods[method_name] = (
method.__ray_concurrency_group__)
# Update cache. # Update cache.
cls._cache[actor_creation_function_descriptor] = self cls._cache[actor_creation_function_descriptor] = self
return self return self
@ -285,8 +299,8 @@ class ActorClassMetadata:
def __init__(self, language, modified_class, def __init__(self, language, modified_class,
actor_creation_function_descriptor, class_id, max_restarts, actor_creation_function_descriptor, class_id, max_restarts,
max_task_retries, num_cpus, num_gpus, memory, max_task_retries, num_cpus, num_gpus, memory,
object_store_memory, resources, accelerator_type, object_store_memory, resources, accelerator_type, runtime_env,
runtime_env): concurrency_groups):
self.language = language self.language = language
self.modified_class = modified_class self.modified_class = modified_class
self.actor_creation_function_descriptor = \ self.actor_creation_function_descriptor = \
@ -303,6 +317,7 @@ class ActorClassMetadata:
self.resources = resources self.resources = resources
self.accelerator_type = accelerator_type self.accelerator_type = accelerator_type
self.runtime_env = runtime_env self.runtime_env = runtime_env
self.concurrency_groups = concurrency_groups
self.last_export_session_and_job = None self.last_export_session_and_job = None
self.method_meta = ActorClassMethodMetadata.create( self.method_meta = ActorClassMethodMetadata.create(
modified_class, actor_creation_function_descriptor) modified_class, actor_creation_function_descriptor)
@ -358,10 +373,10 @@ class ActorClass:
f"use '{self.__ray_metadata__.class_name}.remote()'.") f"use '{self.__ray_metadata__.class_name}.remote()'.")
@classmethod @classmethod
def _ray_from_modified_class(cls, modified_class, class_id, max_restarts, def _ray_from_modified_class(
max_task_retries, num_cpus, num_gpus, memory, cls, modified_class, class_id, max_restarts, max_task_retries,
object_store_memory, resources, num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type, runtime_env): accelerator_type, runtime_env, concurrency_groups):
for attribute in [ for attribute in [
"remote", "remote",
"_remote", "_remote",
@ -398,7 +413,7 @@ class ActorClass:
Language.PYTHON, modified_class, Language.PYTHON, modified_class,
actor_creation_function_descriptor, class_id, max_restarts, actor_creation_function_descriptor, class_id, max_restarts,
max_task_retries, num_cpus, num_gpus, memory, object_store_memory, max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
resources, accelerator_type, new_runtime_env) resources, accelerator_type, new_runtime_env, concurrency_groups)
return self return self
@ -413,10 +428,12 @@ class ActorClass:
# .remote(), it would get run in the Ray Client server, which runs on # .remote(), it would get run in the Ray Client server, which runs on
# a remote node where the files aren't available. # a remote node where the files aren't available.
new_runtime_env = ParsedRuntimeEnv(runtime_env or {}) new_runtime_env = ParsedRuntimeEnv(runtime_env or {})
self.__ray_metadata__ = ActorClassMetadata( self.__ray_metadata__ = ActorClassMetadata(
language, None, actor_creation_function_descriptor, None, language, None, actor_creation_function_descriptor, None,
max_restarts, max_task_retries, num_cpus, num_gpus, memory, max_restarts, max_task_retries, num_cpus, num_gpus, memory,
object_store_memory, resources, accelerator_type, new_runtime_env) object_store_memory, resources, accelerator_type, new_runtime_env,
[])
return self return self
@ -740,6 +757,25 @@ class ActorClass:
parsed_runtime_env = override_task_or_actor_runtime_env( parsed_runtime_env = override_task_or_actor_runtime_env(
runtime_env, parent_runtime_env) runtime_env, parent_runtime_env)
concurrency_groups_dict = {}
for cg_name in meta.concurrency_groups:
concurrency_groups_dict[cg_name] = {
"name": cg_name,
"max_concurrency": meta.concurrency_groups[cg_name],
"function_descriptors": [],
}
# Update methods
for method_name in meta.method_meta.concurrency_group_for_methods:
cg_name = meta.method_meta.concurrency_group_for_methods[
method_name]
assert cg_name in concurrency_groups_dict
module_name = meta.actor_creation_function_descriptor.module_name
class_name = meta.actor_creation_function_descriptor.class_name
concurrency_groups_dict[cg_name]["function_descriptors"].append(
PythonFunctionDescriptor(module_name, method_name, class_name))
actor_id = worker.core_worker.create_actor( actor_id = worker.core_worker.create_actor(
meta.language, meta.language,
meta.actor_creation_function_descriptor, meta.actor_creation_function_descriptor,
@ -759,7 +795,8 @@ class ActorClass:
# Store actor_method_cpu in actor handle's extension data. # Store actor_method_cpu in actor handle's extension data.
extension_data=str(actor_method_cpu), extension_data=str(actor_method_cpu),
serialized_runtime_env=parsed_runtime_env.serialize(), serialized_runtime_env=parsed_runtime_env.serialize(),
runtime_env_uris=parsed_runtime_env.get_uris()) runtime_env_uris=parsed_runtime_env.get_uris(),
concurrency_groups_dict=concurrency_groups_dict or dict())
actor_handle = ActorHandle( actor_handle = ActorHandle(
meta.language, meta.language,
@ -1060,7 +1097,8 @@ def modify_class(cls):
def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources, def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type, max_restarts, max_task_retries, runtime_env): accelerator_type, max_restarts, max_task_retries, runtime_env,
concurrency_groups):
Class = modify_class(cls) Class = modify_class(cls)
_inject_tracing_into_class(Class) _inject_tracing_into_class(Class)
@ -1068,6 +1106,8 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
max_restarts = 0 max_restarts = 0
if max_task_retries is None: if max_task_retries is None:
max_task_retries = 0 max_task_retries = 0
if concurrency_groups is None:
concurrency_groups = []
infinite_restart = max_restarts == -1 infinite_restart = max_restarts == -1
if not infinite_restart: if not infinite_restart:
@ -1086,7 +1126,7 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
return ActorClass._ray_from_modified_class( return ActorClass._ray_from_modified_class(
Class, ActorClassID.from_random(), max_restarts, max_task_retries, Class, ActorClassID.from_random(), max_restarts, max_task_retries,
num_cpus, num_gpus, memory, object_store_memory, resources, num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type, runtime_env) accelerator_type, runtime_env, concurrency_groups)
def exit_actor(): def exit_actor():

View file

@ -2,7 +2,7 @@ from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.memory cimport shared_ptr, unique_ptr
from libcpp.string cimport string as c_string from libcpp.string cimport string as c_string
from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t, uint32_t
from libcpp.unordered_map cimport unordered_map from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector as c_vector from libcpp.vector cimport vector as c_vector
from libcpp.pair cimport pair as c_pair from libcpp.pair cimport pair as c_pair
@ -255,7 +255,8 @@ cdef extern from "ray/core_worker/common.h" nogil:
c_pair[CPlacementGroupID, int64_t] placement_options, c_pair[CPlacementGroupID, int64_t] placement_options,
c_bool placement_group_capture_child_tasks, c_bool placement_group_capture_child_tasks,
c_string serialized_runtime_env, c_string serialized_runtime_env,
c_vector[c_string] runtime_env_uris) c_vector[c_string] runtime_env_uris,
const c_vector[CConcurrencyGroup] &concurrency_groups)
cdef cppclass CPlacementGroupCreationOptions \ cdef cppclass CPlacementGroupCreationOptions \
"ray::core::PlacementGroupCreationOptions": "ray::core::PlacementGroupCreationOptions":
@ -283,3 +284,14 @@ cdef extern from "ray/gcs/gcs_client.h" nogil:
cdef extern from "src/ray/protobuf/gcs.pb.h" nogil: cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
cdef cppclass CJobConfig "ray::rpc::JobConfig": cdef cppclass CJobConfig "ray::rpc::JobConfig":
const c_string &SerializeAsString() const c_string &SerializeAsString()
cdef extern from "ray/common/task/task_spec.h" nogil:
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
CConcurrencyGroup(
const c_string &name,
uint32_t max_concurrency,
const c_vector[CFunctionDescriptor] &c_fds)
CConcurrencyGroup()
c_string GetName() const
uint32_t GetMaxConcurrency() const
c_vector[CFunctionDescriptor] GetFunctionDescriptors() const

View file

@ -43,6 +43,7 @@ from ray.includes.common cimport (
CGcsClientOptions, CGcsClientOptions,
LocalMemoryBuffer, LocalMemoryBuffer,
CJobConfig, CJobConfig,
CConcurrencyGroup,
) )
from ray.includes.function_descriptor cimport ( from ray.includes.function_descriptor cimport (
CFunctionDescriptor, CFunctionDescriptor,
@ -284,7 +285,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_vector[shared_ptr[CRayObject]] *returns, c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer] shared_ptr[LocalMemoryBuffer]
&creation_task_exception_pb_bytes, &creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil c_bool *is_application_level_error,
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
const c_string name_of_concurrency_group_to_execute) nogil
) task_execution_callback ) task_execution_callback
(void(const CWorkerID &) nogil) on_worker_shutdown (void(const CWorkerID &) nogil) on_worker_shutdown
(CRayStatus() nogil) check_signals (CRayStatus() nogil) check_signals

View file

@ -0,0 +1,114 @@
# coding: utf-8
import asyncio
import sys
import threading
import pytest
import ray
# This tests the methods are executed in the correct eventloop.
def test_basic():
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
class AsyncActor:
def __init__(self):
self.eventloop_f1 = None
self.eventloop_f2 = None
self.eventloop_f3 = None
self.eventloop_f4 = None
self.default_eventloop = asyncio.get_event_loop()
@ray.method(concurrency_group="io")
async def f1(self):
self.eventloop_f1 = asyncio.get_event_loop()
return threading.current_thread().ident
@ray.method(concurrency_group="io")
def f2(self):
self.eventloop_f2 = asyncio.get_event_loop()
return threading.current_thread().ident
@ray.method(concurrency_group="compute")
def f3(self):
self.eventloop_f3 = asyncio.get_event_loop()
return threading.current_thread().ident
@ray.method(concurrency_group="compute")
def f4(self):
self.eventloop_f4 = asyncio.get_event_loop()
return threading.current_thread().ident
def f5(self):
# If this method is executed in default eventloop.
assert asyncio.get_event_loop() == self.default_eventloop
return threading.current_thread().ident
@ray.method(concurrency_group="io")
def do_assert(self):
if self.eventloop_f1 != self.eventloop_f2:
return False
if self.eventloop_f3 != self.eventloop_f4:
return False
if self.eventloop_f1 == self.eventloop_f3:
return False
if self.eventloop_f1 == self.eventloop_f4:
return False
return True
###############################################
a = AsyncActor.remote()
f1_thread_id = ray.get(a.f1.remote()) # executed in the "io" group.
f2_thread_id = ray.get(a.f2.remote()) # executed in the "io" group.
f3_thread_id = ray.get(a.f3.remote()) # executed in the "compute" group.
f4_thread_id = ray.get(a.f4.remote()) # executed in the "compute" group.
assert f1_thread_id == f2_thread_id
assert f3_thread_id == f4_thread_id
assert f1_thread_id != f3_thread_id
assert ray.get(a.do_assert.remote())
assert ray.get(a.f5.remote()) # executed in the default group.
# It also has the ability to specify it at runtime.
# This task will be invoked in the `compute` thread pool.
a.f2.options(concurrency_group="compute").remote()
# The case tests that the asyncio count down works well in one concurrency
# group.
def test_async_methods_in_concurrency_group():
@ray.remote(concurrency_groups={"async": 3})
class AsyncBatcher:
def __init__(self):
self.batch = []
self.event = None
@ray.method(concurrency_group="async")
def init_event(self):
self.event = asyncio.Event()
return True
@ray.method(concurrency_group="async")
async def add(self, x):
self.batch.append(x)
if len(self.batch) >= 3:
self.event.set()
else:
await self.event.wait()
return sorted(self.batch)
a = AsyncBatcher.remote()
ray.get(a.init_event.remote())
x1 = a.add.remote(1)
x2 = a.add.remote(2)
x3 = a.add.remote(3)
r1 = ray.get(x1)
r2 = ray.get(x2)
r3 = ray.get(x3)
assert r1 == [1, 2, 3]
assert r1 == r2 == r3
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1928,7 +1928,8 @@ def make_decorator(num_returns=None,
runtime_env=None, runtime_env=None,
placement_group="default", placement_group="default",
worker=None, worker=None,
retry_exceptions=None): retry_exceptions=None,
concurrency_groups=None):
def decorator(function_or_class): def decorator(function_or_class):
if (inspect.isfunction(function_or_class) if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)): or is_cython(function_or_class)):
@ -1983,10 +1984,10 @@ def make_decorator(num_returns=None,
raise ValueError( raise ValueError(
"The keyword 'max_task_retries' only accepts -1, 0 or a" "The keyword 'max_task_retries' only accepts -1, 0 or a"
" positive integer") " positive integer")
return ray.actor.make_actor(function_or_class, num_cpus, num_gpus, return ray.actor.make_actor(
memory, object_store_memory, resources, function_or_class, num_cpus, num_gpus, memory,
accelerator_type, max_restarts, object_store_memory, resources, accelerator_type, max_restarts,
max_task_retries, runtime_env) max_task_retries, runtime_env, concurrency_groups)
raise TypeError("The @ray.remote decorator must be applied to " raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.") "either a function or to a class.")
@ -2105,10 +2106,21 @@ def remote(*args, **kwargs):
# Parse the keyword arguments from the decorator. # Parse the keyword arguments from the decorator.
valid_kwargs = [ valid_kwargs = [
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory", "num_returns",
"resources", "accelerator_type", "max_calls", "max_restarts", "num_cpus",
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions", "num_gpus",
"placement_group" "memory",
"object_store_memory",
"resources",
"accelerator_type",
"max_calls",
"max_restarts",
"max_task_retries",
"max_retries",
"runtime_env",
"retry_exceptions",
"placement_group",
"concurrency_groups",
] ]
error_string = ("The @ray.remote decorator must be applied either " error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example " "with no arguments and no parentheses, for example "
@ -2143,6 +2155,7 @@ def remote(*args, **kwargs):
runtime_env = kwargs.get("runtime_env") runtime_env = kwargs.get("runtime_env")
placement_group = kwargs.get("placement_group", "default") placement_group = kwargs.get("placement_group", "default")
retry_exceptions = kwargs.get("retry_exceptions") retry_exceptions = kwargs.get("retry_exceptions")
concurrency_groups = kwargs.get("concurrency_groups")
return make_decorator( return make_decorator(
num_returns=num_returns, num_returns=num_returns,
@ -2159,4 +2172,5 @@ def remote(*args, **kwargs):
runtime_env=runtime_env, runtime_env=runtime_env,
placement_group=placement_group, placement_group=placement_group,
worker=worker, worker=worker,
retry_exceptions=retry_exceptions) retry_exceptions=retry_exceptions,
concurrency_groups=concurrency_groups or [])

View file

@ -43,6 +43,20 @@ struct ConcurrencyGroup {
uint32_t max_concurrency; uint32_t max_concurrency;
// Function descriptors of the actor methods in this group. // Function descriptors of the actor methods in this group.
std::vector<ray::FunctionDescriptor> function_descriptors; std::vector<ray::FunctionDescriptor> function_descriptors;
ConcurrencyGroup() = default;
ConcurrencyGroup(const std::string &name, uint32_t max_concurrency,
const std::vector<ray::FunctionDescriptor> &fds)
: name(name), max_concurrency(max_concurrency), function_descriptors(fds) {}
std::string GetName() const { return name; }
uint32_t GetMaxConcurrency() const { return max_concurrency; }
std::vector<ray::FunctionDescriptor> GetFunctionDescriptors() const {
return function_descriptors;
}
}; };
static inline rpc::ObjectReference GetReferenceForActorDummyObject( static inline rpc::ObjectReference GetReferenceForActorDummyObject(

View file

@ -2273,11 +2273,21 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID()); CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr; std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr;
std::vector<ConcurrencyGroup> defined_concurrency_groups = {};
std::string name_of_concurrency_group_to_execute;
if (task_spec.IsActorCreationTask()) {
defined_concurrency_groups = task_spec.ConcurrencyGroups();
} else if (task_spec.IsActorTask()) {
name_of_concurrency_group_to_execute = task_spec.ConcurrencyGroupName();
}
status = options_.task_execution_callback( status = options_.task_execution_callback(
task_type, task_spec.GetName(), func, task_type, task_spec.GetName(), func,
task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs, task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects, return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
creation_task_exception_pb_bytes, is_application_level_error); creation_task_exception_pb_bytes, is_application_level_error,
defined_concurrency_groups, name_of_concurrency_group_to_execute);
// Get the reference counts for any IDs that we borrowed during this task, // Get the reference counts for any IDs that we borrowed during this task,
// remove the local reference for these IDs, and return the ref count info to // remove the local reference for these IDs, and return the ref count info to

View file

@ -71,7 +71,15 @@ struct CoreWorkerOptions {
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<RayObject>> *results, std::vector<std::shared_ptr<RayObject>> *results,
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes, std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error)>; bool *is_application_level_error,
// The following 2 parameters `defined_concurrency_groups` and
// `name_of_concurrency_group_to_execute` are used for Python
// asyncio actor only.
//
// Defined concurrency groups of this actor. Note this is only
// used for actor creation task.
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
const std::string name_of_concurrency_group_to_execute)>;
CoreWorkerOptions() CoreWorkerOptions()
: store_socket(""), : store_socket(""),

View file

@ -101,7 +101,13 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<RayObject>> *results, std::vector<std::shared_ptr<RayObject>> *results,
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb, std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
bool *is_application_level_error) { bool *is_application_level_error,
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
const std::string name_of_concurrency_group_to_execute) {
// These 2 parameters are used for Python only, and Java worker
// will not use them.
RAY_UNUSED(defined_concurrency_groups);
RAY_UNUSED(name_of_concurrency_group_to_execute);
// TODO(jjyao): Support retrying application-level errors for Java // TODO(jjyao): Support retrying application-level errors for Java
*is_application_level_error = false; *is_application_level_error = false;

View file

@ -509,6 +509,21 @@ class MockWorkerContext : public WorkerContext {
} }
}; };
class MockCoreWorkerDirectTaskReceiver : public CoreWorkerDirectTaskReceiver {
public:
MockCoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
instrumented_io_context &main_io_service,
const TaskHandler &task_handler,
const OnTaskDone &task_done)
: CoreWorkerDirectTaskReceiver(worker_context, main_io_service, task_handler,
task_done) {}
void UpdateConcurrencyGroupsCache(const ActorID &actor_id,
const std::vector<ConcurrencyGroup> &cgs) {
concurrency_groups_cache_[actor_id] = cgs;
}
};
class DirectActorReceiverTest : public ::testing::Test { class DirectActorReceiverTest : public ::testing::Test {
public: public:
DirectActorReceiverTest() DirectActorReceiverTest()
@ -518,7 +533,7 @@ class DirectActorReceiverTest : public ::testing::Test {
auto execute_task = auto execute_task =
std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1, std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>( receiver_ = std::make_unique<MockCoreWorkerDirectTaskReceiver>(
worker_context_, main_io_service_, execute_task, [] { return Status::OK(); }); worker_context_, main_io_service_, execute_task, [] { return Status::OK(); });
receiver_->Init(std::make_shared<rpc::CoreWorkerClientPool>( receiver_->Init(std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &addr) { return worker_client_; }), [&](const rpc::Address &addr) { return worker_client_; }),
@ -541,7 +556,7 @@ class DirectActorReceiverTest : public ::testing::Test {
main_io_service_.stop(); main_io_service_.stop();
} }
std::unique_ptr<CoreWorkerDirectTaskReceiver> receiver_; std::unique_ptr<MockCoreWorkerDirectTaskReceiver> receiver_;
private: private:
rpc::Address rpc_address_; rpc::Address rpc_address_;
@ -575,6 +590,7 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) {
++callback_count; ++callback_count;
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
}; };
receiver_->UpdateConcurrencyGroupsCache(actor_id, {});
receiver_->HandleTask(request, &reply, reply_callback); receiver_->HandleTask(request, &reply, reply_callback);
} }

View file

@ -524,6 +524,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
return_object->add_nested_inlined_refs()->CopyFrom(nested_ref); return_object->add_nested_inlined_refs()->CopyFrom(nested_ref);
} }
} }
if (task_spec.IsActorCreationTask()) { if (task_spec.IsActorCreationTask()) {
/// The default max concurrency for creating PoolManager should /// The default max concurrency for creating PoolManager should
/// be 0 if this is an asyncio actor. /// be 0 if this is an asyncio actor.
@ -531,6 +532,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency(); task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency();
pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(), pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(),
default_max_concurrency); default_max_concurrency);
concurrency_groups_cache_[task_spec.TaskId().ActorId()] =
task_spec.ConcurrencyGroups();
RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId() RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId()
<< ", actor_id: " << task_spec.ActorCreationId(); << ", actor_id: " << task_spec.ActorCreationId();
// Tell raylet that an actor creation task has finished execution, so that // Tell raylet that an actor creation task has finished execution, so that
@ -573,11 +576,13 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
if (task_spec.IsActorTask()) { if (task_spec.IsActorTask()) {
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId()); auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
if (it == actor_scheduling_queues_.end()) { if (it == actor_scheduling_queues_.end()) {
auto cg_it = concurrency_groups_cache_.find(task_spec.ActorId());
RAY_CHECK(cg_it != concurrency_groups_cache_.end());
auto result = actor_scheduling_queues_.emplace( auto result = actor_scheduling_queues_.emplace(
task_spec.CallerWorkerId(), task_spec.CallerWorkerId(),
std::unique_ptr<SchedulingQueue>( std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
new ActorSchedulingQueue(task_main_io_service_, *waiter_, pool_manager_, task_main_io_service_, *waiter_, pool_manager_, is_asyncio_,
is_asyncio_, fiber_max_concurrency_))); fiber_max_concurrency_, cg_it->second)));
it = result.first; it = result.first;
} }

View file

@ -283,6 +283,69 @@ class CoreWorkerDirectActorTaskSubmitter
friend class CoreWorkerTest; friend class CoreWorkerTest;
}; };
/// The class that manages fiber states for Python asyncio actors.
///
/// We'll create one fiber state for every concurrency group. And
/// create one default fiber state for default concurrency group if
/// necessary.
class FiberStateManager final {
public:
explicit FiberStateManager(const std::vector<ConcurrencyGroup> &concurrency_groups = {},
const int32_t default_group_max_concurrency = 1000) {
for (auto &group : concurrency_groups) {
const auto name = group.name;
const auto max_concurrency = group.max_concurrency;
auto fiber = std::make_shared<FiberState>(max_concurrency);
auto &fds = group.function_descriptors;
for (auto fd : fds) {
functions_to_fiber_index_[fd->ToString()] = fiber;
}
name_to_fiber_index_[name] = fiber;
}
/// Create default fiber state for default concurrency group.
if (default_group_max_concurrency >= 1) {
default_fiber_ = std::make_shared<FiberState>(default_group_max_concurrency);
}
}
/// Get the corresponding fiber state by the give concurrency group or function
/// descriptor.
///
/// Return the corresponding fiber state of the concurrency group
/// if concurrency_group_name is given.
/// Otherwise return the corresponding fiber state by the given function descriptor.
std::shared_ptr<FiberState> GetFiber(const std::string &concurrency_group_name,
ray::FunctionDescriptor fd) {
if (!concurrency_group_name.empty()) {
auto it = name_to_fiber_index_.find(concurrency_group_name);
RAY_CHECK(it != name_to_fiber_index_.end())
<< "Failed to look up the fiber state of the given concurrency group "
<< concurrency_group_name << " . It might be that you didn't define "
<< "the concurrency group " << concurrency_group_name;
return it->second;
}
/// Code path of that this task wasn't specified in a concurrency group addtionally.
/// Use the predefined concurrency group.
if (functions_to_fiber_index_.find(fd->ToString()) !=
functions_to_fiber_index_.end()) {
return functions_to_fiber_index_[fd->ToString()];
}
return default_fiber_;
}
private:
// Map from the name to their corresponding fibers.
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> name_to_fiber_index_;
// Map from the FunctionDescriptors to their corresponding fibers.
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> functions_to_fiber_index_;
// The fiber for default concurrency group. It's nullptr if its max concurrency
// is 1.
std::shared_ptr<FiberState> default_fiber_ = nullptr;
};
class BoundedExecutor; class BoundedExecutor;
/// A manager that manages a set of thread pool. which will perform /// A manager that manages a set of thread pool. which will perform
@ -489,6 +552,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
instrumented_io_context &main_io_service, DependencyWaiter &waiter, instrumented_io_context &main_io_service, DependencyWaiter &waiter,
std::shared_ptr<PoolManager> pool_manager = std::make_shared<PoolManager>(), std::shared_ptr<PoolManager> pool_manager = std::make_shared<PoolManager>(),
bool is_asyncio = false, int fiber_max_concurrency = 1, bool is_asyncio = false, int fiber_max_concurrency = 1,
const std::vector<ConcurrencyGroup> &concurrency_groups = {},
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds) int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
: reorder_wait_seconds_(reorder_wait_seconds), : reorder_wait_seconds_(reorder_wait_seconds),
wait_timer_(main_io_service), wait_timer_(main_io_service),
@ -497,9 +561,16 @@ class ActorSchedulingQueue : public SchedulingQueue {
pool_manager_(pool_manager), pool_manager_(pool_manager),
is_asyncio_(is_asyncio) { is_asyncio_(is_asyncio) {
if (is_asyncio_) { if (is_asyncio_) {
RAY_LOG(INFO) << "Setting actor as async with max_concurrency=" std::stringstream ss;
<< fiber_max_concurrency << ", creating new fiber thread."; ss << "Setting actor as asyncio with max_concurrency=" << fiber_max_concurrency
fiber_state_ = std::make_unique<FiberState>(fiber_max_concurrency); << ", and defined concurrency groups are:" << std::endl;
for (const auto &concurrency_group : concurrency_groups) {
ss << "\t" << concurrency_group.name << " : "
<< concurrency_group.max_concurrency;
}
RAY_LOG(INFO) << ss.str();
fiber_state_manager_ =
std::make_unique<FiberStateManager>(concurrency_groups, fiber_max_concurrency);
} }
} }
@ -595,7 +666,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
if (is_asyncio_) { if (is_asyncio_) {
// Process async actor task. // Process async actor task.
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); }); auto fiber = fiber_state_manager_->GetFiber(request.ConcurrencyGroupName(),
request.FunctionDescriptor());
fiber->EnqueueFiber([request]() mutable { request.Accept(); });
} else { } else {
// Process actor tasks. // Process actor tasks.
RAY_CHECK(pool_manager_ != nullptr); RAY_CHECK(pool_manager_ != nullptr);
@ -661,9 +734,10 @@ class ActorSchedulingQueue : public SchedulingQueue {
/// Whether we should enqueue requests into asyncio pool. Setting this to true /// Whether we should enqueue requests into asyncio pool. Setting this to true
/// will instantiate all tasks as fibers that can be yielded. /// will instantiate all tasks as fibers that can be yielded.
bool is_asyncio_ = false; bool is_asyncio_ = false;
/// If is_asyncio_ is true, fiber_state_ contains the running state required /// Manage the running fiber states of actors in this worker. It works with
/// to enable continuation and work together with python asyncio. /// python asyncio if this is an asyncio actor.
std::unique_ptr<FiberState> fiber_state_; std::unique_ptr<FiberStateManager> fiber_state_manager_;
friend class SchedulingQueueTest; friend class SchedulingQueueTest;
}; };
@ -822,6 +896,10 @@ class CoreWorkerDirectTaskReceiver {
bool CancelQueuedNormalTask(TaskID task_id); bool CancelQueuedNormalTask(TaskID task_id);
protected:
/// Cache the concurrency groups of actors.
absl::flat_hash_map<ActorID, std::vector<ConcurrencyGroup>> concurrency_groups_cache_;
private: private:
// Worker context. // Worker context.
WorkerContext &worker_context_; WorkerContext &worker_context_;