add test, fix _should_cache

This commit is contained in:
Chris Wong 2022-05-08 22:35:14 -07:00
parent d84c018211
commit ce5290aba9
2 changed files with 31 additions and 1 deletions

View file

@ -430,6 +430,34 @@ def test_disconnect_during_large_put():
assert result.shape == (1024, 1024, 128)
def test_disconnect_during_large_schedule():
"""
Disconnect during a remote call with a large (multi-chunk) argument.
"""
i = 0
started = False
def fail_halfway(_):
# Inject an error halfway through the object transfer
nonlocal i, started
if not started:
return
i += 1
if i == 8:
raise RuntimeError
@ray.remote
def f(a):
return a.shape
with start_middleman_server(on_data_request=fail_halfway):
started = True
a = np.random.random((1024, 1024, 128))
result = ray.get(f.remote(a))
assert i > 8 # Check that the failure was injected
assert result == (1024, 1024, 128)
def test_valid_actor_state():
"""
Repeatedly inject errors in the middle of mutating actor calls. Check

View file

@ -62,8 +62,10 @@ def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
req_type = req.WhichOneof("type")
if req_type == "get" and req.get.asynchronous:
return False
if req_type == "put" or req_type == "task":
if req_type == "put":
return req.put.chunk_id == req.put.total_chunks - 1
if req_type == "task":
return req.task.chunk_id == req.task.total_chunks - 1
return req_type not in ("acknowledge", "connection_cleanup")