mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Working but not passing test (#7358)
This commit is contained in:
parent
34488f52f3
commit
f321eaec9b
2 changed files with 83 additions and 3 deletions
|
@ -153,7 +153,15 @@ class SerializationContext:
|
|||
return serialized_obj[0](*serialized_obj[1])
|
||||
|
||||
def object_id_serializer(obj):
|
||||
self.add_contained_object_id(obj)
|
||||
if self.is_in_band_serialization():
|
||||
self.add_contained_object_id(obj)
|
||||
else:
|
||||
# If this serialization is out-of-band (e.g., from a call to
|
||||
# cloudpickle directly or captured in a remote function/actor),
|
||||
# then pin the object for the lifetime of this worker by adding
|
||||
# a local reference that won't ever be removed.
|
||||
ray.worker.get_global_worker(
|
||||
).core_worker.add_object_id_reference(obj)
|
||||
owner_id = ""
|
||||
owner_address = ""
|
||||
# TODO(swang): Remove this check. Otherwise, we will not be able to
|
||||
|
@ -206,6 +214,15 @@ class SerializationContext:
|
|||
# construct a reducer
|
||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||
|
||||
def is_in_band_serialization(self):
|
||||
return getattr(self._thread_local, "in_band", False)
|
||||
|
||||
def set_in_band_serialization(self):
|
||||
self._thread_local.in_band = True
|
||||
|
||||
def set_out_of_band_serialization(self):
|
||||
self._thread_local.in_band = False
|
||||
|
||||
def set_outer_object_id(self, outer_object_id):
|
||||
self._thread_local.outer_object_id = outer_object_id
|
||||
|
||||
|
@ -349,8 +366,16 @@ class SerializationContext:
|
|||
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
|
|
|
@ -722,6 +722,61 @@ def test_recursively_return_borrowed_object_id(one_worker_100MiB):
|
|||
_fill_object_store_and_get(final_oid_bytes, succeed=False)
|
||||
|
||||
|
||||
def test_out_of_band_serialized_object_id(one_worker_100MiB):
|
||||
assert len(
|
||||
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 0
|
||||
oid = ray.put("hello")
|
||||
_check_refcounts({oid: (1, 0)})
|
||||
oid_str = ray.cloudpickle.dumps(oid)
|
||||
_check_refcounts({oid: (2, 0)})
|
||||
del oid
|
||||
assert len(
|
||||
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 1
|
||||
assert ray.get(ray.cloudpickle.loads(oid_str)) == "hello"
|
||||
|
||||
|
||||
def test_captured_object_id(one_worker_100MiB):
|
||||
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
|
||||
|
||||
@ray.remote
|
||||
def f(signal):
|
||||
ray.get(signal.wait.remote())
|
||||
ray.get(captured_id) # noqa: F821
|
||||
|
||||
signal = SignalActor.remote()
|
||||
oid = f.remote(signal)
|
||||
|
||||
# Delete local references.
|
||||
del f
|
||||
del captured_id
|
||||
|
||||
# Test that the captured object ID is pinned despite having no local
|
||||
# references.
|
||||
ray.get(signal.send.remote())
|
||||
_fill_object_store_and_get(oid)
|
||||
|
||||
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def get(self, signal):
|
||||
ray.get(signal.wait.remote())
|
||||
ray.get(captured_id) # noqa: F821
|
||||
|
||||
signal = SignalActor.remote()
|
||||
actor = Actor.remote()
|
||||
oid = actor.get.remote(signal)
|
||||
|
||||
# Delete local references.
|
||||
del Actor
|
||||
del captured_id
|
||||
|
||||
# Test that the captured object ID is pinned despite having no local
|
||||
# references.
|
||||
ray.get(signal.send.remote())
|
||||
_fill_object_store_and_get(oid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue