diff --git a/python/ray/tests/test_client_reconnect.py b/python/ray/tests/test_client_reconnect.py index 5b1985cd4..98feef300 100644 --- a/python/ray/tests/test_client_reconnect.py +++ b/python/ray/tests/test_client_reconnect.py @@ -430,6 +430,34 @@ def test_disconnect_during_large_put(): assert result.shape == (1024, 1024, 128) +def test_disconnect_during_large_schedule(): + """ + Disconnect during a remote call with a large (multi-chunk) argument. + """ + i = 0 + started = False + + def fail_halfway(_): + # Inject an error halfway through the object transfer + nonlocal i, started + if not started: + return + i += 1 + if i == 8: + raise RuntimeError + + @ray.remote + def f(a): + return a.shape + + with start_middleman_server(on_data_request=fail_halfway): + started = True + a = np.random.random((1024, 1024, 128)) + result = ray.get(f.remote(a)) + assert i > 8 # Check that the failure was injected + assert result == (1024, 1024, 128) + + def test_valid_actor_state(): """ Repeatedly inject errors in the middle of mutating actor calls. Check diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 72fb4472b..c459b1e80 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -62,8 +62,10 @@ def _should_cache(req: ray_client_pb2.DataRequest) -> bool: req_type = req.WhichOneof("type") if req_type == "get" and req.get.asynchronous: return False - if req_type == "put" or req_type == "task": + if req_type == "put": return req.put.chunk_id == req.put.total_chunks - 1 + if req_type == "task": + return req.task.chunk_id == req.task.total_chunks - 1 return req_type not in ("acknowledge", "connection_cleanup")