From 6084eb6a9f69025941ec220d8cc14dc30e48a86c Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 2 Aug 2022 20:04:03 -0700 Subject: [PATCH] Revert "Revert "[Serve] ServeHandle detects ActorError and drop replicas from target group (#26685)" (#27283)" (#27348) --- python/ray/serve/_private/long_poll.py | 7 +-- python/ray/serve/_private/router.py | 54 +++++++++++++++++++--- python/ray/serve/tests/test_standalone2.py | 37 +++++++++++++++ 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/python/ray/serve/_private/long_poll.py b/python/ray/serve/_private/long_poll.py index a1d431147..c76febee0 100644 --- a/python/ray/serve/_private/long_poll.py +++ b/python/ray/serve/_private/long_poll.py @@ -148,13 +148,14 @@ class LongPollClient: if isinstance(updates, (ray.exceptions.RayTaskError)): if isinstance(updates.as_instanceof_cause(), (asyncio.TimeoutError)): logger.debug("LongPollClient polling timed out. Retrying.") + self._schedule_to_event_loop(self._reset) else: # Some error happened in the controller. It could be a bug or # some undesired state. logger.error("LongPollHost errored\n" + updates.traceback_str) - # We must call this in event loop so it works in Ray Client. - # See https://github.com/ray-project/ray/issues/20971 - self._schedule_to_event_loop(self._poll_next) + # We must call this in event loop so it works in Ray Client. + # See https://github.com/ray-project/ray/issues/20971 + self._schedule_to_event_loop(self._poll_next) return logger.debug( diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index eb358030c..67d21707e 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional import ray from ray.actor import ActorHandle +from ray.exceptions import RayActorError, RayTaskError from ray.util import metrics from ray.serve._private.common import RunningReplicaInfo @@ -87,6 +88,17 @@ class ReplicaSet: {"deployment": self.deployment_name} ) + def _reset_replica_iterator(self): + """Reset the iterator used to load balance replicas. + + This call is expected to be called after the replica membership has + been updated. It will shuffle the replicas randomly to avoid multiple + handle sending requests in the same order. + """ + replicas = list(self.in_flight_queries.keys()) + random.shuffle(replicas) + self.replica_iterator = itertools.cycle(replicas) + def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): added, removed, _ = compute_iterable_delta( self.in_flight_queries.keys(), running_replicas @@ -97,14 +109,13 @@ class ReplicaSet: for removed_replica in removed: # Delete it directly because shutdown is processed by controller. - del self.in_flight_queries[removed_replica] + # Replicas might already been deleted due to early detection of + # actor error. + self.in_flight_queries.pop(removed_replica, None) if len(added) > 0 or len(removed) > 0: - # Shuffle the keys to avoid synchronization across clients. - replicas = list(self.in_flight_queries.keys()) - random.shuffle(replicas) - self.replica_iterator = itertools.cycle(replicas) logger.debug(f"ReplicaSet: +{len(added)}, -{len(removed)} replicas.") + self._reset_replica_iterator() self.config_updated_event.set() def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]: @@ -160,9 +171,38 @@ class ReplicaSet: def _drain_completed_object_refs(self) -> int: refs = self._all_query_refs + # NOTE(simon): even though the timeout is 0, a large number of refs can still + # cause some blocking delay in the event loop. Consider moving this to async? done, _ = ray.wait(refs, num_returns=len(refs), timeout=0) - for replica_in_flight_queries in self.in_flight_queries.values(): - replica_in_flight_queries.difference_update(done) + replicas_to_remove = [] + for replica_info, replica_in_flight_queries in self.in_flight_queries.items(): + completed_queries = replica_in_flight_queries.intersection(done) + if len(completed_queries): + try: + # NOTE(simon): this ray.get call should be cheap because all these + # refs are ready as indicated by previous `ray.wait` call. + ray.get(list(completed_queries)) + except RayActorError: + logger.debug( + f"Removing {replica_info.replica_tag} from replica set " + "because the actor exited." + ) + replicas_to_remove.append(replica_info) + except RayTaskError: + # Ignore application error. + pass + except Exception: + logger.exception( + "Handle received unexpected error when processing request." + ) + + replica_in_flight_queries.difference_update(completed_queries) + + if len(replicas_to_remove) > 0: + for replica_info in replicas_to_remove: + self.in_flight_queries.pop(replica_info, None) + self._reset_replica_iterator() + return len(done) async def assign_replica(self, query: Query) -> ray.ObjectRef: diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py index 0afff8711..527af7ef7 100644 --- a/python/ray/serve/tests/test_standalone2.py +++ b/python/ray/serve/tests/test_standalone2.py @@ -10,6 +10,7 @@ import pytest import requests import ray +import ray.actor import ray._private.state from ray import serve from ray._private.test_utils import wait_for_condition @@ -650,6 +651,42 @@ def test_shutdown_remote(start_and_shutdown_ray_cli_function): os.unlink(shutdown_file.name) +def test_handle_early_detect_failure(shutdown_ray): + """Check that handle can be notified about replicas failure. + + It should detect replica raises ActorError and take them out of the replicas set. + """ + ray.init() + serve.start(detached=True) + + @serve.deployment(num_replicas=2, max_concurrent_queries=1) + def f(do_crash: bool = False): + if do_crash: + os._exit(1) + return os.getpid() + + handle = serve.run(f.bind()) + pids = ray.get([handle.remote() for _ in range(2)]) + assert len(set(pids)) == 2 + assert len(handle.router._replica_set.in_flight_queries.keys()) == 2 + + client = get_global_client() + # Kill the controller so that the replicas membership won't be updated + # through controller health check + long polling. + ray.kill(client._controller, no_restart=True) + + with pytest.raises(RayActorError): + ray.get(handle.remote(do_crash=True)) + + pids = ray.get([handle.remote() for _ in range(10)]) + assert len(set(pids)) == 1 + assert len(handle.router._replica_set.in_flight_queries.keys()) == 1 + + # Restart the controller, and then clean up all the replicas + serve.start(detached=True) + serve.shutdown() + + def test_autoscaler_shutdown_node_http_everynode( shutdown_ray, call_ray_stop_only # noqa: F811 ):