[Serve] ServeHandle detects ActorError and drop replicas from target group (#26685)

This commit is contained in:
Simon Mo 2022-07-29 09:50:17 -07:00 committed by GitHub
parent 0b60d90283
commit 545c51609f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 7 deletions

View file

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError, RayTaskError
from ray.util import metrics
from ray.serve._private.common import RunningReplicaInfo
@ -87,6 +88,17 @@ class ReplicaSet:
{"deployment": self.deployment_name}
)
def _reset_replica_iterator(self):
"""Reset the iterator used to load balance replicas.
This call is expected to be called after the replica membership has
been updated. It will shuffle the replicas randomly to avoid multiple
handle sending requests in the same order.
"""
replicas = list(self.in_flight_queries.keys())
random.shuffle(replicas)
self.replica_iterator = itertools.cycle(replicas)
def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
added, removed, _ = compute_iterable_delta(
self.in_flight_queries.keys(), running_replicas
@ -97,14 +109,13 @@ class ReplicaSet:
for removed_replica in removed:
# Delete it directly because shutdown is processed by controller.
del self.in_flight_queries[removed_replica]
# Replicas might already been deleted due to early detection of
# actor error.
self.in_flight_queries.pop(removed_replica, None)
if len(added) > 0 or len(removed) > 0:
# Shuffle the keys to avoid synchronization across clients.
replicas = list(self.in_flight_queries.keys())
random.shuffle(replicas)
self.replica_iterator = itertools.cycle(replicas)
logger.debug(f"ReplicaSet: +{len(added)}, -{len(removed)} replicas.")
self._reset_replica_iterator()
self.config_updated_event.set()
def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]:
@ -160,9 +171,38 @@ class ReplicaSet:
def _drain_completed_object_refs(self) -> int:
refs = self._all_query_refs
# NOTE(simon): even though the timeout is 0, a large number of refs can still
# cause some blocking delay in the event loop. Consider moving this to async?
done, _ = ray.wait(refs, num_returns=len(refs), timeout=0)
for replica_in_flight_queries in self.in_flight_queries.values():
replica_in_flight_queries.difference_update(done)
replicas_to_remove = []
for replica_info, replica_in_flight_queries in self.in_flight_queries.items():
completed_queries = replica_in_flight_queries.intersection(done)
if len(completed_queries):
try:
# NOTE(simon): this ray.get call should be cheap because all these
# refs are ready as indicated by previous `ray.wait` call.
ray.get(list(completed_queries))
except RayActorError:
logger.debug(
f"Removing {replica_info.replica_tag} from replica set "
"because the actor exited."
)
replicas_to_remove.append(replica_info)
except RayTaskError:
# Ignore application error.
pass
except Exception:
logger.exception(
"Handle received unexpected error when processing request."
)
replica_in_flight_queries.difference_update(completed_queries)
if len(replicas_to_remove) > 0:
for replica_info in replicas_to_remove:
self.in_flight_queries.pop(replica_info, None)
self._reset_replica_iterator()
return len(done)
async def assign_replica(self, query: Query) -> ray.ObjectRef:

View file

@ -10,6 +10,7 @@ import pytest
import requests
import ray
import ray.actor
import ray._private.state
from ray import serve
from ray._private.test_utils import wait_for_condition
@ -650,6 +651,39 @@ def test_shutdown_remote(start_and_shutdown_ray_cli_function):
os.unlink(shutdown_file.name)
def test_handle_early_detect_failure(shutdown_ray):
"""Check that handle can be notified about replicas failure and take them out of the replicas set."""
ray.init()
serve.start(detached=True)
@serve.deployment(num_replicas=2, max_concurrent_queries=1)
def f(do_crash: bool = False):
if do_crash:
os._exit(1)
return os.getpid()
handle = serve.run(f.bind())
pids = ray.get([handle.remote() for _ in range(2)])
assert len(set(pids)) == 2
assert len(handle.router._replica_set.in_flight_queries.keys()) == 2
client = get_global_client()
# Kill the controller so that the replicas membership won't be updated
# through controller health check + long polling.
ray.kill(client._controller, no_restart=True)
with pytest.raises(RayActorError):
ray.get(handle.remote(do_crash=True))
pids = ray.get([handle.remote() for _ in range(10)])
assert len(set(pids)) == 1
assert len(handle.router._replica_set.in_flight_queries.keys()) == 1
# Restart the controller, and then clean up all the replicas
serve.start(detached=True)
serve.shutdown()
def test_autoscaler_shutdown_node_http_everynode(
shutdown_ray, call_ray_stop_only # noqa: F811
):