[Core] Add private on_completed callback for ObjectRef (#13688)

This commit is contained in:
Simon Mo 2021-01-27 16:32:00 -08:00 committed by GitHub
parent 32ec0d205f
commit 25fa391193
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 36 deletions

View file

@ -1569,12 +1569,13 @@ cdef class CoreWorker:
return ref_counts
def get_async(self, ObjectRef object_ref, future):
cpython.Py_INCREF(future)
def set_get_async_callback(self, ObjectRef object_ref, callback):
cpython.Py_INCREF(callback)
CCoreWorkerProcess.GetCoreWorker().GetAsync(
object_ref.native(),
async_set_result,
<void*>future)
object_ref.native(),
async_callback,
<void*>callback
)
def push_error(self, JobID job_id, error_type, error_message,
double timestamp):
@ -1588,13 +1589,11 @@ cdef class CoreWorker:
resource_name.encode("ascii"), capacity,
CNodeID.FromBinary(client_id.binary()))
cdef void async_set_result(shared_ptr[CRayObject] obj,
CObjectID object_ref,
void *future) with gil:
cdef void async_callback(shared_ptr[CRayObject] obj,
CObjectID object_ref,
void *user_callback) with gil:
cdef:
c_vector[shared_ptr[CRayObject]] objects_to_deserialize
py_future = <object>(future)
loop = py_future._loop
# Object is retrieved from in memory store.
# Here we go through the code path used to deserialize objects.
@ -1605,23 +1604,6 @@ cdef void async_set_result(shared_ptr[CRayObject] obj,
result = ray.worker.global_worker.deserialize_objects(
data_metadata_pairs, ids_to_deserialize)[0]
def set_future():
# Issue #11030, #8841
# If this future has result set already, we just need to
# skip the set result/exception procedure.
if py_future.done():
cpython.Py_DECREF(py_future)
return
if isinstance(result, RayTaskError):
ray.worker.last_task_error_raise_time = time.time()
py_future.set_exception(result.as_instanceof_cause())
elif isinstance(result, RayError):
# Directly raise exception for RayActorError
py_future.set_exception(result)
else:
py_future.set_result(result)
cpython.Py_DECREF(py_future)
loop.call_soon_threadsafe(set_future)
py_callback = <object>user_callback
py_callback(result)
cpython.Py_DECREF(py_callback)

View file

@ -1,6 +1,7 @@
from ray.includes.unique_ids cimport CObjectID
import asyncio
from typing import Callable, Any
import ray
@ -71,10 +72,41 @@ cdef class ObjectRef(BaseID):
def as_future(self):
loop = asyncio.get_event_loop()
core_worker = ray.worker.global_worker.core_worker
py_future = loop.create_future()
def callback(result):
loop = py_future._loop
def set_future():
# Issue #11030, #8841
# If this future has result set already, we just need to
# skip the set result/exception procedure.
if py_future.done():
return
if isinstance(result, RayTaskError):
ray.worker.last_task_error_raise_time = time.time()
py_future.set_exception(result.as_instanceof_cause())
elif isinstance(result, RayError):
# Directly raise exception for RayActorError
py_future.set_exception(result)
else:
py_future.set_result(result)
loop.call_soon_threadsafe(set_future)
self._on_completed(callback)
future = loop.create_future()
core_worker.get_async(self, future)
# A hack to keep a reference to the object ref for ref counting.
future.object_ref = self
return future
py_future.object_ref = self
return py_future
def _on_completed(self, py_callback: Callable[[Any], None]):
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
core_worker = ray.worker.global_worker.core_worker
core_worker.set_get_async_callback(self, py_callback)
return self

View file

@ -6,7 +6,7 @@ import threading
import pytest
import ray
from ray.test_utils import SignalActor
from ray.test_utils import SignalActor, wait_for_condition
def test_asyncio_actor(ray_start_regular_shared):
@ -224,6 +224,26 @@ async def test_asyncio_exit_actor(ray_start_regular_shared):
ray.get(a.ping.remote())
def test_async_callback(ray_start_regular_shared):
global_set = set()
ref = ray.put(None)
ref._on_completed(lambda _: global_set.add("completed-1"))
wait_for_condition(lambda: "completed-1" in global_set)
signal = SignalActor.remote()
@ray.remote
def wait():
ray.get(signal.wait.remote())
ref = wait.remote()
ref._on_completed(lambda _: global_set.add("completed-2"))
assert "completed-2" not in global_set
signal.send.remote()
wait_for_condition(lambda: "completed-2" in global_set)
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))