mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[serve] Re-add variable route support for old API (#15455)
This commit is contained in:
parent
79c24146bd
commit
668a784553
5 changed files with 214 additions and 59 deletions
|
@ -41,6 +41,39 @@ in :mod:`serve.start <ray.serve.start>`:
|
|||
instance group of Ray cluster to achieve high availability of Serve's HTTP
|
||||
proxies.
|
||||
|
||||
Variable HTTP Routes
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Ray Serve supports capturing path parameters. For example, in a call of the form
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
serve.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}")
|
||||
|
||||
the ``username`` parameter will be accessible in your backend code as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def my_backend(request):
|
||||
username = request.path_params["username"]
|
||||
...
|
||||
|
||||
Ray Serve uses Starlette's Router class under the hood for routing, so type
|
||||
conversion for path parameters is also supported, as well as multiple path parameters.
|
||||
For example, suppose this route is used:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
serve.create_endpoint(
|
||||
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
|
||||
|
||||
Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary
|
||||
available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is
|
||||
a Python int and ``3.14`` is a Python float.
|
||||
|
||||
For full details on the supported path parameters, see Starlette's
|
||||
`path parameters documentation <https://www.starlette.io/routing/#path-parameters>`_.
|
||||
|
||||
Custom HTTP response status codes
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ class EndpointInfo:
|
|||
http_methods: List[str]
|
||||
python_methods: Optional[List[str]] = field(default_factory=list)
|
||||
route: Optional[str] = None
|
||||
legacy: Optional[bool] = True
|
||||
|
||||
|
||||
class BackendInfo(BaseModel):
|
||||
|
|
|
@ -296,7 +296,8 @@ class ServeController:
|
|||
endpoint_info = EndpointInfo(
|
||||
ALL_HTTP_METHODS,
|
||||
route=route_prefix,
|
||||
python_methods=python_methods)
|
||||
python_methods=python_methods,
|
||||
legacy=False)
|
||||
self.endpoint_state.update_endpoint(name, endpoint_info,
|
||||
TrafficPolicy({
|
||||
name: 1.0
|
||||
|
|
|
@ -22,6 +22,85 @@ from ray.serve.handle import DEFAULT
|
|||
MAX_REPLICA_FAILURE_RETRIES = 10
|
||||
|
||||
|
||||
async def _send_request_to_handle(handle, scope, receive, send):
|
||||
http_body_bytes = await receive_http_body(scope, receive, send)
|
||||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
handle = handle.options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
|
||||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
|
||||
# scope["router"] and scope["endpoint"] contain references to a router
|
||||
# and endpoint object, respectively, which each in turn contain a
|
||||
# reference to the Serve client, which cannot be serialized.
|
||||
# The solution is to delete these from scope, as they will not be used.
|
||||
# TODO(edoakes): this can be removed once we deprecate the old API.
|
||||
if "router" in scope:
|
||||
del scope["router"]
|
||||
if "endpoint" in scope:
|
||||
del scope["endpoint"]
|
||||
|
||||
# NOTE(edoakes): it's important that we defer building the starlette
|
||||
# request until it reaches the backend replica to avoid unnecessary
|
||||
# serialization cost, so we use a simple dataclass here.
|
||||
request = HTTPRequestWrapper(scope, http_body_bytes)
|
||||
|
||||
retries = 0
|
||||
backoff_time_s = 0.05
|
||||
while retries < MAX_REPLICA_FAILURE_RETRIES:
|
||||
object_ref = await handle.remote(request)
|
||||
try:
|
||||
result = await object_ref
|
||||
break
|
||||
except RayActorError:
|
||||
logger.warning("Request failed due to replica failure. There are "
|
||||
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
|
||||
"remaining.")
|
||||
await asyncio.sleep(backoff_time_s)
|
||||
backoff_time_s *= 2
|
||||
retries += 1
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await Response(
|
||||
error_message, status_code=500).send(scope, receive, send)
|
||||
elif isinstance(result, starlette.responses.Response):
|
||||
await result(scope, receive, send)
|
||||
else:
|
||||
await Response(result).send(scope, receive, send)
|
||||
|
||||
|
||||
class ServeStarletteEndpoint:
|
||||
"""Wraps the given Serve endpoint in a Starlette endpoint.
|
||||
|
||||
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
|
||||
a Starlette app or Starlette Router which calls the given Serve endpoint.
|
||||
Usage:
|
||||
route = starlette.routing.Route(
|
||||
"/api",
|
||||
ServeStarletteEndpoint(endpoint_tag),
|
||||
methods=methods)
|
||||
app = starlette.applications.Starlette(routes=[route])
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint_tag: EndpointTag, path_prefix: str):
|
||||
self.endpoint_tag = endpoint_tag
|
||||
self.path_prefix = path_prefix
|
||||
self.handle = serve.get_handle(
|
||||
self.endpoint_tag, sync=False, missing_ok=True)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# Modify the path and root path so that reverse lookups and redirection
|
||||
# work as expected. We do this here instead of in replicas so it can be
|
||||
# changed without restarting the replicas.
|
||||
scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
|
||||
scope["root_path"] = self.path_prefix
|
||||
|
||||
await _send_request_to_handle(self.handle, scope, receive, send)
|
||||
|
||||
|
||||
class LongestPrefixRouter:
|
||||
"""Router that performs longest prefix matches on incoming routes."""
|
||||
|
||||
|
@ -117,10 +196,21 @@ class HTTPProxy:
|
|||
# controller instance this proxy is running in.
|
||||
ray.serve.api._set_internal_replica_context(None, None,
|
||||
controller_name, None)
|
||||
self.router = LongestPrefixRouter()
|
||||
|
||||
# Used only for displaying the route table.
|
||||
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
|
||||
|
||||
# NOTE(edoakes): we currently have both a Starlette router and a
|
||||
# longest-prefix router to maintain compatibility with the old API.
|
||||
# We first match on the Starlette router (which contains routes using
|
||||
# the old API) and then fall back to the prefix router. The Starlette
|
||||
# router can be removed once we deprecate the old API.
|
||||
self.starlette_router = starlette.routing.Router(
|
||||
default=self._fallback_to_prefix_router)
|
||||
self.prefix_router = LongestPrefixRouter()
|
||||
self.long_poll_client = LongPollClient(
|
||||
ray.get_actor(controller_name), {
|
||||
LongPollNamespace.ROUTE_TABLE: self.router.update_routes,
|
||||
LongPollNamespace.ROUTE_TABLE: self._update_routes,
|
||||
},
|
||||
call_in_event_loop=asyncio.get_event_loop())
|
||||
self.request_counter = metrics.Counter(
|
||||
|
@ -128,6 +218,41 @@ class HTTPProxy:
|
|||
description="The number of HTTP requests processed.",
|
||||
tag_keys=("route", ))
|
||||
|
||||
def _split_routes(
|
||||
self, endpoints: Dict[EndpointTag, EndpointInfo]) -> Tuple[Dict[
|
||||
EndpointTag, EndpointInfo], Dict[EndpointTag, EndpointInfo]]:
|
||||
starlette_routes = {}
|
||||
prefix_routes = {}
|
||||
for endpoint, info in endpoints.items():
|
||||
if info.legacy:
|
||||
starlette_routes[endpoint] = info
|
||||
else:
|
||||
prefix_routes[endpoint] = info
|
||||
|
||||
return starlette_routes, prefix_routes
|
||||
|
||||
def _update_routes(self,
|
||||
endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
|
||||
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
|
||||
for endpoint, info in endpoints.items():
|
||||
if not info.legacy and info.route is None:
|
||||
route = f"/{endpoint}"
|
||||
else:
|
||||
route = info.route
|
||||
self.route_info[route] = (endpoint, info.http_methods)
|
||||
|
||||
starlette_routes, prefix_routes = self._split_routes(endpoints)
|
||||
self.starlette_router.routes = [
|
||||
starlette.routing.Route(
|
||||
info.route,
|
||||
ServeStarletteEndpoint(endpoint, info.route),
|
||||
methods=info.http_methods)
|
||||
for endpoint, info in starlette_routes.items()
|
||||
if info.route is not None
|
||||
]
|
||||
|
||||
self.prefix_router.update_routes(prefix_routes)
|
||||
|
||||
async def block_until_endpoint_exists(self, endpoint: EndpointTag,
|
||||
timeout_s: float):
|
||||
start = time.time()
|
||||
|
@ -135,8 +260,9 @@ class HTTPProxy:
|
|||
if time.time() - start > timeout_s:
|
||||
raise TimeoutError(
|
||||
f"Waited {timeout_s} for {endpoint} to propagate.")
|
||||
if self.router.endpoint_exists(endpoint):
|
||||
return
|
||||
for existing_endpoint, _ in self.route_info.values():
|
||||
if existing_endpoint == endpoint:
|
||||
return
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
async def _not_found(self, scope, receive, send):
|
||||
|
@ -147,6 +273,22 @@ class HTTPProxy:
|
|||
status_code=404)
|
||||
await response.send(scope, receive, send)
|
||||
|
||||
async def _fallback_to_prefix_router(self, scope, receive, send):
|
||||
route_prefix, handle = self.prefix_router.match_route(
|
||||
scope["path"], scope["method"])
|
||||
if route_prefix is None:
|
||||
return await self._not_found(scope, receive, send)
|
||||
|
||||
# Modify the path and root path so that reverse lookups and redirection
|
||||
# work as expected. We do this here instead of in replicas so it can be
|
||||
# changed without restarting the replicas.
|
||||
if route_prefix != "/":
|
||||
assert not route_prefix.endswith("/")
|
||||
scope["path"] = scope["path"].replace(route_prefix, "", 1)
|
||||
scope["root_path"] = route_prefix
|
||||
|
||||
await _send_request_to_handle(handle, scope, receive, send)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Implements the ASGI protocol.
|
||||
|
||||
|
@ -158,61 +300,10 @@ class HTTPProxy:
|
|||
self.request_counter.inc(tags={"route": scope["path"]})
|
||||
|
||||
if scope["path"] == "/-/routes":
|
||||
return await starlette.responses.JSONResponse(
|
||||
self.router.route_info)(scope, receive, send)
|
||||
return await starlette.responses.JSONResponse(self.route_info)(
|
||||
scope, receive, send)
|
||||
|
||||
route_prefix, handle = self.router.match_route(scope["path"],
|
||||
scope["method"])
|
||||
if route_prefix is None:
|
||||
return await self._not_found(scope, receive, send)
|
||||
|
||||
http_body_bytes = await receive_http_body(scope, receive, send)
|
||||
|
||||
# Modify the path and root path so that reverse lookups and redirection
|
||||
# work as expected. We do this here instead of in replicas so it can be
|
||||
# changed without restarting the replicas.
|
||||
if route_prefix != "/":
|
||||
assert not route_prefix.endswith("/")
|
||||
scope["path"] = scope["path"].replace(route_prefix, "", 1)
|
||||
scope["root_path"] = route_prefix
|
||||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
handle = handle.options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||
DEFAULT.VALUE),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
|
||||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
|
||||
# NOTE(edoakes): it's important that we defer building the starlette
|
||||
# request until it reaches the backend replica to avoid unnecessary
|
||||
# serialization cost, so we use a simple dataclass here.
|
||||
request = HTTPRequestWrapper(scope, http_body_bytes)
|
||||
|
||||
retries = 0
|
||||
backoff_time_s = 0.05
|
||||
while retries < MAX_REPLICA_FAILURE_RETRIES:
|
||||
object_ref = await handle.remote(request)
|
||||
try:
|
||||
result = await object_ref
|
||||
break
|
||||
except RayActorError:
|
||||
logger.warning(
|
||||
"Request failed due to replica failure. There are "
|
||||
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
|
||||
"remaining.")
|
||||
await asyncio.sleep(backoff_time_s)
|
||||
backoff_time_s *= 2
|
||||
retries += 1
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await Response(
|
||||
error_message, status_code=500).send(scope, receive, send)
|
||||
elif isinstance(result, starlette.responses.Response):
|
||||
await result(scope, receive, send)
|
||||
else:
|
||||
await Response(result).send(scope, receive, send)
|
||||
await self.starlette_router(scope, receive, send)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
|
|
|
@ -540,6 +540,35 @@ def test_starlette_request(serve_instance):
|
|||
assert resp == long_string
|
||||
|
||||
|
||||
def test_variable_routes(serve_instance):
|
||||
def f(starlette_request):
|
||||
return starlette_request.path_params
|
||||
|
||||
serve.create_backend("f", f)
|
||||
serve.create_endpoint("basic", backend="f", route="/api/{username}")
|
||||
|
||||
# Test multiple variables and test type conversion
|
||||
serve.create_endpoint(
|
||||
"complex",
|
||||
backend="f",
|
||||
route="/api/{user_id:int}/{number:float}",
|
||||
methods=["POST"])
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/api/scaly").json() == {
|
||||
"username": "scaly"
|
||||
}
|
||||
|
||||
assert requests.post("http://127.0.0.1:8000/api/23/12.345").json() == {
|
||||
"user_id": 23,
|
||||
"number": 12.345
|
||||
}
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/-/routes").json() == {
|
||||
"/api/{username}": ["basic", ["GET"]],
|
||||
"/api/{user_id:int}/{number:float}": ["complex", ["POST"]]
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue