mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Serve] Get ServeHandle on the same node (#11477)
This commit is contained in:
parent
ef96793d3f
commit
2c5cb95b42
3 changed files with 73 additions and 7 deletions
|
@ -1,5 +1,6 @@
|
|||
import atexit
|
||||
from functools import wraps
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
|
@ -7,7 +8,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
|||
from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters, logger)
|
||||
get_random_letters, logger, get_node_id_for_actor)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.actor import ActorHandle
|
||||
|
@ -317,23 +318,38 @@ class Client:
|
|||
proportion))
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(self, endpoint_name: str) -> RayServeHandle:
|
||||
def get_handle(self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False) -> RayServeHandle:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
missing_ok (bool): If true, then Serve won't check the endpoint is
|
||||
registered. False by default.
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
"""
|
||||
if endpoint_name not in ray.get(
|
||||
if not missing_ok and endpoint_name not in ray.get(
|
||||
self._controller.get_all_endpoints.remote()):
|
||||
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
|
||||
|
||||
# TODO(edoakes): we should choose the router on the same node.
|
||||
routers = ray.get(self._controller.get_routers.remote())
|
||||
routers = list(ray.get(self._controller.get_routers.remote()).values())
|
||||
current_node_id = ray.get_runtime_context().node_id.hex()
|
||||
|
||||
try:
|
||||
router_chosen = next(
|
||||
filter(lambda r: get_node_id_for_actor(r) == current_node_id,
|
||||
routers))
|
||||
except StopIteration:
|
||||
logger.warning(
|
||||
f"When getting a handle for {endpoint_name}, Serve can't find "
|
||||
"a router on the same node. Serve will use a random router.")
|
||||
router_chosen = random.choice(routers)
|
||||
|
||||
return RayServeHandle(
|
||||
list(routers.values())[0],
|
||||
router_chosen,
|
||||
endpoint_name,
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
The test file for all standalone tests that doesn't
|
||||
requires a shared Serve instance.
|
||||
"""
|
||||
from random import randint
|
||||
import sys
|
||||
import socket
|
||||
|
||||
|
@ -13,7 +14,7 @@ from ray import serve
|
|||
from ray.cluster_utils import Cluster
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.utils import (block_until_http_ready, get_all_node_ids,
|
||||
format_actor_name)
|
||||
format_actor_name, get_node_id_for_actor)
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray._private.services import new_port
|
||||
|
||||
|
@ -128,5 +129,48 @@ def test_middleware():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(socket, "SO_REUSEPORT"),
|
||||
reason=("Port sharing only works on newer verion of Linux. "
|
||||
"This test can only be ran when port sharing is supported."))
|
||||
def test_cluster_handle_affinity():
|
||||
cluster = Cluster()
|
||||
# HACK: using two different ip address so the placement constraint for
|
||||
# resource check later will work.
|
||||
head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=4)
|
||||
cluster.add_node(node_ip_address="0.0.0.0", num_cpus=4)
|
||||
|
||||
ray.init(head_node.address)
|
||||
|
||||
# Make sure we have two nodes.
|
||||
node_ids = [n["NodeID"] for n in ray.nodes()]
|
||||
assert len(node_ids) == 2
|
||||
|
||||
# Start the backend.
|
||||
client = serve.start(http_port=randint(10000, 30000), detached=True)
|
||||
client.create_backend("hi:v0", lambda _: "hi")
|
||||
client.create_endpoint("hi", backend="hi:v0")
|
||||
|
||||
# Try to retrieve the handle from both head and worker node, check the
|
||||
# router's node id.
|
||||
@ray.remote
|
||||
def check_handle_router_id():
|
||||
client = serve.connect()
|
||||
handle = client.get_handle("hi")
|
||||
return get_node_id_for_actor(handle.router_handle)
|
||||
|
||||
router_node_ids = ray.get([
|
||||
check_handle_router_id.options(resources={
|
||||
node_id: 0.01
|
||||
}).remote() for node_id in ray.state.node_ids()
|
||||
])
|
||||
|
||||
assert set(router_node_ids) == set(node_ids)
|
||||
|
||||
# Clean up the nodes (otherwise Ray will segfault).
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -299,3 +299,9 @@ def get_all_node_ids():
|
|||
node_ids.append(("{}-{}".format(node_id, index), node_id))
|
||||
|
||||
return node_ids
|
||||
|
||||
|
||||
def get_node_id_for_actor(actor_handle):
|
||||
"""Given an actor handle, return the node id it's placed on."""
|
||||
|
||||
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
|
||||
|
|
Loading…
Add table
Reference in a new issue