[Serve] Get ServeHandle on the same node (#11477)

This commit is contained in:
Simon Mo 2020-10-20 10:44:23 -07:00 committed by GitHub
parent ef96793d3f
commit 2c5cb95b42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 7 deletions

View file

@ -1,5 +1,6 @@
import atexit
from functools import wraps
import random
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
@ -7,7 +8,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger)
get_random_letters, logger, get_node_id_for_actor)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.actor import ActorHandle
@ -317,23 +318,38 @@ class Client:
proportion))
@_ensure_connected
def get_handle(self, endpoint_name: str) -> RayServeHandle:
def get_handle(self,
endpoint_name: str,
missing_ok: Optional[bool] = False) -> RayServeHandle:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
endpoint_name (str): A registered service endpoint.
missing_ok (bool): If true, then Serve won't check the endpoint is
registered. False by default.
Returns:
RayServeHandle
"""
if endpoint_name not in ray.get(
if not missing_ok and endpoint_name not in ray.get(
self._controller.get_all_endpoints.remote()):
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
# TODO(edoakes): we should choose the router on the same node.
routers = ray.get(self._controller.get_routers.remote())
routers = list(ray.get(self._controller.get_routers.remote()).values())
current_node_id = ray.get_runtime_context().node_id.hex()
try:
router_chosen = next(
filter(lambda r: get_node_id_for_actor(r) == current_node_id,
routers))
except StopIteration:
logger.warning(
f"When getting a handle for {endpoint_name}, Serve can't find "
"a router on the same node. Serve will use a random router.")
router_chosen = random.choice(routers)
return RayServeHandle(
list(routers.values())[0],
router_chosen,
endpoint_name,
)

View file

@ -2,6 +2,7 @@
The test file for all standalone tests that doesn't
requires a shared Serve instance.
"""
from random import randint
import sys
import socket
@ -13,7 +14,7 @@ from ray import serve
from ray.cluster_utils import Cluster
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.utils import (block_until_http_ready, get_all_node_ids,
format_actor_name)
format_actor_name, get_node_id_for_actor)
from ray.test_utils import wait_for_condition
from ray._private.services import new_port
@ -128,5 +129,48 @@ def test_middleware():
ray.shutdown()
@pytest.mark.skipif(
not hasattr(socket, "SO_REUSEPORT"),
reason=("Port sharing only works on newer verion of Linux. "
"This test can only be ran when port sharing is supported."))
def test_cluster_handle_affinity():
cluster = Cluster()
# HACK: using two different ip address so the placement constraint for
# resource check later will work.
head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=4)
cluster.add_node(node_ip_address="0.0.0.0", num_cpus=4)
ray.init(head_node.address)
# Make sure we have two nodes.
node_ids = [n["NodeID"] for n in ray.nodes()]
assert len(node_ids) == 2
# Start the backend.
client = serve.start(http_port=randint(10000, 30000), detached=True)
client.create_backend("hi:v0", lambda _: "hi")
client.create_endpoint("hi", backend="hi:v0")
# Try to retrieve the handle from both head and worker node, check the
# router's node id.
@ray.remote
def check_handle_router_id():
client = serve.connect()
handle = client.get_handle("hi")
return get_node_id_for_actor(handle.router_handle)
router_node_ids = ray.get([
check_handle_router_id.options(resources={
node_id: 0.01
}).remote() for node_id in ray.state.node_ids()
])
assert set(router_node_ids) == set(node_ids)
# Clean up the nodes (otherwise Ray will segfault).
ray.shutdown()
cluster.shutdown()
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -299,3 +299,9 @@ def get_all_node_ids():
node_ids.append(("{}-{}".format(node_id, index), node_id))
return node_ids
def get_node_id_for_actor(actor_handle):
"""Given an actor handle, return the node id it's placed on."""
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]