mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] ServeHandle detects ActorError and drop replicas from target group (#26685)
This commit is contained in:
parent
0b60d90283
commit
545c51609f
2 changed files with 81 additions and 7 deletions
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
|
from ray.exceptions import RayActorError, RayTaskError
|
||||||
from ray.util import metrics
|
from ray.util import metrics
|
||||||
|
|
||||||
from ray.serve._private.common import RunningReplicaInfo
|
from ray.serve._private.common import RunningReplicaInfo
|
||||||
|
@ -87,6 +88,17 @@ class ReplicaSet:
|
||||||
{"deployment": self.deployment_name}
|
{"deployment": self.deployment_name}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _reset_replica_iterator(self):
|
||||||
|
"""Reset the iterator used to load balance replicas.
|
||||||
|
|
||||||
|
This call is expected to be called after the replica membership has
|
||||||
|
been updated. It will shuffle the replicas randomly to avoid multiple
|
||||||
|
handle sending requests in the same order.
|
||||||
|
"""
|
||||||
|
replicas = list(self.in_flight_queries.keys())
|
||||||
|
random.shuffle(replicas)
|
||||||
|
self.replica_iterator = itertools.cycle(replicas)
|
||||||
|
|
||||||
def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
|
def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
|
||||||
added, removed, _ = compute_iterable_delta(
|
added, removed, _ = compute_iterable_delta(
|
||||||
self.in_flight_queries.keys(), running_replicas
|
self.in_flight_queries.keys(), running_replicas
|
||||||
|
@ -97,14 +109,13 @@ class ReplicaSet:
|
||||||
|
|
||||||
for removed_replica in removed:
|
for removed_replica in removed:
|
||||||
# Delete it directly because shutdown is processed by controller.
|
# Delete it directly because shutdown is processed by controller.
|
||||||
del self.in_flight_queries[removed_replica]
|
# Replicas might already been deleted due to early detection of
|
||||||
|
# actor error.
|
||||||
|
self.in_flight_queries.pop(removed_replica, None)
|
||||||
|
|
||||||
if len(added) > 0 or len(removed) > 0:
|
if len(added) > 0 or len(removed) > 0:
|
||||||
# Shuffle the keys to avoid synchronization across clients.
|
|
||||||
replicas = list(self.in_flight_queries.keys())
|
|
||||||
random.shuffle(replicas)
|
|
||||||
self.replica_iterator = itertools.cycle(replicas)
|
|
||||||
logger.debug(f"ReplicaSet: +{len(added)}, -{len(removed)} replicas.")
|
logger.debug(f"ReplicaSet: +{len(added)}, -{len(removed)} replicas.")
|
||||||
|
self._reset_replica_iterator()
|
||||||
self.config_updated_event.set()
|
self.config_updated_event.set()
|
||||||
|
|
||||||
def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]:
|
def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]:
|
||||||
|
@ -160,9 +171,38 @@ class ReplicaSet:
|
||||||
|
|
||||||
def _drain_completed_object_refs(self) -> int:
|
def _drain_completed_object_refs(self) -> int:
|
||||||
refs = self._all_query_refs
|
refs = self._all_query_refs
|
||||||
|
# NOTE(simon): even though the timeout is 0, a large number of refs can still
|
||||||
|
# cause some blocking delay in the event loop. Consider moving this to async?
|
||||||
done, _ = ray.wait(refs, num_returns=len(refs), timeout=0)
|
done, _ = ray.wait(refs, num_returns=len(refs), timeout=0)
|
||||||
for replica_in_flight_queries in self.in_flight_queries.values():
|
replicas_to_remove = []
|
||||||
replica_in_flight_queries.difference_update(done)
|
for replica_info, replica_in_flight_queries in self.in_flight_queries.items():
|
||||||
|
completed_queries = replica_in_flight_queries.intersection(done)
|
||||||
|
if len(completed_queries):
|
||||||
|
try:
|
||||||
|
# NOTE(simon): this ray.get call should be cheap because all these
|
||||||
|
# refs are ready as indicated by previous `ray.wait` call.
|
||||||
|
ray.get(list(completed_queries))
|
||||||
|
except RayActorError:
|
||||||
|
logger.debug(
|
||||||
|
f"Removing {replica_info.replica_tag} from replica set "
|
||||||
|
"because the actor exited."
|
||||||
|
)
|
||||||
|
replicas_to_remove.append(replica_info)
|
||||||
|
except RayTaskError:
|
||||||
|
# Ignore application error.
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Handle received unexpected error when processing request."
|
||||||
|
)
|
||||||
|
|
||||||
|
replica_in_flight_queries.difference_update(completed_queries)
|
||||||
|
|
||||||
|
if len(replicas_to_remove) > 0:
|
||||||
|
for replica_info in replicas_to_remove:
|
||||||
|
self.in_flight_queries.pop(replica_info, None)
|
||||||
|
self._reset_replica_iterator()
|
||||||
|
|
||||||
return len(done)
|
return len(done)
|
||||||
|
|
||||||
async def assign_replica(self, query: Query) -> ray.ObjectRef:
|
async def assign_replica(self, query: Query) -> ray.ObjectRef:
|
||||||
|
|
|
@ -10,6 +10,7 @@ import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
import ray.actor
|
||||||
import ray._private.state
|
import ray._private.state
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray._private.test_utils import wait_for_condition
|
from ray._private.test_utils import wait_for_condition
|
||||||
|
@ -650,6 +651,39 @@ def test_shutdown_remote(start_and_shutdown_ray_cli_function):
|
||||||
os.unlink(shutdown_file.name)
|
os.unlink(shutdown_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_early_detect_failure(shutdown_ray):
|
||||||
|
"""Check that handle can be notified about replicas failure and take them out of the replicas set."""
|
||||||
|
ray.init()
|
||||||
|
serve.start(detached=True)
|
||||||
|
|
||||||
|
@serve.deployment(num_replicas=2, max_concurrent_queries=1)
|
||||||
|
def f(do_crash: bool = False):
|
||||||
|
if do_crash:
|
||||||
|
os._exit(1)
|
||||||
|
return os.getpid()
|
||||||
|
|
||||||
|
handle = serve.run(f.bind())
|
||||||
|
pids = ray.get([handle.remote() for _ in range(2)])
|
||||||
|
assert len(set(pids)) == 2
|
||||||
|
assert len(handle.router._replica_set.in_flight_queries.keys()) == 2
|
||||||
|
|
||||||
|
client = get_global_client()
|
||||||
|
# Kill the controller so that the replicas membership won't be updated
|
||||||
|
# through controller health check + long polling.
|
||||||
|
ray.kill(client._controller, no_restart=True)
|
||||||
|
|
||||||
|
with pytest.raises(RayActorError):
|
||||||
|
ray.get(handle.remote(do_crash=True))
|
||||||
|
|
||||||
|
pids = ray.get([handle.remote() for _ in range(10)])
|
||||||
|
assert len(set(pids)) == 1
|
||||||
|
assert len(handle.router._replica_set.in_flight_queries.keys()) == 1
|
||||||
|
|
||||||
|
# Restart the controller, and then clean up all the replicas
|
||||||
|
serve.start(detached=True)
|
||||||
|
serve.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def test_autoscaler_shutdown_node_http_everynode(
|
def test_autoscaler_shutdown_node_http_everynode(
|
||||||
shutdown_ray, call_ray_stop_only # noqa: F811
|
shutdown_ray, call_ray_stop_only # noqa: F811
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Reference in a new issue