mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[ray client] Fix ctrl-c for ray.get() by setting a short-server side timeout (#14425)
This commit is contained in:
parent
190ab40645
commit
2cf4c7253c
2 changed files with 63 additions and 2 deletions
|
@ -3,12 +3,45 @@ import time
|
|||
import sys
|
||||
import logging
|
||||
import threading
|
||||
import _thread
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_interrupt_ray_get(call_ray_stop_only):
|
||||
import ray
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def block():
|
||||
print("blocking run")
|
||||
time.sleep(99)
|
||||
|
||||
@ray.remote
|
||||
def fast():
|
||||
print("fast run")
|
||||
time.sleep(1)
|
||||
return "ok"
|
||||
|
||||
class Interrupt(threading.Thread):
|
||||
def run(self):
|
||||
time.sleep(2)
|
||||
_thread.interrupt_main()
|
||||
|
||||
it = Interrupt()
|
||||
it.start()
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
ray.get(block.remote())
|
||||
|
||||
# Assert we can still get new items after the interrupt.
|
||||
assert ray.get(fast.remote()) == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
|
|
@ -22,6 +22,7 @@ import ray.cloudpickle as cloudpickle
|
|||
from ray.cloudpickle.compat import pickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray.util.client.client_pickler import convert_to_arg
|
||||
from ray.util.client.client_pickler import dumps_from_client
|
||||
from ray.util.client.client_pickler import loads_from_server
|
||||
|
@ -44,6 +45,11 @@ logger = logging.getLogger(__name__)
|
|||
INITIAL_TIMEOUT_SEC = 5
|
||||
MAX_TIMEOUT_SEC = 30
|
||||
|
||||
# The max amount of time an operation can run blocking in the server. This
|
||||
# allows for Ctrl-C of the client to work without explicitly cancelling server
|
||||
# operations.
|
||||
MAX_BLOCKING_OPERATION_TIME_S = 2
|
||||
|
||||
|
||||
def backoff(timeout: int) -> int:
|
||||
timeout = timeout + 5
|
||||
|
@ -171,7 +177,29 @@ class Worker:
|
|||
"list of IDs or just an ID: %s" % type(vals))
|
||||
if timeout is None:
|
||||
timeout = 0
|
||||
out = [self._get(x, timeout) for x in to_get]
|
||||
deadline = None
|
||||
else:
|
||||
deadline = time.monotonic() + timeout
|
||||
out = []
|
||||
for obj_ref in to_get:
|
||||
res = None
|
||||
# Implement non-blocking get with a short-polling loop. This allows
|
||||
# cancellation of gets via Ctrl-C, since we never block for long.
|
||||
while True:
|
||||
try:
|
||||
if deadline:
|
||||
op_timeout = min(
|
||||
MAX_BLOCKING_OPERATION_TIME_S,
|
||||
max(deadline - time.monotonic(), 0.001))
|
||||
else:
|
||||
op_timeout = MAX_BLOCKING_OPERATION_TIME_S
|
||||
res = self._get(obj_ref, op_timeout)
|
||||
break
|
||||
except GetTimeoutError:
|
||||
if deadline and time.monotonic() > deadline:
|
||||
raise
|
||||
logger.debug("Internal retry for get {}".format(obj_ref))
|
||||
out.append(res)
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
@ -188,7 +216,6 @@ class Worker:
|
|||
except pickle.UnpicklingError:
|
||||
logger.exception("Failed to deserialize {}".format(data.error))
|
||||
raise
|
||||
logger.error(err)
|
||||
raise err
|
||||
return loads_from_server(data.data)
|
||||
|
||||
|
@ -221,6 +248,7 @@ class Worker:
|
|||
resp = self.data_client.PutObject(req)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
# TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
|
|
Loading…
Add table
Reference in a new issue