diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 7715063ea..29cdc5611 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -93,6 +93,7 @@ cdef class ObjectRef(BaseID): cdef class ClientObjectRef(ObjectRef): cdef object _mutex cdef object _id_future + cdef object _client_worker_ref cdef _set_id(self, id) cdef inline _wait_for_id(self, timeout=None) @@ -107,6 +108,7 @@ cdef class ActorID(BaseID): cdef class ClientActorRef(ActorID): cdef object _mutex cdef object _id_future + cdef object _client_worker_ref cdef _set_id(self, id) cdef inline _wait_for_id(self, timeout=None) diff --git a/python/ray/includes/object_ref.pxi b/python/ray/includes/object_ref.pxi index 860f2fe28..95073b86e 100644 --- a/python/ray/includes/object_ref.pxi +++ b/python/ray/includes/object_ref.pxi @@ -5,6 +5,7 @@ import concurrent.futures import functools import logging import threading +import weakref from typing import Callable, Any, Union import ray @@ -154,6 +155,9 @@ cdef class ClientObjectRef(ObjectRef): def __init__(self, id: Union[bytes, concurrent.futures.Future]): self.in_core_worker = False self._mutex = threading.Lock() + # client worker might be cleaned up before __dealloc__ is called. + # so use a weakref to check whether it's alive or not. + self._client_worker_ref = weakref.ref(client.ray.get_context().client_worker) if isinstance(id, bytes): self._set_id(id) elif isinstance(id, concurrent.futures.Future): @@ -162,14 +166,8 @@ cdef class ClientObjectRef(ObjectRef): raise TypeError("Unexpected type for id {}".format(id)) def __dealloc__(self): - if client is None or client.ray is None: - # Similar issue as mentioned in ObjectRef.__dealloc__ above. The - # client package or client.ray object might be set - # to None when the script exits. Should be safe to skip - # call_release in this case, since the client should have already - # disconnected at this point. - return - if client.ray.is_connected(): + client_worker = self._client_worker_ref() + if client_worker is not None and client_worker.is_connected(): try: self._wait_for_id() # cython would suppress this exception as well, but it tries to @@ -182,7 +180,7 @@ cdef class ClientObjectRef(ObjectRef): "a method on the actor reference before its destructor " "is run.") if not self.data.IsNil(): - client.ray.call_release(self.id) + client_worker.call_release(self.id) cdef CObjectID native(self): self._wait_for_id() @@ -251,13 +249,16 @@ cdef class ClientObjectRef(ObjectRef): data = loads_from_server(resp.get.data) py_callback(data) - - client.ray._register_callback(self, deserialize_obj) + client_worker = self._client_worker_ref() + assert client_worker is not None + client_worker.register_callback(self, deserialize_obj) cdef _set_id(self, id): check_id(id) self.data = CObjectID.FromBinary(id) - client.ray.call_retain(id) + client_worker = self._client_worker_ref() + assert client_worker is not None + client_worker.call_retain(id) cdef inline _wait_for_id(self, timeout=None): if self._id_future: diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 93822f570..e86462922 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -325,6 +325,9 @@ cdef class ClientActorRef(ActorID): def __init__(self, id: Union[bytes, concurrent.futures.Future]): self._mutex = threading.Lock() + # client worker might be cleaned up before __dealloc__ is called. + # so use a weakref to check whether it's alive or not. + self._client_worker_ref = weakref.ref(client.ray.get_context().client_worker) if isinstance(id, bytes): self._set_id(id) elif isinstance(id, Future): @@ -333,13 +336,8 @@ cdef class ClientActorRef(ActorID): raise TypeError("Unexpected type for id {}".format(id)) def __dealloc__(self): - if client is None or client.ray is None: - # The client package or client.ray object might be set - # to None when the script exits. Should be safe to skip - # call_release in this case, since the client should have already - # disconnected at this point. - return - if client.ray.is_connected(): + client_worker = self._client_worker_ref() + if client_worker is not None and client_worker.is_connected(): try: self._wait_for_id() # cython would suppress this exception as well, but it tries to @@ -352,7 +350,7 @@ cdef class ClientActorRef(ActorID): "a method on the actor reference before its destructor " "is run.") if not self.data.IsNil(): - client.ray.call_release(self.id) + client_worker.call_release(self.id) def binary(self): self._wait_for_id() @@ -381,7 +379,9 @@ cdef class ClientActorRef(ActorID): cdef _set_id(self, id): check_id(id, CActorID.Size()) self.data = CActorID.FromBinary(id) - client.ray.call_retain(id) + client_worker = self._client_worker_ref() + assert client_worker is not None + client_worker.call_retain(id) cdef _wait_for_id(self, timeout=None): if self._id_future: @@ -390,7 +390,6 @@ cdef class ClientActorRef(ActorID): self._set_id(self._id_future.result(timeout=timeout)) self._id_future = None - cdef class FunctionID(UniqueID): def __init__(self, id): diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index d02edeed5..07ba7b9a3 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -758,5 +758,31 @@ def test_init_requires_no_resources(call_ray_start, use_client): ray.get(f.remote()) +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25553 --num-cpus 1"], + indirect=True, +) +def test_object_ref_release(call_ray_start): + """This is to test the release of an object in previous session is + handled correctly. + """ + import ray + + ray.init("ray://localhost:25553") + + a = ray.put("Hello") + + ray.shutdown() + ray.init("ray://localhost:25553") + # a is release in the session which doesn't create it. + del a + + with disable_client_hook(): + # Make sure a doesn't generate a release request. + ref_cnt = ray.util.client.ray.get_context().client_worker.reference_count + assert all(v > 0 for v in ref_cnt.values()) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 2f55135e6..174ddf2de 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -138,8 +138,10 @@ class _ClientContext: def disconnect(self): """Disconnect the Ray Client.""" + if self.client_worker is not None: self.client_worker.close() + self.api.worker = None self.client_worker = None # remote can be called outside of a connection, which is why it