[client] Fix ray client object ref releasing in wrong context. (#22025)

This commit is contained in:
Yi Cheng 2022-02-01 22:42:39 -08:00 committed by GitHub
parent 54fe2f80bb
commit 588d540b68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 22 deletions

View file

@ -93,6 +93,7 @@ cdef class ObjectRef(BaseID):
cdef class ClientObjectRef(ObjectRef): cdef class ClientObjectRef(ObjectRef):
cdef object _mutex cdef object _mutex
cdef object _id_future cdef object _id_future
cdef object _client_worker_ref
cdef _set_id(self, id) cdef _set_id(self, id)
cdef inline _wait_for_id(self, timeout=None) cdef inline _wait_for_id(self, timeout=None)
@ -107,6 +108,7 @@ cdef class ActorID(BaseID):
cdef class ClientActorRef(ActorID): cdef class ClientActorRef(ActorID):
cdef object _mutex cdef object _mutex
cdef object _id_future cdef object _id_future
cdef object _client_worker_ref
cdef _set_id(self, id) cdef _set_id(self, id)
cdef inline _wait_for_id(self, timeout=None) cdef inline _wait_for_id(self, timeout=None)

View file

@ -5,6 +5,7 @@ import concurrent.futures
import functools import functools
import logging import logging
import threading import threading
import weakref
from typing import Callable, Any, Union from typing import Callable, Any, Union
import ray import ray
@ -154,6 +155,9 @@ cdef class ClientObjectRef(ObjectRef):
def __init__(self, id: Union[bytes, concurrent.futures.Future]): def __init__(self, id: Union[bytes, concurrent.futures.Future]):
self.in_core_worker = False self.in_core_worker = False
self._mutex = threading.Lock() self._mutex = threading.Lock()
# client worker might be cleaned up before __dealloc__ is called.
# so use a weakref to check whether it's alive or not.
self._client_worker_ref = weakref.ref(client.ray.get_context().client_worker)
if isinstance(id, bytes): if isinstance(id, bytes):
self._set_id(id) self._set_id(id)
elif isinstance(id, concurrent.futures.Future): elif isinstance(id, concurrent.futures.Future):
@ -162,14 +166,8 @@ cdef class ClientObjectRef(ObjectRef):
raise TypeError("Unexpected type for id {}".format(id)) raise TypeError("Unexpected type for id {}".format(id))
def __dealloc__(self): def __dealloc__(self):
if client is None or client.ray is None: client_worker = self._client_worker_ref()
# Similar issue as mentioned in ObjectRef.__dealloc__ above. The if client_worker is not None and client_worker.is_connected():
# client package or client.ray object might be set
# to None when the script exits. Should be safe to skip
# call_release in this case, since the client should have already
# disconnected at this point.
return
if client.ray.is_connected():
try: try:
self._wait_for_id() self._wait_for_id()
# cython would suppress this exception as well, but it tries to # cython would suppress this exception as well, but it tries to
@ -182,7 +180,7 @@ cdef class ClientObjectRef(ObjectRef):
"a method on the actor reference before its destructor " "a method on the actor reference before its destructor "
"is run.") "is run.")
if not self.data.IsNil(): if not self.data.IsNil():
client.ray.call_release(self.id) client_worker.call_release(self.id)
cdef CObjectID native(self): cdef CObjectID native(self):
self._wait_for_id() self._wait_for_id()
@ -251,13 +249,16 @@ cdef class ClientObjectRef(ObjectRef):
data = loads_from_server(resp.get.data) data = loads_from_server(resp.get.data)
py_callback(data) py_callback(data)
client_worker = self._client_worker_ref()
client.ray._register_callback(self, deserialize_obj) assert client_worker is not None
client_worker.register_callback(self, deserialize_obj)
cdef _set_id(self, id): cdef _set_id(self, id):
check_id(id) check_id(id)
self.data = CObjectID.FromBinary(<c_string>id) self.data = CObjectID.FromBinary(<c_string>id)
client.ray.call_retain(id) client_worker = self._client_worker_ref()
assert client_worker is not None
client_worker.call_retain(id)
cdef inline _wait_for_id(self, timeout=None): cdef inline _wait_for_id(self, timeout=None):
if self._id_future: if self._id_future:

View file

@ -325,6 +325,9 @@ cdef class ClientActorRef(ActorID):
def __init__(self, id: Union[bytes, concurrent.futures.Future]): def __init__(self, id: Union[bytes, concurrent.futures.Future]):
self._mutex = threading.Lock() self._mutex = threading.Lock()
# client worker might be cleaned up before __dealloc__ is called.
# so use a weakref to check whether it's alive or not.
self._client_worker_ref = weakref.ref(client.ray.get_context().client_worker)
if isinstance(id, bytes): if isinstance(id, bytes):
self._set_id(id) self._set_id(id)
elif isinstance(id, Future): elif isinstance(id, Future):
@ -333,13 +336,8 @@ cdef class ClientActorRef(ActorID):
raise TypeError("Unexpected type for id {}".format(id)) raise TypeError("Unexpected type for id {}".format(id))
def __dealloc__(self): def __dealloc__(self):
if client is None or client.ray is None: client_worker = self._client_worker_ref()
# The client package or client.ray object might be set if client_worker is not None and client_worker.is_connected():
# to None when the script exits. Should be safe to skip
# call_release in this case, since the client should have already
# disconnected at this point.
return
if client.ray.is_connected():
try: try:
self._wait_for_id() self._wait_for_id()
# cython would suppress this exception as well, but it tries to # cython would suppress this exception as well, but it tries to
@ -352,7 +350,7 @@ cdef class ClientActorRef(ActorID):
"a method on the actor reference before its destructor " "a method on the actor reference before its destructor "
"is run.") "is run.")
if not self.data.IsNil(): if not self.data.IsNil():
client.ray.call_release(self.id) client_worker.call_release(self.id)
def binary(self): def binary(self):
self._wait_for_id() self._wait_for_id()
@ -381,7 +379,9 @@ cdef class ClientActorRef(ActorID):
cdef _set_id(self, id): cdef _set_id(self, id):
check_id(id, CActorID.Size()) check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id) self.data = CActorID.FromBinary(<c_string>id)
client.ray.call_retain(id) client_worker = self._client_worker_ref()
assert client_worker is not None
client_worker.call_retain(id)
cdef _wait_for_id(self, timeout=None): cdef _wait_for_id(self, timeout=None):
if self._id_future: if self._id_future:
@ -390,7 +390,6 @@ cdef class ClientActorRef(ActorID):
self._set_id(self._id_future.result(timeout=timeout)) self._set_id(self._id_future.result(timeout=timeout))
self._id_future = None self._id_future = None
cdef class FunctionID(UniqueID): cdef class FunctionID(UniqueID):
def __init__(self, id): def __init__(self, id):

View file

@ -758,5 +758,31 @@ def test_init_requires_no_resources(call_ray_start, use_client):
ray.get(f.remote()) ray.get(f.remote())
@pytest.mark.parametrize(
"call_ray_start",
["ray start --head --ray-client-server-port 25553 --num-cpus 1"],
indirect=True,
)
def test_object_ref_release(call_ray_start):
"""This is to test the release of an object in previous session is
handled correctly.
"""
import ray
ray.init("ray://localhost:25553")
a = ray.put("Hello")
ray.shutdown()
ray.init("ray://localhost:25553")
# a is release in the session which doesn't create it.
del a
with disable_client_hook():
# Make sure a doesn't generate a release request.
ref_cnt = ray.util.client.ray.get_context().client_worker.reference_count
assert all(v > 0 for v in ref_cnt.values())
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -138,8 +138,10 @@ class _ClientContext:
def disconnect(self): def disconnect(self):
"""Disconnect the Ray Client.""" """Disconnect the Ray Client."""
if self.client_worker is not None: if self.client_worker is not None:
self.client_worker.close() self.client_worker.close()
self.api.worker = None
self.client_worker = None self.client_worker = None
# remote can be called outside of a connection, which is why it # remote can be called outside of a connection, which is why it