mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
Working but not passing test (#7358)
This commit is contained in:
parent
34488f52f3
commit
f321eaec9b
2 changed files with 83 additions and 3 deletions
|
@ -153,7 +153,15 @@ class SerializationContext:
|
||||||
return serialized_obj[0](*serialized_obj[1])
|
return serialized_obj[0](*serialized_obj[1])
|
||||||
|
|
||||||
def object_id_serializer(obj):
|
def object_id_serializer(obj):
|
||||||
self.add_contained_object_id(obj)
|
if self.is_in_band_serialization():
|
||||||
|
self.add_contained_object_id(obj)
|
||||||
|
else:
|
||||||
|
# If this serialization is out-of-band (e.g., from a call to
|
||||||
|
# cloudpickle directly or captured in a remote function/actor),
|
||||||
|
# then pin the object for the lifetime of this worker by adding
|
||||||
|
# a local reference that won't ever be removed.
|
||||||
|
ray.worker.get_global_worker(
|
||||||
|
).core_worker.add_object_id_reference(obj)
|
||||||
owner_id = ""
|
owner_id = ""
|
||||||
owner_address = ""
|
owner_address = ""
|
||||||
# TODO(swang): Remove this check. Otherwise, we will not be able to
|
# TODO(swang): Remove this check. Otherwise, we will not be able to
|
||||||
|
@ -206,6 +214,15 @@ class SerializationContext:
|
||||||
# construct a reducer
|
# construct a reducer
|
||||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||||
|
|
||||||
|
def is_in_band_serialization(self):
|
||||||
|
return getattr(self._thread_local, "in_band", False)
|
||||||
|
|
||||||
|
def set_in_band_serialization(self):
|
||||||
|
self._thread_local.in_band = True
|
||||||
|
|
||||||
|
def set_out_of_band_serialization(self):
|
||||||
|
self._thread_local.in_band = False
|
||||||
|
|
||||||
def set_outer_object_id(self, outer_object_id):
|
def set_outer_object_id(self, outer_object_id):
|
||||||
self._thread_local.outer_object_id = outer_object_id
|
self._thread_local.outer_object_id = outer_object_id
|
||||||
|
|
||||||
|
@ -349,8 +366,16 @@ class SerializationContext:
|
||||||
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
|
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
|
||||||
writer = Pickle5Writer()
|
writer = Pickle5Writer()
|
||||||
# TODO(swang): Check that contained_object_ids is empty.
|
# TODO(swang): Check that contained_object_ids is empty.
|
||||||
inband = pickle.dumps(
|
try:
|
||||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
self.set_in_band_serialization()
|
||||||
|
inband = pickle.dumps(
|
||||||
|
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_and_clear_contained_object_ids()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self.set_out_of_band_serialization()
|
||||||
|
|
||||||
return Pickle5SerializedObject(
|
return Pickle5SerializedObject(
|
||||||
metadata, inband, writer,
|
metadata, inband, writer,
|
||||||
self.get_and_clear_contained_object_ids())
|
self.get_and_clear_contained_object_ids())
|
||||||
|
|
|
@ -722,6 +722,61 @@ def test_recursively_return_borrowed_object_id(one_worker_100MiB):
|
||||||
_fill_object_store_and_get(final_oid_bytes, succeed=False)
|
_fill_object_store_and_get(final_oid_bytes, succeed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_out_of_band_serialized_object_id(one_worker_100MiB):
|
||||||
|
assert len(
|
||||||
|
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 0
|
||||||
|
oid = ray.put("hello")
|
||||||
|
_check_refcounts({oid: (1, 0)})
|
||||||
|
oid_str = ray.cloudpickle.dumps(oid)
|
||||||
|
_check_refcounts({oid: (2, 0)})
|
||||||
|
del oid
|
||||||
|
assert len(
|
||||||
|
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 1
|
||||||
|
assert ray.get(ray.cloudpickle.loads(oid_str)) == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_captured_object_id(one_worker_100MiB):
|
||||||
|
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def f(signal):
|
||||||
|
ray.get(signal.wait.remote())
|
||||||
|
ray.get(captured_id) # noqa: F821
|
||||||
|
|
||||||
|
signal = SignalActor.remote()
|
||||||
|
oid = f.remote(signal)
|
||||||
|
|
||||||
|
# Delete local references.
|
||||||
|
del f
|
||||||
|
del captured_id
|
||||||
|
|
||||||
|
# Test that the captured object ID is pinned despite having no local
|
||||||
|
# references.
|
||||||
|
ray.get(signal.send.remote())
|
||||||
|
_fill_object_store_and_get(oid)
|
||||||
|
|
||||||
|
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class Actor:
|
||||||
|
def get(self, signal):
|
||||||
|
ray.get(signal.wait.remote())
|
||||||
|
ray.get(captured_id) # noqa: F821
|
||||||
|
|
||||||
|
signal = SignalActor.remote()
|
||||||
|
actor = Actor.remote()
|
||||||
|
oid = actor.get.remote(signal)
|
||||||
|
|
||||||
|
# Delete local references.
|
||||||
|
del Actor
|
||||||
|
del captured_id
|
||||||
|
|
||||||
|
# Test that the captured object ID is pinned despite having no local
|
||||||
|
# references.
|
||||||
|
ray.get(signal.send.remote())
|
||||||
|
_fill_object_store_and_get(oid)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
Loading…
Add table
Reference in a new issue