mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Core] Add private on_completed callback for ObjectRef (#13688)
This commit is contained in:
parent
32ec0d205f
commit
25fa391193
3 changed files with 70 additions and 36 deletions
|
@ -1569,12 +1569,13 @@ cdef class CoreWorker:
|
|||
|
||||
return ref_counts
|
||||
|
||||
def get_async(self, ObjectRef object_ref, future):
|
||||
cpython.Py_INCREF(future)
|
||||
def set_get_async_callback(self, ObjectRef object_ref, callback):
|
||||
cpython.Py_INCREF(callback)
|
||||
CCoreWorkerProcess.GetCoreWorker().GetAsync(
|
||||
object_ref.native(),
|
||||
async_set_result,
|
||||
<void*>future)
|
||||
object_ref.native(),
|
||||
async_callback,
|
||||
<void*>callback
|
||||
)
|
||||
|
||||
def push_error(self, JobID job_id, error_type, error_message,
|
||||
double timestamp):
|
||||
|
@ -1588,13 +1589,11 @@ cdef class CoreWorker:
|
|||
resource_name.encode("ascii"), capacity,
|
||||
CNodeID.FromBinary(client_id.binary()))
|
||||
|
||||
cdef void async_set_result(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_ref,
|
||||
void *future) with gil:
|
||||
cdef void async_callback(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_ref,
|
||||
void *user_callback) with gil:
|
||||
cdef:
|
||||
c_vector[shared_ptr[CRayObject]] objects_to_deserialize
|
||||
py_future = <object>(future)
|
||||
loop = py_future._loop
|
||||
|
||||
# Object is retrieved from in memory store.
|
||||
# Here we go through the code path used to deserialize objects.
|
||||
|
@ -1605,23 +1604,6 @@ cdef void async_set_result(shared_ptr[CRayObject] obj,
|
|||
result = ray.worker.global_worker.deserialize_objects(
|
||||
data_metadata_pairs, ids_to_deserialize)[0]
|
||||
|
||||
def set_future():
|
||||
# Issue #11030, #8841
|
||||
# If this future has result set already, we just need to
|
||||
# skip the set result/exception procedure.
|
||||
if py_future.done():
|
||||
cpython.Py_DECREF(py_future)
|
||||
return
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
py_future.set_exception(result.as_instanceof_cause())
|
||||
elif isinstance(result, RayError):
|
||||
# Directly raise exception for RayActorError
|
||||
py_future.set_exception(result)
|
||||
else:
|
||||
py_future.set_result(result)
|
||||
|
||||
cpython.Py_DECREF(py_future)
|
||||
|
||||
loop.call_soon_threadsafe(set_future)
|
||||
py_callback = <object>user_callback
|
||||
py_callback(result)
|
||||
cpython.Py_DECREF(py_callback)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from ray.includes.unique_ids cimport CObjectID
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Any
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -71,10 +72,41 @@ cdef class ObjectRef(BaseID):
|
|||
|
||||
def as_future(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
py_future = loop.create_future()
|
||||
|
||||
def callback(result):
|
||||
loop = py_future._loop
|
||||
|
||||
def set_future():
|
||||
# Issue #11030, #8841
|
||||
# If this future has result set already, we just need to
|
||||
# skip the set result/exception procedure.
|
||||
if py_future.done():
|
||||
return
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
py_future.set_exception(result.as_instanceof_cause())
|
||||
elif isinstance(result, RayError):
|
||||
# Directly raise exception for RayActorError
|
||||
py_future.set_exception(result)
|
||||
else:
|
||||
py_future.set_result(result)
|
||||
|
||||
loop.call_soon_threadsafe(set_future)
|
||||
|
||||
self._on_completed(callback)
|
||||
|
||||
future = loop.create_future()
|
||||
core_worker.get_async(self, future)
|
||||
# A hack to keep a reference to the object ref for ref counting.
|
||||
future.object_ref = self
|
||||
return future
|
||||
py_future.object_ref = self
|
||||
return py_future
|
||||
|
||||
def _on_completed(self, py_callback: Callable[[Any], None]):
|
||||
"""Register a callback that will be called after Object is ready.
|
||||
If the ObjectRef is already ready, the callback will be called soon.
|
||||
The callback should take the result as the only argument. The result
|
||||
can be an exception object in case of task error.
|
||||
"""
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
core_worker.set_get_async_callback(self, py_callback)
|
||||
return self
|
||||
|
|
|
@ -6,7 +6,7 @@ import threading
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.test_utils import SignalActor
|
||||
from ray.test_utils import SignalActor, wait_for_condition
|
||||
|
||||
|
||||
def test_asyncio_actor(ray_start_regular_shared):
|
||||
|
@ -224,6 +224,26 @@ async def test_asyncio_exit_actor(ray_start_regular_shared):
|
|||
ray.get(a.ping.remote())
|
||||
|
||||
|
||||
def test_async_callback(ray_start_regular_shared):
|
||||
global_set = set()
|
||||
|
||||
ref = ray.put(None)
|
||||
ref._on_completed(lambda _: global_set.add("completed-1"))
|
||||
wait_for_condition(lambda: "completed-1" in global_set)
|
||||
|
||||
signal = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def wait():
|
||||
ray.get(signal.wait.remote())
|
||||
|
||||
ref = wait.remote()
|
||||
ref._on_completed(lambda _: global_set.add("completed-2"))
|
||||
assert "completed-2" not in global_set
|
||||
signal.send.remote()
|
||||
wait_for_condition(lambda: "completed-2" in global_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue