Working but not passing test (#7358)

This commit is contained in:
Edward Oakes 2020-02-28 12:57:28 -06:00 committed by GitHub
parent 34488f52f3
commit f321eaec9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 3 deletions

View file

@ -153,7 +153,15 @@ class SerializationContext:
return serialized_obj[0](*serialized_obj[1])
def object_id_serializer(obj):
self.add_contained_object_id(obj)
if self.is_in_band_serialization():
self.add_contained_object_id(obj)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.get_global_worker(
).core_worker.add_object_id_reference(obj)
owner_id = ""
owner_address = ""
# TODO(swang): Remove this check. Otherwise, we will not be able to
@ -206,6 +214,15 @@ class SerializationContext:
# construct a reducer
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
def is_in_band_serialization(self):
return getattr(self._thread_local, "in_band", False)
def set_in_band_serialization(self):
self._thread_local.in_band = True
def set_out_of_band_serialization(self):
self._thread_local.in_band = False
def set_outer_object_id(self, outer_object_id):
self._thread_local.outer_object_id = outer_object_id
@ -349,8 +366,16 @@ class SerializationContext:
assert ray.cloudpickle.FAST_CLOUDPICKLE_USED
writer = Pickle5Writer()
# TODO(swang): Check that contained_object_ids is empty.
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
try:
self.set_in_band_serialization()
inband = pickle.dumps(
value, protocol=5, buffer_callback=writer.buffer_callback)
except Exception as e:
self.get_and_clear_contained_object_ids()
raise e
finally:
self.set_out_of_band_serialization()
return Pickle5SerializedObject(
metadata, inband, writer,
self.get_and_clear_contained_object_ids())

View file

@ -722,6 +722,61 @@ def test_recursively_return_borrowed_object_id(one_worker_100MiB):
_fill_object_store_and_get(final_oid_bytes, succeed=False)
def test_out_of_band_serialized_object_id(one_worker_100MiB):
assert len(
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 0
oid = ray.put("hello")
_check_refcounts({oid: (1, 0)})
oid_str = ray.cloudpickle.dumps(oid)
_check_refcounts({oid: (2, 0)})
del oid
assert len(
ray.worker.global_worker.core_worker.get_all_reference_counts()) == 1
assert ray.get(ray.cloudpickle.loads(oid_str)) == "hello"
def test_captured_object_id(one_worker_100MiB):
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
@ray.remote
def f(signal):
ray.get(signal.wait.remote())
ray.get(captured_id) # noqa: F821
signal = SignalActor.remote()
oid = f.remote(signal)
# Delete local references.
del f
del captured_id
# Test that the captured object ID is pinned despite having no local
# references.
ray.get(signal.send.remote())
_fill_object_store_and_get(oid)
captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8))
@ray.remote
class Actor:
def get(self, signal):
ray.get(signal.wait.remote())
ray.get(captured_id) # noqa: F821
signal = SignalActor.remote()
actor = Actor.remote()
oid = actor.get.remote(signal)
# Delete local references.
del Actor
del captured_id
# Test that the captured object ID is pinned despite having no local
# references.
ray.get(signal.send.remote())
_fill_object_store_and_get(oid)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))