[ray client] Fix ctrl-c for ray.get() by setting a short-server side timeout (#14425)

This commit is contained in:
Eric Liang 2021-03-04 10:36:42 -08:00 committed by GitHub
parent 190ab40645
commit 2cf4c7253c
2 changed files with 63 additions and 2 deletions

View file

@ -3,12 +3,45 @@ import time
import sys
import logging
import threading
import _thread
import ray.util.client.server.server as ray_client_server
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_interrupt_ray_get(call_ray_stop_only):
import ray
ray.init(num_cpus=2)
with ray_start_client_server() as ray:
@ray.remote
def block():
print("blocking run")
time.sleep(99)
@ray.remote
def fast():
print("fast run")
time.sleep(1)
return "ok"
class Interrupt(threading.Thread):
def run(self):
time.sleep(2)
_thread.interrupt_main()
it = Interrupt()
it.start()
with pytest.raises(KeyboardInterrupt):
ray.get(block.remote())
# Assert we can still get new items after the interrupt.
assert ray.get(fast.remote()) == "ok"
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:

View file

@ -22,6 +22,7 @@ import ray.cloudpickle as cloudpickle
from ray.cloudpickle.compat import pickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.exceptions import GetTimeoutError
from ray.util.client.client_pickler import convert_to_arg
from ray.util.client.client_pickler import dumps_from_client
from ray.util.client.client_pickler import loads_from_server
@ -44,6 +45,11 @@ logger = logging.getLogger(__name__)
INITIAL_TIMEOUT_SEC = 5
MAX_TIMEOUT_SEC = 30
# The max amount of time an operation can run blocking in the server. This
# allows for Ctrl-C of the client to work without explicitly cancelling server
# operations.
MAX_BLOCKING_OPERATION_TIME_S = 2
def backoff(timeout: int) -> int:
timeout = timeout + 5
@ -171,7 +177,29 @@ class Worker:
"list of IDs or just an ID: %s" % type(vals))
if timeout is None:
timeout = 0
out = [self._get(x, timeout) for x in to_get]
deadline = None
else:
deadline = time.monotonic() + timeout
out = []
for obj_ref in to_get:
res = None
# Implement non-blocking get with a short-polling loop. This allows
# cancellation of gets via Ctrl-C, since we never block for long.
while True:
try:
if deadline:
op_timeout = min(
MAX_BLOCKING_OPERATION_TIME_S,
max(deadline - time.monotonic(), 0.001))
else:
op_timeout = MAX_BLOCKING_OPERATION_TIME_S
res = self._get(obj_ref, op_timeout)
break
except GetTimeoutError:
if deadline and time.monotonic() > deadline:
raise
logger.debug("Internal retry for get {}".format(obj_ref))
out.append(res)
if single:
out = out[0]
return out
@ -188,7 +216,6 @@ class Worker:
except pickle.UnpicklingError:
logger.exception("Failed to deserialize {}".format(data.error))
raise
logger.error(err)
raise err
return loads_from_server(data.data)
@ -221,6 +248,7 @@ class Worker:
resp = self.data_client.PutObject(req)
return ClientObjectRef(resp.id)
# TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
def wait(self,
object_refs: List[ClientObjectRef],
*,