mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core] Port concurrency groups with asyncio (#18567)
## Why are these changes needed? This PR aims to port concurrency groups functionality with asyncio for Python. ### API ```python @ray.remote(concurrency_groups={"io": 2, "compute": 4}) class AsyncActor: def __init__(self): pass @ray.method(concurrency_group="io") async def f1(self): pass @ray.method(concurrency_group="io") def f2(self): pass @ray.method(concurrency_group="compute") def f3(self): pass @ray.method(concurrency_group="compute") def f4(self): pass def f5(self): pass ``` The annotation above the actor class `AsyncActor` defines this actor will have 2 concurrency groups and defines their max concurrencies, and it has a default concurrency group. Every concurrency group has an async eventloop and a pythread to execute the methods which is defined on them. Method `f1` will be invoked in the `io` concurrency group. `f2` in `io`, `f3` in `compute` and etc. TO BE NOTICED, `f5` and `__init__` will be invoked in the default concurrency. The following method `f2` will be invoked in the concurrency group `compute` since the dynamic specifying has a higher priority. ```python a.f2.options(concurrency_group="compute").remote() ``` ### Implementation The straightforward implementation details are: - Before we only have 1 eventloop binding 1 pythread for an asyncio actor. Now we create 1 eventloop binding 1 pythread for every concurrency group of the asyncio actor. - Before we have 1 fiber state for every caller in the asyncio actor. Now we create a FiberStateManager for every caller in the asyncio actor. And the FiberStateManager manages the fiber states for concurrency groups. ## Related issue number #16047
This commit is contained in:
parent
a04b02e2e8
commit
048e7f7d5d
18 changed files with 623 additions and 67 deletions
|
@ -129,7 +129,9 @@ Status TaskExecutor::ExecuteTask(
|
||||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||||
bool *is_application_level_error) {
|
bool *is_application_level_error,
|
||||||
|
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||||
|
const std::string name_of_concurrency_group_to_execute) {
|
||||||
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
|
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
|
||||||
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
|
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
|
||||||
auto function_descriptor = ray_function.GetFunctionDescriptor();
|
auto function_descriptor = ray_function.GetFunctionDescriptor();
|
||||||
|
|
|
@ -82,7 +82,9 @@ class TaskExecutor {
|
||||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||||
bool *is_application_level_error);
|
bool *is_application_level_error,
|
||||||
|
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||||
|
const std::string name_of_concurrency_group_to_execute);
|
||||||
|
|
||||||
virtual ~TaskExecutor(){};
|
virtual ~TaskExecutor(){};
|
||||||
|
|
||||||
|
|
101
doc/source/concurrency_group_api.rst
Normal file
101
doc/source/concurrency_group_api.rst
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
Limiting Concurrency Per-Method with Concurrency Groups
|
||||||
|
=======================================================
|
||||||
|
|
||||||
|
Besides setting the max concurrency overall for an asyncio actor, Ray allows methods to be separated into *concurrency groups*, each with its own asyncio event loop. This allows you to limit the concurrency per-method, e.g., allow a health-check method to be given its own concurrency quota separate from request serving methods.
|
||||||
|
|
||||||
|
.. warning:: Concurrency groups are only supported for asyncio actors, not threaded actors.
|
||||||
|
|
||||||
|
.. _defining-concurrency-groups:
|
||||||
|
|
||||||
|
Defining Concurrency Groups
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
You can define concurrency groups for asyncio actors using the ``concurrency_groups`` decorator argument:
|
||||||
|
|
||||||
|
.. tabs::
|
||||||
|
.. group-tab:: Python
|
||||||
|
|
||||||
|
This defines two concurrency groups, "io" with max_concurrency=2 and
|
||||||
|
"compute" with max_concurrency=4. The methods ``f1`` and ``f2`` are
|
||||||
|
placed in the "io" group, and the methods ``f3`` and ``f4`` are placed
|
||||||
|
into the "compute" group. Note that there is always a default
|
||||||
|
concurrency group, which has a default concurrency of 1000.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
|
||||||
|
class AsyncIOActor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="io")
|
||||||
|
async def f1(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="io")
|
||||||
|
async def f2(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
async def f3(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
async def f4(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def f5(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
a = AsyncIOActor.remote()
|
||||||
|
a.f1.remote() # executed in the "io" group.
|
||||||
|
a.f2.remote() # executed in the "io" group.
|
||||||
|
a.f3.remote() # executed in the "compute" group.
|
||||||
|
a.f4.remote() # executed in the "compute" group.
|
||||||
|
a.f5.remote() # executed in the default group.
|
||||||
|
|
||||||
|
|
||||||
|
.. _default-concurrency-group:
|
||||||
|
|
||||||
|
Default Concurrency Group
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
By default, methods are placed in a default concurrency group which has a concurrency limit of 1000.
|
||||||
|
The concurrency of the default group can be changed by setting the ``max_concurrency`` actor option.
|
||||||
|
|
||||||
|
.. tabs::
|
||||||
|
.. group-tab:: Python
|
||||||
|
|
||||||
|
The following AsyncIOActor has 2 concurrency groups: "io" and "default".
|
||||||
|
The max concurrency of "io" is 2, and the max concurrency of "default" is 10.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ray.remote(concurrency_groups={"io": 2)
|
||||||
|
class AsyncIOActor:
|
||||||
|
async def f1(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
actor = AsyncIOActor.options(max_concurrency=10).remote()
|
||||||
|
|
||||||
|
|
||||||
|
.. _setting-the-concurrency-group-at-runtime:
|
||||||
|
|
||||||
|
Setting the Concurrency Group at Runtime
|
||||||
|
----------------------------------------
|
||||||
|
|
||||||
|
You can also dispatch actor methods into a specific concurrency group at runtime using the ``.options`` method:
|
||||||
|
|
||||||
|
.. tabs::
|
||||||
|
.. group-tab:: Python
|
||||||
|
|
||||||
|
The following snippet demonstrates setting the concurrency group of the
|
||||||
|
``f2`` method dynamically at runtime.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Executed in the "io" group (as defined in the actor class).
|
||||||
|
a.f2.options().remote()
|
||||||
|
|
||||||
|
# Executed in the "compute" group.
|
||||||
|
a.f2.options(concurrency_group="compute").remote()
|
|
@ -14,6 +14,7 @@ Finally, we've also included some content on using core Ray APIs with `Tensorflo
|
||||||
actors.rst
|
actors.rst
|
||||||
namespaces.rst
|
namespaces.rst
|
||||||
async_api.rst
|
async_api.rst
|
||||||
|
concurrency_group_api.rst
|
||||||
using-ray-with-gpus.rst
|
using-ray-with-gpus.rst
|
||||||
serialization.rst
|
serialization.rst
|
||||||
memory-management.rst
|
memory-management.rst
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ray.includes.common cimport (
|
||||||
CBuffer,
|
CBuffer,
|
||||||
CRayObject,
|
CRayObject,
|
||||||
CAddress,
|
CAddress,
|
||||||
|
CConcurrencyGroup,
|
||||||
)
|
)
|
||||||
from ray.includes.libcoreworker cimport (
|
from ray.includes.libcoreworker cimport (
|
||||||
ActorHandleSharedPtr,
|
ActorHandleSharedPtr,
|
||||||
|
@ -117,6 +118,11 @@ cdef class CoreWorker:
|
||||||
object current_runtime_env
|
object current_runtime_env
|
||||||
c_bool is_local_mode
|
c_bool is_local_mode
|
||||||
|
|
||||||
|
object cgname_to_eventloop_dict
|
||||||
|
object eventloop_for_default_cg
|
||||||
|
object thread_for_default_cg
|
||||||
|
object fd_to_cgname_dict
|
||||||
|
|
||||||
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
||||||
size_t data_size, ObjectRef object_ref,
|
size_t data_size, ObjectRef object_ref,
|
||||||
c_vector[CObjectID] contained_ids,
|
c_vector[CObjectID] contained_ids,
|
||||||
|
@ -130,6 +136,10 @@ cdef class CoreWorker:
|
||||||
c_vector[shared_ptr[CRayObject]] *returns)
|
c_vector[shared_ptr[CRayObject]] *returns)
|
||||||
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
|
cdef yield_current_fiber(self, CFiberEvent &fiber_event)
|
||||||
cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle)
|
cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle)
|
||||||
|
cdef c_function_descriptors_to_python(
|
||||||
|
self, const c_vector[CFunctionDescriptor] &c_function_descriptors)
|
||||||
|
cdef initialize_eventloops_for_actor_concurrency_group(
|
||||||
|
self, const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups)
|
||||||
|
|
||||||
cdef class FunctionDescriptor:
|
cdef class FunctionDescriptor:
|
||||||
cdef:
|
cdef:
|
||||||
|
|
|
@ -60,6 +60,7 @@ from ray.includes.common cimport (
|
||||||
CRayFunction,
|
CRayFunction,
|
||||||
CWorkerType,
|
CWorkerType,
|
||||||
CJobConfig,
|
CJobConfig,
|
||||||
|
CConcurrencyGroup,
|
||||||
move,
|
move,
|
||||||
LANGUAGE_CPP,
|
LANGUAGE_CPP,
|
||||||
LANGUAGE_JAVA,
|
LANGUAGE_JAVA,
|
||||||
|
@ -316,6 +317,34 @@ cdef int prepare_resources(
|
||||||
resource_map[0][key.encode("ascii")] = float(value)
|
resource_map[0][key.encode("ascii")] = float(value)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
cdef c_vector[CFunctionDescriptor] prepare_function_descriptors(pyfd_list):
|
||||||
|
cdef:
|
||||||
|
c_vector[CFunctionDescriptor] fd_list
|
||||||
|
CRayFunction ray_function
|
||||||
|
|
||||||
|
for pyfd in pyfd_list:
|
||||||
|
fd_list.push_back(CFunctionDescriptorBuilder.BuildPython(
|
||||||
|
pyfd.module_name, pyfd.class_name, pyfd.function_name, b""))
|
||||||
|
return fd_list
|
||||||
|
|
||||||
|
|
||||||
|
cdef int prepare_actor_concurrency_groups(
|
||||||
|
dict concurrency_groups_dict,
|
||||||
|
c_vector[CConcurrencyGroup] *concurrency_groups):
|
||||||
|
|
||||||
|
cdef:
|
||||||
|
CConcurrencyGroup cg
|
||||||
|
c_vector[CFunctionDescriptor] c_fd_list
|
||||||
|
|
||||||
|
if concurrency_groups_dict is None:
|
||||||
|
raise ValueError("Must provide it...")
|
||||||
|
|
||||||
|
for key, value in concurrency_groups_dict.items():
|
||||||
|
c_fd_list = prepare_function_descriptors(value["function_descriptors"])
|
||||||
|
cg = CConcurrencyGroup(
|
||||||
|
key.encode("ascii"), value["max_concurrency"], c_fd_list)
|
||||||
|
concurrency_groups.push_back(cg)
|
||||||
|
return 1
|
||||||
|
|
||||||
cdef prepare_args(
|
cdef prepare_args(
|
||||||
CoreWorker core_worker,
|
CoreWorker core_worker,
|
||||||
|
@ -411,7 +440,11 @@ cdef execute_task(
|
||||||
const c_vector[CObjectID] &c_return_ids,
|
const c_vector[CObjectID] &c_return_ids,
|
||||||
const c_string debugger_breakpoint,
|
const c_string debugger_breakpoint,
|
||||||
c_vector[shared_ptr[CRayObject]] *returns,
|
c_vector[shared_ptr[CRayObject]] *returns,
|
||||||
c_bool *is_application_level_error):
|
c_bool *is_application_level_error,
|
||||||
|
# This parameter is only used for actor creation task to define
|
||||||
|
# the concurrency groups of this actor.
|
||||||
|
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups,
|
||||||
|
const c_string c_name_of_concurrency_group_to_execute):
|
||||||
|
|
||||||
is_application_level_error[0] = False
|
is_application_level_error[0] = False
|
||||||
|
|
||||||
|
@ -462,6 +495,11 @@ cdef execute_task(
|
||||||
print(actor_magic_token)
|
print(actor_magic_token)
|
||||||
print(actor_magic_token, file=sys.stderr)
|
print(actor_magic_token, file=sys.stderr)
|
||||||
|
|
||||||
|
# Initial eventloops for asyncio for this actor.
|
||||||
|
if core_worker.current_actor_is_asyncio():
|
||||||
|
core_worker.initialize_eventloops_for_actor_concurrency_group(
|
||||||
|
c_defined_concurrency_groups)
|
||||||
|
|
||||||
execution_info = execution_infos.get(function_descriptor)
|
execution_info = execution_infos.get(function_descriptor)
|
||||||
if not execution_info:
|
if not execution_info:
|
||||||
execution_info = manager.get_execution_info(
|
execution_info = manager.get_execution_info(
|
||||||
|
@ -473,6 +511,8 @@ cdef execute_task(
|
||||||
b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
|
b' "task_id": ' + task_id.hex().encode("ascii") + b'}')
|
||||||
|
|
||||||
task_name = name.decode("utf-8")
|
task_name = name.decode("utf-8")
|
||||||
|
name_of_concurrency_group_to_execute = \
|
||||||
|
c_name_of_concurrency_group_to_execute.decode("ascii")
|
||||||
title = f"ray::{task_name}"
|
title = f"ray::{task_name}"
|
||||||
|
|
||||||
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
|
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
|
||||||
|
@ -520,7 +560,9 @@ cdef execute_task(
|
||||||
async_function = sync_to_async(function)
|
async_function = sync_to_async(function)
|
||||||
|
|
||||||
return core_worker.run_async_func_in_event_loop(
|
return core_worker.run_async_func_in_event_loop(
|
||||||
async_function, actor, *arguments, **kwarguments)
|
async_function, function_descriptor,
|
||||||
|
name_of_concurrency_group_to_execute, actor,
|
||||||
|
*arguments, **kwarguments)
|
||||||
|
|
||||||
return function(actor, *arguments, **kwarguments)
|
return function(actor, *arguments, **kwarguments)
|
||||||
|
|
||||||
|
@ -546,7 +588,8 @@ cdef execute_task(
|
||||||
.deserialize_objects(
|
.deserialize_objects(
|
||||||
metadata_pairs, object_refs))
|
metadata_pairs, object_refs))
|
||||||
args = core_worker.run_async_func_in_event_loop(
|
args = core_worker.run_async_func_in_event_loop(
|
||||||
deserialize_args)
|
deserialize_args, function_descriptor,
|
||||||
|
name_of_concurrency_group_to_execute)
|
||||||
else:
|
else:
|
||||||
args = ray.worker.global_worker.deserialize_objects(
|
args = ray.worker.global_worker.deserialize_objects(
|
||||||
metadata_pairs, object_refs)
|
metadata_pairs, object_refs)
|
||||||
|
@ -692,7 +735,9 @@ cdef CRayStatus task_execution_handler(
|
||||||
const c_string debugger_breakpoint,
|
const c_string debugger_breakpoint,
|
||||||
c_vector[shared_ptr[CRayObject]] *returns,
|
c_vector[shared_ptr[CRayObject]] *returns,
|
||||||
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
|
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
|
||||||
c_bool *is_application_level_error) nogil:
|
c_bool *is_application_level_error,
|
||||||
|
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
|
||||||
|
const c_string name_of_concurrency_group_to_execute) nogil:
|
||||||
with gil, disable_client_hook():
|
with gil, disable_client_hook():
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
|
@ -701,7 +746,9 @@ cdef CRayStatus task_execution_handler(
|
||||||
execute_task(task_type, task_name, ray_function, c_resources,
|
execute_task(task_type, task_name, ray_function, c_resources,
|
||||||
c_args, c_arg_refs, c_return_ids,
|
c_args, c_arg_refs, c_return_ids,
|
||||||
debugger_breakpoint, returns,
|
debugger_breakpoint, returns,
|
||||||
is_application_level_error)
|
is_application_level_error,
|
||||||
|
defined_concurrency_groups,
|
||||||
|
name_of_concurrency_group_to_execute)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
sys_exit = SystemExit()
|
sys_exit = SystemExit()
|
||||||
if isinstance(e, RayActorError) and \
|
if isinstance(e, RayActorError) and \
|
||||||
|
@ -1020,6 +1067,10 @@ cdef class CoreWorker:
|
||||||
options.startup_token = startup_token
|
options.startup_token = startup_token
|
||||||
CCoreWorkerProcess.Initialize(options)
|
CCoreWorkerProcess.Initialize(options)
|
||||||
|
|
||||||
|
self.cgname_to_eventloop_dict = None
|
||||||
|
self.fd_to_cgname_dict = None
|
||||||
|
self.eventloop_for_default_cg = None
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
with nogil:
|
with nogil:
|
||||||
# If it's a worker, the core worker process should have been
|
# If it's a worker, the core worker process should have been
|
||||||
|
@ -1414,6 +1465,7 @@ cdef class CoreWorker:
|
||||||
c_string extension_data,
|
c_string extension_data,
|
||||||
c_string serialized_runtime_env,
|
c_string serialized_runtime_env,
|
||||||
runtime_env_uris,
|
runtime_env_uris,
|
||||||
|
concurrency_groups_dict,
|
||||||
):
|
):
|
||||||
cdef:
|
cdef:
|
||||||
CRayFunction ray_function
|
CRayFunction ray_function
|
||||||
|
@ -1425,6 +1477,7 @@ cdef class CoreWorker:
|
||||||
CPlacementGroupID c_placement_group_id = \
|
CPlacementGroupID c_placement_group_id = \
|
||||||
placement_group_id.native()
|
placement_group_id.native()
|
||||||
c_vector[c_string] c_runtime_env_uris = runtime_env_uris
|
c_vector[c_string] c_runtime_env_uris = runtime_env_uris
|
||||||
|
c_vector[CConcurrencyGroup] c_concurrency_groups
|
||||||
|
|
||||||
with self.profile_event(b"submit_task"):
|
with self.profile_event(b"submit_task"):
|
||||||
prepare_resources(resources, &c_resources)
|
prepare_resources(resources, &c_resources)
|
||||||
|
@ -1432,6 +1485,8 @@ cdef class CoreWorker:
|
||||||
ray_function = CRayFunction(
|
ray_function = CRayFunction(
|
||||||
language.lang, function_descriptor.descriptor)
|
language.lang, function_descriptor.descriptor)
|
||||||
prepare_args(self, language, args, &args_vector)
|
prepare_args(self, language, args, &args_vector)
|
||||||
|
prepare_actor_concurrency_groups(
|
||||||
|
concurrency_groups_dict, &c_concurrency_groups)
|
||||||
|
|
||||||
with nogil:
|
with nogil:
|
||||||
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
||||||
|
@ -1447,7 +1502,8 @@ cdef class CoreWorker:
|
||||||
placement_group_bundle_index),
|
placement_group_bundle_index),
|
||||||
placement_group_capture_child_tasks,
|
placement_group_capture_child_tasks,
|
||||||
serialized_runtime_env,
|
serialized_runtime_env,
|
||||||
c_runtime_env_uris),
|
c_runtime_env_uris,
|
||||||
|
c_concurrency_groups),
|
||||||
extension_data,
|
extension_data,
|
||||||
&c_actor_id))
|
&c_actor_id))
|
||||||
|
|
||||||
|
@ -1814,32 +1870,96 @@ cdef class CoreWorker:
|
||||||
CCoreWorkerProcess.GetCoreWorker().SealReturnObject(
|
CCoreWorkerProcess.GetCoreWorker().SealReturnObject(
|
||||||
return_id, returns[0][i]))
|
return_id, returns[0][i]))
|
||||||
|
|
||||||
def create_or_get_event_loop(self):
|
cdef c_function_descriptors_to_python(
|
||||||
if self.async_event_loop is None:
|
self,
|
||||||
self.async_event_loop = get_new_event_loop()
|
const c_vector[CFunctionDescriptor] &c_function_descriptors):
|
||||||
asyncio.set_event_loop(self.async_event_loop)
|
|
||||||
|
|
||||||
if self.async_thread is None:
|
ret = []
|
||||||
self.async_thread = threading.Thread(
|
for i in range(c_function_descriptors.size()):
|
||||||
target=lambda: self.async_event_loop.run_forever(),
|
ret.append(CFunctionDescriptorToPython(c_function_descriptors[i]))
|
||||||
name="AsyncIO Thread"
|
return ret
|
||||||
|
|
||||||
|
cdef initialize_eventloops_for_actor_concurrency_group(
|
||||||
|
self,
|
||||||
|
const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups):
|
||||||
|
|
||||||
|
cdef:
|
||||||
|
CConcurrencyGroup c_concurrency_group
|
||||||
|
c_vector[CFunctionDescriptor] c_function_descriptors
|
||||||
|
|
||||||
|
self.cgname_to_eventloop_dict = {}
|
||||||
|
self.fd_to_cgname_dict = {}
|
||||||
|
|
||||||
|
self.eventloop_for_default_cg = get_new_event_loop()
|
||||||
|
self.thread_for_default_cg = threading.Thread(
|
||||||
|
target=lambda: self.eventloop_for_default_cg.run_forever(),
|
||||||
|
name="AsyncIO Thread: default"
|
||||||
|
)
|
||||||
|
# Making the thread as daemon to let it exit
|
||||||
|
# when the main thread exits.
|
||||||
|
self.thread_for_default_cg.daemon = True
|
||||||
|
self.thread_for_default_cg.start()
|
||||||
|
|
||||||
|
for i in range(c_defined_concurrency_groups.size()):
|
||||||
|
c_concurrency_group = c_defined_concurrency_groups[i]
|
||||||
|
cg_name = c_concurrency_group.GetName().decode("ascii")
|
||||||
|
function_descriptors = self.c_function_descriptors_to_python(
|
||||||
|
c_concurrency_group.GetFunctionDescriptors())
|
||||||
|
|
||||||
|
async_eventloop = get_new_event_loop()
|
||||||
|
async_thread = threading.Thread(
|
||||||
|
target=lambda: async_eventloop.run_forever(),
|
||||||
|
name="AsyncIO Thread: {}".format(cg_name)
|
||||||
)
|
)
|
||||||
# Making the thread a daemon causes it to exit
|
# Making the thread a daemon causes it to exit
|
||||||
# when the main thread exits.
|
# when the main thread exits.
|
||||||
self.async_thread.daemon = True
|
async_thread.daemon = True
|
||||||
self.async_thread.start()
|
async_thread.start()
|
||||||
|
|
||||||
return self.async_event_loop
|
self.cgname_to_eventloop_dict[cg_name] = {
|
||||||
|
"eventloop": async_eventloop,
|
||||||
|
"thread": async_thread,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fd in function_descriptors:
|
||||||
|
self.fd_to_cgname_dict[fd] = cg_name
|
||||||
|
|
||||||
|
def get_event_loop(self, function_descriptor, specified_cgname):
|
||||||
|
# __init__ will be invoked in default eventloop
|
||||||
|
if function_descriptor.function_name == "__init__":
|
||||||
|
return self.eventloop_for_default_cg, self.thread_for_default_cg
|
||||||
|
|
||||||
|
if specified_cgname is not None:
|
||||||
|
if specified_cgname in self.cgname_to_eventloop_dict:
|
||||||
|
this_group = self.cgname_to_eventloop_dict[specified_cgname]
|
||||||
|
return (this_group["eventloop"], this_group["thread"])
|
||||||
|
|
||||||
|
if function_descriptor in self.fd_to_cgname_dict:
|
||||||
|
curr_cgname = self.fd_to_cgname_dict[function_descriptor]
|
||||||
|
if curr_cgname in self.cgname_to_eventloop_dict:
|
||||||
|
return (
|
||||||
|
self.cgname_to_eventloop_dict[curr_cgname]["eventloop"],
|
||||||
|
self.cgname_to_eventloop_dict[curr_cgname]["thread"])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The function {} is defined to be executed "
|
||||||
|
"in the concurrency group {} . But there is no this group."
|
||||||
|
.format(function_descriptor, curr_cgname))
|
||||||
|
|
||||||
|
return self.eventloop_for_default_cg, self.thread_for_default_cg
|
||||||
|
|
||||||
|
def run_async_func_in_event_loop(
|
||||||
|
self, func, function_descriptor, specified_cgname, *args, **kwargs):
|
||||||
|
|
||||||
def run_async_func_in_event_loop(self, func, *args, **kwargs):
|
|
||||||
cdef:
|
cdef:
|
||||||
CFiberEvent event
|
CFiberEvent event
|
||||||
loop = self.create_or_get_event_loop()
|
eventloop, async_thread = self.get_event_loop(
|
||||||
|
function_descriptor, specified_cgname)
|
||||||
coroutine = func(*args, **kwargs)
|
coroutine = func(*args, **kwargs)
|
||||||
if threading.get_ident() == self.async_thread.ident:
|
if threading.get_ident() == async_thread.ident:
|
||||||
future = asyncio.ensure_future(coroutine, loop)
|
future = asyncio.ensure_future(coroutine, eventloop)
|
||||||
else:
|
else:
|
||||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
future = asyncio.run_coroutine_threadsafe(coroutine, eventloop)
|
||||||
future.add_done_callback(lambda _: event.Notify())
|
future.add_done_callback(lambda _: event.Notify())
|
||||||
with nogil:
|
with nogil:
|
||||||
(CCoreWorkerProcess.GetCoreWorker()
|
(CCoreWorkerProcess.GetCoreWorker()
|
||||||
|
|
|
@ -54,11 +54,14 @@ def method(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
assert len(args) == 0
|
assert len(args) == 0
|
||||||
assert len(kwargs) == 1
|
assert len(kwargs) == 1
|
||||||
assert "num_returns" in kwargs
|
|
||||||
num_returns = kwargs["num_returns"]
|
assert "num_returns" in kwargs or "concurrency_group" in kwargs
|
||||||
|
|
||||||
def annotate_method(method):
|
def annotate_method(method):
|
||||||
method.__ray_num_returns__ = num_returns
|
if "num_returns" in kwargs:
|
||||||
|
method.__ray_num_returns__ = kwargs["num_returns"]
|
||||||
|
if "concurrency_group" in kwargs:
|
||||||
|
method.__ray_concurrency_group__ = kwargs["concurrency_group"]
|
||||||
return method
|
return method
|
||||||
|
|
||||||
return annotate_method
|
return annotate_method
|
||||||
|
@ -138,7 +141,12 @@ class ActorMethod:
|
||||||
return FuncWrapper()
|
return FuncWrapper()
|
||||||
|
|
||||||
@_tracing_actor_method_invocation
|
@_tracing_actor_method_invocation
|
||||||
def _remote(self, args=None, kwargs=None, name="", num_returns=None):
|
def _remote(self,
|
||||||
|
args=None,
|
||||||
|
kwargs=None,
|
||||||
|
name="",
|
||||||
|
num_returns=None,
|
||||||
|
concurrency_group=None):
|
||||||
if num_returns is None:
|
if num_returns is None:
|
||||||
num_returns = self._num_returns
|
num_returns = self._num_returns
|
||||||
|
|
||||||
|
@ -222,6 +230,8 @@ class ActorClassMethodMetadata(object):
|
||||||
self.decorators = {}
|
self.decorators = {}
|
||||||
self.signatures = {}
|
self.signatures = {}
|
||||||
self.num_returns = {}
|
self.num_returns = {}
|
||||||
|
self.concurrency_group_for_methods = {}
|
||||||
|
|
||||||
for method_name, method in actor_methods:
|
for method_name, method in actor_methods:
|
||||||
# Whether or not this method requires binding of its first
|
# Whether or not this method requires binding of its first
|
||||||
# argument. For class and static methods, we do not want to bind
|
# argument. For class and static methods, we do not want to bind
|
||||||
|
@ -247,6 +257,10 @@ class ActorClassMethodMetadata(object):
|
||||||
self.decorators[method_name] = (
|
self.decorators[method_name] = (
|
||||||
method.__ray_invocation_decorator__)
|
method.__ray_invocation_decorator__)
|
||||||
|
|
||||||
|
if hasattr(method, "__ray_concurrency_group__"):
|
||||||
|
self.concurrency_group_for_methods[method_name] = (
|
||||||
|
method.__ray_concurrency_group__)
|
||||||
|
|
||||||
# Update cache.
|
# Update cache.
|
||||||
cls._cache[actor_creation_function_descriptor] = self
|
cls._cache[actor_creation_function_descriptor] = self
|
||||||
return self
|
return self
|
||||||
|
@ -285,8 +299,8 @@ class ActorClassMetadata:
|
||||||
def __init__(self, language, modified_class,
|
def __init__(self, language, modified_class,
|
||||||
actor_creation_function_descriptor, class_id, max_restarts,
|
actor_creation_function_descriptor, class_id, max_restarts,
|
||||||
max_task_retries, num_cpus, num_gpus, memory,
|
max_task_retries, num_cpus, num_gpus, memory,
|
||||||
object_store_memory, resources, accelerator_type,
|
object_store_memory, resources, accelerator_type, runtime_env,
|
||||||
runtime_env):
|
concurrency_groups):
|
||||||
self.language = language
|
self.language = language
|
||||||
self.modified_class = modified_class
|
self.modified_class = modified_class
|
||||||
self.actor_creation_function_descriptor = \
|
self.actor_creation_function_descriptor = \
|
||||||
|
@ -303,6 +317,7 @@ class ActorClassMetadata:
|
||||||
self.resources = resources
|
self.resources = resources
|
||||||
self.accelerator_type = accelerator_type
|
self.accelerator_type = accelerator_type
|
||||||
self.runtime_env = runtime_env
|
self.runtime_env = runtime_env
|
||||||
|
self.concurrency_groups = concurrency_groups
|
||||||
self.last_export_session_and_job = None
|
self.last_export_session_and_job = None
|
||||||
self.method_meta = ActorClassMethodMetadata.create(
|
self.method_meta = ActorClassMethodMetadata.create(
|
||||||
modified_class, actor_creation_function_descriptor)
|
modified_class, actor_creation_function_descriptor)
|
||||||
|
@ -358,10 +373,10 @@ class ActorClass:
|
||||||
f"use '{self.__ray_metadata__.class_name}.remote()'.")
|
f"use '{self.__ray_metadata__.class_name}.remote()'.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _ray_from_modified_class(cls, modified_class, class_id, max_restarts,
|
def _ray_from_modified_class(
|
||||||
max_task_retries, num_cpus, num_gpus, memory,
|
cls, modified_class, class_id, max_restarts, max_task_retries,
|
||||||
object_store_memory, resources,
|
num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||||
accelerator_type, runtime_env):
|
accelerator_type, runtime_env, concurrency_groups):
|
||||||
for attribute in [
|
for attribute in [
|
||||||
"remote",
|
"remote",
|
||||||
"_remote",
|
"_remote",
|
||||||
|
@ -398,7 +413,7 @@ class ActorClass:
|
||||||
Language.PYTHON, modified_class,
|
Language.PYTHON, modified_class,
|
||||||
actor_creation_function_descriptor, class_id, max_restarts,
|
actor_creation_function_descriptor, class_id, max_restarts,
|
||||||
max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
|
max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
|
||||||
resources, accelerator_type, new_runtime_env)
|
resources, accelerator_type, new_runtime_env, concurrency_groups)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -413,10 +428,12 @@ class ActorClass:
|
||||||
# .remote(), it would get run in the Ray Client server, which runs on
|
# .remote(), it would get run in the Ray Client server, which runs on
|
||||||
# a remote node where the files aren't available.
|
# a remote node where the files aren't available.
|
||||||
new_runtime_env = ParsedRuntimeEnv(runtime_env or {})
|
new_runtime_env = ParsedRuntimeEnv(runtime_env or {})
|
||||||
|
|
||||||
self.__ray_metadata__ = ActorClassMetadata(
|
self.__ray_metadata__ = ActorClassMetadata(
|
||||||
language, None, actor_creation_function_descriptor, None,
|
language, None, actor_creation_function_descriptor, None,
|
||||||
max_restarts, max_task_retries, num_cpus, num_gpus, memory,
|
max_restarts, max_task_retries, num_cpus, num_gpus, memory,
|
||||||
object_store_memory, resources, accelerator_type, new_runtime_env)
|
object_store_memory, resources, accelerator_type, new_runtime_env,
|
||||||
|
[])
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -740,6 +757,25 @@ class ActorClass:
|
||||||
parsed_runtime_env = override_task_or_actor_runtime_env(
|
parsed_runtime_env = override_task_or_actor_runtime_env(
|
||||||
runtime_env, parent_runtime_env)
|
runtime_env, parent_runtime_env)
|
||||||
|
|
||||||
|
concurrency_groups_dict = {}
|
||||||
|
for cg_name in meta.concurrency_groups:
|
||||||
|
concurrency_groups_dict[cg_name] = {
|
||||||
|
"name": cg_name,
|
||||||
|
"max_concurrency": meta.concurrency_groups[cg_name],
|
||||||
|
"function_descriptors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update methods
|
||||||
|
for method_name in meta.method_meta.concurrency_group_for_methods:
|
||||||
|
cg_name = meta.method_meta.concurrency_group_for_methods[
|
||||||
|
method_name]
|
||||||
|
assert cg_name in concurrency_groups_dict
|
||||||
|
|
||||||
|
module_name = meta.actor_creation_function_descriptor.module_name
|
||||||
|
class_name = meta.actor_creation_function_descriptor.class_name
|
||||||
|
concurrency_groups_dict[cg_name]["function_descriptors"].append(
|
||||||
|
PythonFunctionDescriptor(module_name, method_name, class_name))
|
||||||
|
|
||||||
actor_id = worker.core_worker.create_actor(
|
actor_id = worker.core_worker.create_actor(
|
||||||
meta.language,
|
meta.language,
|
||||||
meta.actor_creation_function_descriptor,
|
meta.actor_creation_function_descriptor,
|
||||||
|
@ -759,7 +795,8 @@ class ActorClass:
|
||||||
# Store actor_method_cpu in actor handle's extension data.
|
# Store actor_method_cpu in actor handle's extension data.
|
||||||
extension_data=str(actor_method_cpu),
|
extension_data=str(actor_method_cpu),
|
||||||
serialized_runtime_env=parsed_runtime_env.serialize(),
|
serialized_runtime_env=parsed_runtime_env.serialize(),
|
||||||
runtime_env_uris=parsed_runtime_env.get_uris())
|
runtime_env_uris=parsed_runtime_env.get_uris(),
|
||||||
|
concurrency_groups_dict=concurrency_groups_dict or dict())
|
||||||
|
|
||||||
actor_handle = ActorHandle(
|
actor_handle = ActorHandle(
|
||||||
meta.language,
|
meta.language,
|
||||||
|
@ -1060,7 +1097,8 @@ def modify_class(cls):
|
||||||
|
|
||||||
|
|
||||||
def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||||
accelerator_type, max_restarts, max_task_retries, runtime_env):
|
accelerator_type, max_restarts, max_task_retries, runtime_env,
|
||||||
|
concurrency_groups):
|
||||||
Class = modify_class(cls)
|
Class = modify_class(cls)
|
||||||
_inject_tracing_into_class(Class)
|
_inject_tracing_into_class(Class)
|
||||||
|
|
||||||
|
@ -1068,6 +1106,8 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||||
max_restarts = 0
|
max_restarts = 0
|
||||||
if max_task_retries is None:
|
if max_task_retries is None:
|
||||||
max_task_retries = 0
|
max_task_retries = 0
|
||||||
|
if concurrency_groups is None:
|
||||||
|
concurrency_groups = []
|
||||||
|
|
||||||
infinite_restart = max_restarts == -1
|
infinite_restart = max_restarts == -1
|
||||||
if not infinite_restart:
|
if not infinite_restart:
|
||||||
|
@ -1086,7 +1126,7 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||||
return ActorClass._ray_from_modified_class(
|
return ActorClass._ray_from_modified_class(
|
||||||
Class, ActorClassID.from_random(), max_restarts, max_task_retries,
|
Class, ActorClassID.from_random(), max_restarts, max_task_retries,
|
||||||
num_cpus, num_gpus, memory, object_store_memory, resources,
|
num_cpus, num_gpus, memory, object_store_memory, resources,
|
||||||
accelerator_type, runtime_env)
|
accelerator_type, runtime_env, concurrency_groups)
|
||||||
|
|
||||||
|
|
||||||
def exit_actor():
|
def exit_actor():
|
||||||
|
|
|
@ -2,7 +2,7 @@ from libcpp cimport bool as c_bool
|
||||||
from libcpp.memory cimport shared_ptr, unique_ptr
|
from libcpp.memory cimport shared_ptr, unique_ptr
|
||||||
from libcpp.string cimport string as c_string
|
from libcpp.string cimport string as c_string
|
||||||
|
|
||||||
from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t
|
from libc.stdint cimport uint8_t, int32_t, uint64_t, int64_t, uint32_t
|
||||||
from libcpp.unordered_map cimport unordered_map
|
from libcpp.unordered_map cimport unordered_map
|
||||||
from libcpp.vector cimport vector as c_vector
|
from libcpp.vector cimport vector as c_vector
|
||||||
from libcpp.pair cimport pair as c_pair
|
from libcpp.pair cimport pair as c_pair
|
||||||
|
@ -255,7 +255,8 @@ cdef extern from "ray/core_worker/common.h" nogil:
|
||||||
c_pair[CPlacementGroupID, int64_t] placement_options,
|
c_pair[CPlacementGroupID, int64_t] placement_options,
|
||||||
c_bool placement_group_capture_child_tasks,
|
c_bool placement_group_capture_child_tasks,
|
||||||
c_string serialized_runtime_env,
|
c_string serialized_runtime_env,
|
||||||
c_vector[c_string] runtime_env_uris)
|
c_vector[c_string] runtime_env_uris,
|
||||||
|
const c_vector[CConcurrencyGroup] &concurrency_groups)
|
||||||
|
|
||||||
cdef cppclass CPlacementGroupCreationOptions \
|
cdef cppclass CPlacementGroupCreationOptions \
|
||||||
"ray::core::PlacementGroupCreationOptions":
|
"ray::core::PlacementGroupCreationOptions":
|
||||||
|
@ -283,3 +284,14 @@ cdef extern from "ray/gcs/gcs_client.h" nogil:
|
||||||
cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
|
cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
|
||||||
cdef cppclass CJobConfig "ray::rpc::JobConfig":
|
cdef cppclass CJobConfig "ray::rpc::JobConfig":
|
||||||
const c_string &SerializeAsString()
|
const c_string &SerializeAsString()
|
||||||
|
|
||||||
|
cdef extern from "ray/common/task/task_spec.h" nogil:
|
||||||
|
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
|
||||||
|
CConcurrencyGroup(
|
||||||
|
const c_string &name,
|
||||||
|
uint32_t max_concurrency,
|
||||||
|
const c_vector[CFunctionDescriptor] &c_fds)
|
||||||
|
CConcurrencyGroup()
|
||||||
|
c_string GetName() const
|
||||||
|
uint32_t GetMaxConcurrency() const
|
||||||
|
c_vector[CFunctionDescriptor] GetFunctionDescriptors() const
|
||||||
|
|
|
@ -43,6 +43,7 @@ from ray.includes.common cimport (
|
||||||
CGcsClientOptions,
|
CGcsClientOptions,
|
||||||
LocalMemoryBuffer,
|
LocalMemoryBuffer,
|
||||||
CJobConfig,
|
CJobConfig,
|
||||||
|
CConcurrencyGroup,
|
||||||
)
|
)
|
||||||
from ray.includes.function_descriptor cimport (
|
from ray.includes.function_descriptor cimport (
|
||||||
CFunctionDescriptor,
|
CFunctionDescriptor,
|
||||||
|
@ -284,7 +285,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||||
c_vector[shared_ptr[CRayObject]] *returns,
|
c_vector[shared_ptr[CRayObject]] *returns,
|
||||||
shared_ptr[LocalMemoryBuffer]
|
shared_ptr[LocalMemoryBuffer]
|
||||||
&creation_task_exception_pb_bytes,
|
&creation_task_exception_pb_bytes,
|
||||||
c_bool *is_application_level_error) nogil
|
c_bool *is_application_level_error,
|
||||||
|
const c_vector[CConcurrencyGroup] &defined_concurrency_groups,
|
||||||
|
const c_string name_of_concurrency_group_to_execute) nogil
|
||||||
) task_execution_callback
|
) task_execution_callback
|
||||||
(void(const CWorkerID &) nogil) on_worker_shutdown
|
(void(const CWorkerID &) nogil) on_worker_shutdown
|
||||||
(CRayStatus() nogil) check_signals
|
(CRayStatus() nogil) check_signals
|
||||||
|
|
114
python/ray/tests/test_concurrency_group.py
Normal file
114
python/ray/tests/test_concurrency_group.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
# coding: utf-8
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
# This tests the methods are executed in the correct eventloop.
|
||||||
|
def test_basic():
|
||||||
|
@ray.remote(concurrency_groups={"io": 2, "compute": 4})
|
||||||
|
class AsyncActor:
|
||||||
|
def __init__(self):
|
||||||
|
self.eventloop_f1 = None
|
||||||
|
self.eventloop_f2 = None
|
||||||
|
self.eventloop_f3 = None
|
||||||
|
self.eventloop_f4 = None
|
||||||
|
self.default_eventloop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="io")
|
||||||
|
async def f1(self):
|
||||||
|
self.eventloop_f1 = asyncio.get_event_loop()
|
||||||
|
return threading.current_thread().ident
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="io")
|
||||||
|
def f2(self):
|
||||||
|
self.eventloop_f2 = asyncio.get_event_loop()
|
||||||
|
return threading.current_thread().ident
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
def f3(self):
|
||||||
|
self.eventloop_f3 = asyncio.get_event_loop()
|
||||||
|
return threading.current_thread().ident
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
def f4(self):
|
||||||
|
self.eventloop_f4 = asyncio.get_event_loop()
|
||||||
|
return threading.current_thread().ident
|
||||||
|
|
||||||
|
def f5(self):
|
||||||
|
# If this method is executed in default eventloop.
|
||||||
|
assert asyncio.get_event_loop() == self.default_eventloop
|
||||||
|
return threading.current_thread().ident
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="io")
|
||||||
|
def do_assert(self):
|
||||||
|
if self.eventloop_f1 != self.eventloop_f2:
|
||||||
|
return False
|
||||||
|
if self.eventloop_f3 != self.eventloop_f4:
|
||||||
|
return False
|
||||||
|
if self.eventloop_f1 == self.eventloop_f3:
|
||||||
|
return False
|
||||||
|
if self.eventloop_f1 == self.eventloop_f4:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
a = AsyncActor.remote()
|
||||||
|
f1_thread_id = ray.get(a.f1.remote()) # executed in the "io" group.
|
||||||
|
f2_thread_id = ray.get(a.f2.remote()) # executed in the "io" group.
|
||||||
|
f3_thread_id = ray.get(a.f3.remote()) # executed in the "compute" group.
|
||||||
|
f4_thread_id = ray.get(a.f4.remote()) # executed in the "compute" group.
|
||||||
|
|
||||||
|
assert f1_thread_id == f2_thread_id
|
||||||
|
assert f3_thread_id == f4_thread_id
|
||||||
|
assert f1_thread_id != f3_thread_id
|
||||||
|
|
||||||
|
assert ray.get(a.do_assert.remote())
|
||||||
|
|
||||||
|
assert ray.get(a.f5.remote()) # executed in the default group.
|
||||||
|
|
||||||
|
# It also has the ability to specify it at runtime.
|
||||||
|
# This task will be invoked in the `compute` thread pool.
|
||||||
|
a.f2.options(concurrency_group="compute").remote()
|
||||||
|
|
||||||
|
|
||||||
|
# The case tests that the asyncio count down works well in one concurrency
|
||||||
|
# group.
|
||||||
|
def test_async_methods_in_concurrency_group():
|
||||||
|
@ray.remote(concurrency_groups={"async": 3})
|
||||||
|
class AsyncBatcher:
|
||||||
|
def __init__(self):
|
||||||
|
self.batch = []
|
||||||
|
self.event = None
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="async")
|
||||||
|
def init_event(self):
|
||||||
|
self.event = asyncio.Event()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="async")
|
||||||
|
async def add(self, x):
|
||||||
|
self.batch.append(x)
|
||||||
|
if len(self.batch) >= 3:
|
||||||
|
self.event.set()
|
||||||
|
else:
|
||||||
|
await self.event.wait()
|
||||||
|
return sorted(self.batch)
|
||||||
|
|
||||||
|
a = AsyncBatcher.remote()
|
||||||
|
ray.get(a.init_event.remote())
|
||||||
|
|
||||||
|
x1 = a.add.remote(1)
|
||||||
|
x2 = a.add.remote(2)
|
||||||
|
x3 = a.add.remote(3)
|
||||||
|
r1 = ray.get(x1)
|
||||||
|
r2 = ray.get(x2)
|
||||||
|
r3 = ray.get(x3)
|
||||||
|
assert r1 == [1, 2, 3]
|
||||||
|
assert r1 == r2 == r3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1928,7 +1928,8 @@ def make_decorator(num_returns=None,
|
||||||
runtime_env=None,
|
runtime_env=None,
|
||||||
placement_group="default",
|
placement_group="default",
|
||||||
worker=None,
|
worker=None,
|
||||||
retry_exceptions=None):
|
retry_exceptions=None,
|
||||||
|
concurrency_groups=None):
|
||||||
def decorator(function_or_class):
|
def decorator(function_or_class):
|
||||||
if (inspect.isfunction(function_or_class)
|
if (inspect.isfunction(function_or_class)
|
||||||
or is_cython(function_or_class)):
|
or is_cython(function_or_class)):
|
||||||
|
@ -1983,10 +1984,10 @@ def make_decorator(num_returns=None,
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The keyword 'max_task_retries' only accepts -1, 0 or a"
|
"The keyword 'max_task_retries' only accepts -1, 0 or a"
|
||||||
" positive integer")
|
" positive integer")
|
||||||
return ray.actor.make_actor(function_or_class, num_cpus, num_gpus,
|
return ray.actor.make_actor(
|
||||||
memory, object_store_memory, resources,
|
function_or_class, num_cpus, num_gpus, memory,
|
||||||
accelerator_type, max_restarts,
|
object_store_memory, resources, accelerator_type, max_restarts,
|
||||||
max_task_retries, runtime_env)
|
max_task_retries, runtime_env, concurrency_groups)
|
||||||
|
|
||||||
raise TypeError("The @ray.remote decorator must be applied to "
|
raise TypeError("The @ray.remote decorator must be applied to "
|
||||||
"either a function or to a class.")
|
"either a function or to a class.")
|
||||||
|
@ -2105,10 +2106,21 @@ def remote(*args, **kwargs):
|
||||||
|
|
||||||
# Parse the keyword arguments from the decorator.
|
# Parse the keyword arguments from the decorator.
|
||||||
valid_kwargs = [
|
valid_kwargs = [
|
||||||
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
|
"num_returns",
|
||||||
"resources", "accelerator_type", "max_calls", "max_restarts",
|
"num_cpus",
|
||||||
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions",
|
"num_gpus",
|
||||||
"placement_group"
|
"memory",
|
||||||
|
"object_store_memory",
|
||||||
|
"resources",
|
||||||
|
"accelerator_type",
|
||||||
|
"max_calls",
|
||||||
|
"max_restarts",
|
||||||
|
"max_task_retries",
|
||||||
|
"max_retries",
|
||||||
|
"runtime_env",
|
||||||
|
"retry_exceptions",
|
||||||
|
"placement_group",
|
||||||
|
"concurrency_groups",
|
||||||
]
|
]
|
||||||
error_string = ("The @ray.remote decorator must be applied either "
|
error_string = ("The @ray.remote decorator must be applied either "
|
||||||
"with no arguments and no parentheses, for example "
|
"with no arguments and no parentheses, for example "
|
||||||
|
@ -2143,6 +2155,7 @@ def remote(*args, **kwargs):
|
||||||
runtime_env = kwargs.get("runtime_env")
|
runtime_env = kwargs.get("runtime_env")
|
||||||
placement_group = kwargs.get("placement_group", "default")
|
placement_group = kwargs.get("placement_group", "default")
|
||||||
retry_exceptions = kwargs.get("retry_exceptions")
|
retry_exceptions = kwargs.get("retry_exceptions")
|
||||||
|
concurrency_groups = kwargs.get("concurrency_groups")
|
||||||
|
|
||||||
return make_decorator(
|
return make_decorator(
|
||||||
num_returns=num_returns,
|
num_returns=num_returns,
|
||||||
|
@ -2159,4 +2172,5 @@ def remote(*args, **kwargs):
|
||||||
runtime_env=runtime_env,
|
runtime_env=runtime_env,
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
worker=worker,
|
worker=worker,
|
||||||
retry_exceptions=retry_exceptions)
|
retry_exceptions=retry_exceptions,
|
||||||
|
concurrency_groups=concurrency_groups or [])
|
||||||
|
|
|
@ -43,6 +43,20 @@ struct ConcurrencyGroup {
|
||||||
uint32_t max_concurrency;
|
uint32_t max_concurrency;
|
||||||
// Function descriptors of the actor methods in this group.
|
// Function descriptors of the actor methods in this group.
|
||||||
std::vector<ray::FunctionDescriptor> function_descriptors;
|
std::vector<ray::FunctionDescriptor> function_descriptors;
|
||||||
|
|
||||||
|
ConcurrencyGroup() = default;
|
||||||
|
|
||||||
|
ConcurrencyGroup(const std::string &name, uint32_t max_concurrency,
|
||||||
|
const std::vector<ray::FunctionDescriptor> &fds)
|
||||||
|
: name(name), max_concurrency(max_concurrency), function_descriptors(fds) {}
|
||||||
|
|
||||||
|
std::string GetName() const { return name; }
|
||||||
|
|
||||||
|
uint32_t GetMaxConcurrency() const { return max_concurrency; }
|
||||||
|
|
||||||
|
std::vector<ray::FunctionDescriptor> GetFunctionDescriptors() const {
|
||||||
|
return function_descriptors;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline rpc::ObjectReference GetReferenceForActorDummyObject(
|
static inline rpc::ObjectReference GetReferenceForActorDummyObject(
|
||||||
|
|
|
@ -2273,11 +2273,21 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
||||||
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
|
CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID());
|
||||||
|
|
||||||
std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr;
|
std::shared_ptr<LocalMemoryBuffer> creation_task_exception_pb_bytes = nullptr;
|
||||||
|
|
||||||
|
std::vector<ConcurrencyGroup> defined_concurrency_groups = {};
|
||||||
|
std::string name_of_concurrency_group_to_execute;
|
||||||
|
if (task_spec.IsActorCreationTask()) {
|
||||||
|
defined_concurrency_groups = task_spec.ConcurrencyGroups();
|
||||||
|
} else if (task_spec.IsActorTask()) {
|
||||||
|
name_of_concurrency_group_to_execute = task_spec.ConcurrencyGroupName();
|
||||||
|
}
|
||||||
|
|
||||||
status = options_.task_execution_callback(
|
status = options_.task_execution_callback(
|
||||||
task_type, task_spec.GetName(), func,
|
task_type, task_spec.GetName(), func,
|
||||||
task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
|
task_spec.GetRequiredResources().GetResourceUnorderedMap(), args, arg_refs,
|
||||||
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
|
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
|
||||||
creation_task_exception_pb_bytes, is_application_level_error);
|
creation_task_exception_pb_bytes, is_application_level_error,
|
||||||
|
defined_concurrency_groups, name_of_concurrency_group_to_execute);
|
||||||
|
|
||||||
// Get the reference counts for any IDs that we borrowed during this task,
|
// Get the reference counts for any IDs that we borrowed during this task,
|
||||||
// remove the local reference for these IDs, and return the ref count info to
|
// remove the local reference for these IDs, and return the ref count info to
|
||||||
|
|
|
@ -71,7 +71,15 @@ struct CoreWorkerOptions {
|
||||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||||
std::vector<std::shared_ptr<RayObject>> *results,
|
std::vector<std::shared_ptr<RayObject>> *results,
|
||||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||||
bool *is_application_level_error)>;
|
bool *is_application_level_error,
|
||||||
|
// The following 2 parameters `defined_concurrency_groups` and
|
||||||
|
// `name_of_concurrency_group_to_execute` are used for Python
|
||||||
|
// asyncio actor only.
|
||||||
|
//
|
||||||
|
// Defined concurrency groups of this actor. Note this is only
|
||||||
|
// used for actor creation task.
|
||||||
|
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||||
|
const std::string name_of_concurrency_group_to_execute)>;
|
||||||
|
|
||||||
CoreWorkerOptions()
|
CoreWorkerOptions()
|
||||||
: store_socket(""),
|
: store_socket(""),
|
||||||
|
|
|
@ -101,7 +101,13 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||||
std::vector<std::shared_ptr<RayObject>> *results,
|
std::vector<std::shared_ptr<RayObject>> *results,
|
||||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
|
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
|
||||||
bool *is_application_level_error) {
|
bool *is_application_level_error,
|
||||||
|
const std::vector<ConcurrencyGroup> &defined_concurrency_groups,
|
||||||
|
const std::string name_of_concurrency_group_to_execute) {
|
||||||
|
// These 2 parameters are used for Python only, and Java worker
|
||||||
|
// will not use them.
|
||||||
|
RAY_UNUSED(defined_concurrency_groups);
|
||||||
|
RAY_UNUSED(name_of_concurrency_group_to_execute);
|
||||||
// TODO(jjyao): Support retrying application-level errors for Java
|
// TODO(jjyao): Support retrying application-level errors for Java
|
||||||
*is_application_level_error = false;
|
*is_application_level_error = false;
|
||||||
|
|
||||||
|
|
|
@ -509,6 +509,21 @@ class MockWorkerContext : public WorkerContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class MockCoreWorkerDirectTaskReceiver : public CoreWorkerDirectTaskReceiver {
|
||||||
|
public:
|
||||||
|
MockCoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
|
||||||
|
instrumented_io_context &main_io_service,
|
||||||
|
const TaskHandler &task_handler,
|
||||||
|
const OnTaskDone &task_done)
|
||||||
|
: CoreWorkerDirectTaskReceiver(worker_context, main_io_service, task_handler,
|
||||||
|
task_done) {}
|
||||||
|
|
||||||
|
void UpdateConcurrencyGroupsCache(const ActorID &actor_id,
|
||||||
|
const std::vector<ConcurrencyGroup> &cgs) {
|
||||||
|
concurrency_groups_cache_[actor_id] = cgs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class DirectActorReceiverTest : public ::testing::Test {
|
class DirectActorReceiverTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
DirectActorReceiverTest()
|
DirectActorReceiverTest()
|
||||||
|
@ -518,7 +533,7 @@ class DirectActorReceiverTest : public ::testing::Test {
|
||||||
auto execute_task =
|
auto execute_task =
|
||||||
std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1,
|
std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1,
|
||||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||||
receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
|
receiver_ = std::make_unique<MockCoreWorkerDirectTaskReceiver>(
|
||||||
worker_context_, main_io_service_, execute_task, [] { return Status::OK(); });
|
worker_context_, main_io_service_, execute_task, [] { return Status::OK(); });
|
||||||
receiver_->Init(std::make_shared<rpc::CoreWorkerClientPool>(
|
receiver_->Init(std::make_shared<rpc::CoreWorkerClientPool>(
|
||||||
[&](const rpc::Address &addr) { return worker_client_; }),
|
[&](const rpc::Address &addr) { return worker_client_; }),
|
||||||
|
@ -541,7 +556,7 @@ class DirectActorReceiverTest : public ::testing::Test {
|
||||||
main_io_service_.stop();
|
main_io_service_.stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CoreWorkerDirectTaskReceiver> receiver_;
|
std::unique_ptr<MockCoreWorkerDirectTaskReceiver> receiver_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
rpc::Address rpc_address_;
|
rpc::Address rpc_address_;
|
||||||
|
@ -575,6 +590,7 @@ TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) {
|
||||||
++callback_count;
|
++callback_count;
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
};
|
};
|
||||||
|
receiver_->UpdateConcurrencyGroupsCache(actor_id, {});
|
||||||
receiver_->HandleTask(request, &reply, reply_callback);
|
receiver_->HandleTask(request, &reply, reply_callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -524,6 +524,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
return_object->add_nested_inlined_refs()->CopyFrom(nested_ref);
|
return_object->add_nested_inlined_refs()->CopyFrom(nested_ref);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (task_spec.IsActorCreationTask()) {
|
if (task_spec.IsActorCreationTask()) {
|
||||||
/// The default max concurrency for creating PoolManager should
|
/// The default max concurrency for creating PoolManager should
|
||||||
/// be 0 if this is an asyncio actor.
|
/// be 0 if this is an asyncio actor.
|
||||||
|
@ -531,6 +532,8 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency();
|
task_spec.IsAsyncioActor() ? 0 : task_spec.MaxActorConcurrency();
|
||||||
pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(),
|
pool_manager_ = std::make_shared<PoolManager>(task_spec.ConcurrencyGroups(),
|
||||||
default_max_concurrency);
|
default_max_concurrency);
|
||||||
|
concurrency_groups_cache_[task_spec.TaskId().ActorId()] =
|
||||||
|
task_spec.ConcurrencyGroups();
|
||||||
RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId()
|
RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId()
|
||||||
<< ", actor_id: " << task_spec.ActorCreationId();
|
<< ", actor_id: " << task_spec.ActorCreationId();
|
||||||
// Tell raylet that an actor creation task has finished execution, so that
|
// Tell raylet that an actor creation task has finished execution, so that
|
||||||
|
@ -573,11 +576,13 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
||||||
if (task_spec.IsActorTask()) {
|
if (task_spec.IsActorTask()) {
|
||||||
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
auto it = actor_scheduling_queues_.find(task_spec.CallerWorkerId());
|
||||||
if (it == actor_scheduling_queues_.end()) {
|
if (it == actor_scheduling_queues_.end()) {
|
||||||
|
auto cg_it = concurrency_groups_cache_.find(task_spec.ActorId());
|
||||||
|
RAY_CHECK(cg_it != concurrency_groups_cache_.end());
|
||||||
auto result = actor_scheduling_queues_.emplace(
|
auto result = actor_scheduling_queues_.emplace(
|
||||||
task_spec.CallerWorkerId(),
|
task_spec.CallerWorkerId(),
|
||||||
std::unique_ptr<SchedulingQueue>(
|
std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
|
||||||
new ActorSchedulingQueue(task_main_io_service_, *waiter_, pool_manager_,
|
task_main_io_service_, *waiter_, pool_manager_, is_asyncio_,
|
||||||
is_asyncio_, fiber_max_concurrency_)));
|
fiber_max_concurrency_, cg_it->second)));
|
||||||
it = result.first;
|
it = result.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -283,6 +283,69 @@ class CoreWorkerDirectActorTaskSubmitter
|
||||||
friend class CoreWorkerTest;
|
friend class CoreWorkerTest;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// The class that manages fiber states for Python asyncio actors.
|
||||||
|
///
|
||||||
|
/// We'll create one fiber state for every concurrency group. And
|
||||||
|
/// create one default fiber state for default concurrency group if
|
||||||
|
/// necessary.
|
||||||
|
class FiberStateManager final {
|
||||||
|
public:
|
||||||
|
explicit FiberStateManager(const std::vector<ConcurrencyGroup> &concurrency_groups = {},
|
||||||
|
const int32_t default_group_max_concurrency = 1000) {
|
||||||
|
for (auto &group : concurrency_groups) {
|
||||||
|
const auto name = group.name;
|
||||||
|
const auto max_concurrency = group.max_concurrency;
|
||||||
|
auto fiber = std::make_shared<FiberState>(max_concurrency);
|
||||||
|
auto &fds = group.function_descriptors;
|
||||||
|
for (auto fd : fds) {
|
||||||
|
functions_to_fiber_index_[fd->ToString()] = fiber;
|
||||||
|
}
|
||||||
|
name_to_fiber_index_[name] = fiber;
|
||||||
|
}
|
||||||
|
/// Create default fiber state for default concurrency group.
|
||||||
|
if (default_group_max_concurrency >= 1) {
|
||||||
|
default_fiber_ = std::make_shared<FiberState>(default_group_max_concurrency);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the corresponding fiber state by the give concurrency group or function
|
||||||
|
/// descriptor.
|
||||||
|
///
|
||||||
|
/// Return the corresponding fiber state of the concurrency group
|
||||||
|
/// if concurrency_group_name is given.
|
||||||
|
/// Otherwise return the corresponding fiber state by the given function descriptor.
|
||||||
|
std::shared_ptr<FiberState> GetFiber(const std::string &concurrency_group_name,
|
||||||
|
ray::FunctionDescriptor fd) {
|
||||||
|
if (!concurrency_group_name.empty()) {
|
||||||
|
auto it = name_to_fiber_index_.find(concurrency_group_name);
|
||||||
|
RAY_CHECK(it != name_to_fiber_index_.end())
|
||||||
|
<< "Failed to look up the fiber state of the given concurrency group "
|
||||||
|
<< concurrency_group_name << " . It might be that you didn't define "
|
||||||
|
<< "the concurrency group " << concurrency_group_name;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Code path of that this task wasn't specified in a concurrency group addtionally.
|
||||||
|
/// Use the predefined concurrency group.
|
||||||
|
if (functions_to_fiber_index_.find(fd->ToString()) !=
|
||||||
|
functions_to_fiber_index_.end()) {
|
||||||
|
return functions_to_fiber_index_[fd->ToString()];
|
||||||
|
}
|
||||||
|
return default_fiber_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Map from the name to their corresponding fibers.
|
||||||
|
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> name_to_fiber_index_;
|
||||||
|
|
||||||
|
// Map from the FunctionDescriptors to their corresponding fibers.
|
||||||
|
absl::flat_hash_map<std::string, std::shared_ptr<FiberState>> functions_to_fiber_index_;
|
||||||
|
|
||||||
|
// The fiber for default concurrency group. It's nullptr if its max concurrency
|
||||||
|
// is 1.
|
||||||
|
std::shared_ptr<FiberState> default_fiber_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
class BoundedExecutor;
|
class BoundedExecutor;
|
||||||
|
|
||||||
/// A manager that manages a set of thread pool. which will perform
|
/// A manager that manages a set of thread pool. which will perform
|
||||||
|
@ -489,6 +552,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
instrumented_io_context &main_io_service, DependencyWaiter &waiter,
|
instrumented_io_context &main_io_service, DependencyWaiter &waiter,
|
||||||
std::shared_ptr<PoolManager> pool_manager = std::make_shared<PoolManager>(),
|
std::shared_ptr<PoolManager> pool_manager = std::make_shared<PoolManager>(),
|
||||||
bool is_asyncio = false, int fiber_max_concurrency = 1,
|
bool is_asyncio = false, int fiber_max_concurrency = 1,
|
||||||
|
const std::vector<ConcurrencyGroup> &concurrency_groups = {},
|
||||||
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
||||||
: reorder_wait_seconds_(reorder_wait_seconds),
|
: reorder_wait_seconds_(reorder_wait_seconds),
|
||||||
wait_timer_(main_io_service),
|
wait_timer_(main_io_service),
|
||||||
|
@ -497,9 +561,16 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
pool_manager_(pool_manager),
|
pool_manager_(pool_manager),
|
||||||
is_asyncio_(is_asyncio) {
|
is_asyncio_(is_asyncio) {
|
||||||
if (is_asyncio_) {
|
if (is_asyncio_) {
|
||||||
RAY_LOG(INFO) << "Setting actor as async with max_concurrency="
|
std::stringstream ss;
|
||||||
<< fiber_max_concurrency << ", creating new fiber thread.";
|
ss << "Setting actor as asyncio with max_concurrency=" << fiber_max_concurrency
|
||||||
fiber_state_ = std::make_unique<FiberState>(fiber_max_concurrency);
|
<< ", and defined concurrency groups are:" << std::endl;
|
||||||
|
for (const auto &concurrency_group : concurrency_groups) {
|
||||||
|
ss << "\t" << concurrency_group.name << " : "
|
||||||
|
<< concurrency_group.max_concurrency;
|
||||||
|
}
|
||||||
|
RAY_LOG(INFO) << ss.str();
|
||||||
|
fiber_state_manager_ =
|
||||||
|
std::make_unique<FiberStateManager>(concurrency_groups, fiber_max_concurrency);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -595,7 +666,9 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
|
|
||||||
if (is_asyncio_) {
|
if (is_asyncio_) {
|
||||||
// Process async actor task.
|
// Process async actor task.
|
||||||
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
|
auto fiber = fiber_state_manager_->GetFiber(request.ConcurrencyGroupName(),
|
||||||
|
request.FunctionDescriptor());
|
||||||
|
fiber->EnqueueFiber([request]() mutable { request.Accept(); });
|
||||||
} else {
|
} else {
|
||||||
// Process actor tasks.
|
// Process actor tasks.
|
||||||
RAY_CHECK(pool_manager_ != nullptr);
|
RAY_CHECK(pool_manager_ != nullptr);
|
||||||
|
@ -661,9 +734,10 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
||||||
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
||||||
/// will instantiate all tasks as fibers that can be yielded.
|
/// will instantiate all tasks as fibers that can be yielded.
|
||||||
bool is_asyncio_ = false;
|
bool is_asyncio_ = false;
|
||||||
/// If is_asyncio_ is true, fiber_state_ contains the running state required
|
/// Manage the running fiber states of actors in this worker. It works with
|
||||||
/// to enable continuation and work together with python asyncio.
|
/// python asyncio if this is an asyncio actor.
|
||||||
std::unique_ptr<FiberState> fiber_state_;
|
std::unique_ptr<FiberStateManager> fiber_state_manager_;
|
||||||
|
|
||||||
friend class SchedulingQueueTest;
|
friend class SchedulingQueueTest;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -822,6 +896,10 @@ class CoreWorkerDirectTaskReceiver {
|
||||||
|
|
||||||
bool CancelQueuedNormalTask(TaskID task_id);
|
bool CancelQueuedNormalTask(TaskID task_id);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Cache the concurrency groups of actors.
|
||||||
|
absl::flat_hash_map<ActorID, std::vector<ConcurrencyGroup>> concurrency_groups_cache_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Worker context.
|
// Worker context.
|
||||||
WorkerContext &worker_context_;
|
WorkerContext &worker_context_;
|
||||||
|
|
Loading…
Add table
Reference in a new issue