From 25fa391193caf86f1f08daedccde5216a986c302 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 27 Jan 2021 16:32:00 -0800 Subject: [PATCH] [Core] Add private on_completed callback for ObjectRef (#13688) --- python/ray/_raylet.pyx | 42 +++++++++--------------------- python/ray/includes/object_ref.pxi | 42 ++++++++++++++++++++++++++---- python/ray/tests/test_asyncio.py | 22 +++++++++++++++- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0fc3f4bf2..dc9fceaca 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1569,12 +1569,13 @@ cdef class CoreWorker: return ref_counts - def get_async(self, ObjectRef object_ref, future): - cpython.Py_INCREF(future) + def set_get_async_callback(self, ObjectRef object_ref, callback): + cpython.Py_INCREF(callback) CCoreWorkerProcess.GetCoreWorker().GetAsync( - object_ref.native(), - async_set_result, - future) + object_ref.native(), + async_callback, + callback + ) def push_error(self, JobID job_id, error_type, error_message, double timestamp): @@ -1588,13 +1589,11 @@ cdef class CoreWorker: resource_name.encode("ascii"), capacity, CNodeID.FromBinary(client_id.binary())) -cdef void async_set_result(shared_ptr[CRayObject] obj, - CObjectID object_ref, - void *future) with gil: +cdef void async_callback(shared_ptr[CRayObject] obj, + CObjectID object_ref, + void *user_callback) with gil: cdef: c_vector[shared_ptr[CRayObject]] objects_to_deserialize - py_future = (future) - loop = py_future._loop # Object is retrieved from in memory store. # Here we go through the code path used to deserialize objects. @@ -1605,23 +1604,6 @@ cdef void async_set_result(shared_ptr[CRayObject] obj, result = ray.worker.global_worker.deserialize_objects( data_metadata_pairs, ids_to_deserialize)[0] - def set_future(): - # Issue #11030, #8841 - # If this future has result set already, we just need to - # skip the set result/exception procedure. - if py_future.done(): - cpython.Py_DECREF(py_future) - return - - if isinstance(result, RayTaskError): - ray.worker.last_task_error_raise_time = time.time() - py_future.set_exception(result.as_instanceof_cause()) - elif isinstance(result, RayError): - # Directly raise exception for RayActorError - py_future.set_exception(result) - else: - py_future.set_result(result) - - cpython.Py_DECREF(py_future) - - loop.call_soon_threadsafe(set_future) + py_callback = user_callback + py_callback(result) + cpython.Py_DECREF(py_callback) diff --git a/python/ray/includes/object_ref.pxi b/python/ray/includes/object_ref.pxi index 3353e696e..31c59d08b 100644 --- a/python/ray/includes/object_ref.pxi +++ b/python/ray/includes/object_ref.pxi @@ -1,6 +1,7 @@ from ray.includes.unique_ids cimport CObjectID import asyncio +from typing import Callable, Any import ray @@ -71,10 +72,41 @@ cdef class ObjectRef(BaseID): def as_future(self): loop = asyncio.get_event_loop() - core_worker = ray.worker.global_worker.core_worker + py_future = loop.create_future() + + def callback(result): + loop = py_future._loop + + def set_future(): + # Issue #11030, #8841 + # If this future has result set already, we just need to + # skip the set result/exception procedure. + if py_future.done(): + return + + if isinstance(result, RayTaskError): + ray.worker.last_task_error_raise_time = time.time() + py_future.set_exception(result.as_instanceof_cause()) + elif isinstance(result, RayError): + # Directly raise exception for RayActorError + py_future.set_exception(result) + else: + py_future.set_result(result) + + loop.call_soon_threadsafe(set_future) + + self._on_completed(callback) - future = loop.create_future() - core_worker.get_async(self, future) # A hack to keep a reference to the object ref for ref counting. - future.object_ref = self - return future + py_future.object_ref = self + return py_future + + def _on_completed(self, py_callback: Callable[[Any], None]): + """Register a callback that will be called after Object is ready. + If the ObjectRef is already ready, the callback will be called soon. + The callback should take the result as the only argument. The result + can be an exception object in case of task error. + """ + core_worker = ray.worker.global_worker.core_worker + core_worker.set_get_async_callback(self, py_callback) + return self diff --git a/python/ray/tests/test_asyncio.py b/python/ray/tests/test_asyncio.py index 18dd63a22..31f03aefa 100644 --- a/python/ray/tests/test_asyncio.py +++ b/python/ray/tests/test_asyncio.py @@ -6,7 +6,7 @@ import threading import pytest import ray -from ray.test_utils import SignalActor +from ray.test_utils import SignalActor, wait_for_condition def test_asyncio_actor(ray_start_regular_shared): @@ -224,6 +224,26 @@ async def test_asyncio_exit_actor(ray_start_regular_shared): ray.get(a.ping.remote()) +def test_async_callback(ray_start_regular_shared): + global_set = set() + + ref = ray.put(None) + ref._on_completed(lambda _: global_set.add("completed-1")) + wait_for_condition(lambda: "completed-1" in global_set) + + signal = SignalActor.remote() + + @ray.remote + def wait(): + ray.get(signal.wait.remote()) + + ref = wait.remote() + ref._on_completed(lambda _: global_set.add("completed-2")) + assert "completed-2" not in global_set + signal.send.remote() + wait_for_condition(lambda: "completed-2" in global_set) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))