From f321eaec9b1665aa10d5ea1ecd3192724a6d6590 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 28 Feb 2020 12:57:28 -0600 Subject: [PATCH] Working but not passing test (#7358) --- python/ray/serialization.py | 31 ++++++++++-- python/ray/tests/test_reference_counting.py | 55 +++++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/python/ray/serialization.py b/python/ray/serialization.py index c7a352366..79cf548f6 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -153,7 +153,15 @@ class SerializationContext: return serialized_obj[0](*serialized_obj[1]) def object_id_serializer(obj): - self.add_contained_object_id(obj) + if self.is_in_band_serialization(): + self.add_contained_object_id(obj) + else: + # If this serialization is out-of-band (e.g., from a call to + # cloudpickle directly or captured in a remote function/actor), + # then pin the object for the lifetime of this worker by adding + # a local reference that won't ever be removed. + ray.worker.get_global_worker( + ).core_worker.add_object_id_reference(obj) owner_id = "" owner_address = "" # TODO(swang): Remove this check. Otherwise, we will not be able to @@ -206,6 +214,15 @@ class SerializationContext: # construct a reducer pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer + def is_in_band_serialization(self): + return getattr(self._thread_local, "in_band", False) + + def set_in_band_serialization(self): + self._thread_local.in_band = True + + def set_out_of_band_serialization(self): + self._thread_local.in_band = False + def set_outer_object_id(self, outer_object_id): self._thread_local.outer_object_id = outer_object_id @@ -349,8 +366,16 @@ class SerializationContext: assert ray.cloudpickle.FAST_CLOUDPICKLE_USED writer = Pickle5Writer() # TODO(swang): Check that contained_object_ids is empty. - inband = pickle.dumps( - value, protocol=5, buffer_callback=writer.buffer_callback) + try: + self.set_in_band_serialization() + inband = pickle.dumps( + value, protocol=5, buffer_callback=writer.buffer_callback) + except Exception as e: + self.get_and_clear_contained_object_ids() + raise e + finally: + self.set_out_of_band_serialization() + return Pickle5SerializedObject( metadata, inband, writer, self.get_and_clear_contained_object_ids()) diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index 361b054ce..f69cdfbf4 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -722,6 +722,61 @@ def test_recursively_return_borrowed_object_id(one_worker_100MiB): _fill_object_store_and_get(final_oid_bytes, succeed=False) +def test_out_of_band_serialized_object_id(one_worker_100MiB): + assert len( + ray.worker.global_worker.core_worker.get_all_reference_counts()) == 0 + oid = ray.put("hello") + _check_refcounts({oid: (1, 0)}) + oid_str = ray.cloudpickle.dumps(oid) + _check_refcounts({oid: (2, 0)}) + del oid + assert len( + ray.worker.global_worker.core_worker.get_all_reference_counts()) == 1 + assert ray.get(ray.cloudpickle.loads(oid_str)) == "hello" + + +def test_captured_object_id(one_worker_100MiB): + captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8)) + + @ray.remote + def f(signal): + ray.get(signal.wait.remote()) + ray.get(captured_id) # noqa: F821 + + signal = SignalActor.remote() + oid = f.remote(signal) + + # Delete local references. + del f + del captured_id + + # Test that the captured object ID is pinned despite having no local + # references. + ray.get(signal.send.remote()) + _fill_object_store_and_get(oid) + + captured_id = ray.put(np.zeros(10 * 1024 * 1024, dtype=np.uint8)) + + @ray.remote + class Actor: + def get(self, signal): + ray.get(signal.wait.remote()) + ray.get(captured_id) # noqa: F821 + + signal = SignalActor.remote() + actor = Actor.remote() + oid = actor.get.remote(signal) + + # Delete local references. + del Actor + del captured_id + + # Test that the captured object ID is pinned despite having no local + # references. + ray.get(signal.send.remote()) + _fill_object_store_and_get(oid) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))