mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[serve] Re-add variable route support for old API (#15455)
This commit is contained in:
parent
79c24146bd
commit
668a784553
5 changed files with 214 additions and 59 deletions
|
@ -41,6 +41,39 @@ in :mod:`serve.start <ray.serve.start>`:
|
||||||
instance group of Ray cluster to achieve high availability of Serve's HTTP
|
instance group of Ray cluster to achieve high availability of Serve's HTTP
|
||||||
proxies.
|
proxies.
|
||||||
|
|
||||||
|
Variable HTTP Routes
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Ray Serve supports capturing path parameters. For example, in a call of the form
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
serve.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}")
|
||||||
|
|
||||||
|
the ``username`` parameter will be accessible in your backend code as follows:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def my_backend(request):
|
||||||
|
username = request.path_params["username"]
|
||||||
|
...
|
||||||
|
|
||||||
|
Ray Serve uses Starlette's Router class under the hood for routing, so type
|
||||||
|
conversion for path parameters is also supported, as well as multiple path parameters.
|
||||||
|
For example, suppose this route is used:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
serve.create_endpoint(
|
||||||
|
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
|
||||||
|
|
||||||
|
Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary
|
||||||
|
available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is
|
||||||
|
a Python int and ``3.14`` is a Python float.
|
||||||
|
|
||||||
|
For full details on the supported path parameters, see Starlette's
|
||||||
|
`path parameters documentation <https://www.starlette.io/routing/#path-parameters>`_.
|
||||||
|
|
||||||
Custom HTTP response status codes
|
Custom HTTP response status codes
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ class EndpointInfo:
|
||||||
http_methods: List[str]
|
http_methods: List[str]
|
||||||
python_methods: Optional[List[str]] = field(default_factory=list)
|
python_methods: Optional[List[str]] = field(default_factory=list)
|
||||||
route: Optional[str] = None
|
route: Optional[str] = None
|
||||||
|
legacy: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class BackendInfo(BaseModel):
|
class BackendInfo(BaseModel):
|
||||||
|
|
|
@ -296,7 +296,8 @@ class ServeController:
|
||||||
endpoint_info = EndpointInfo(
|
endpoint_info = EndpointInfo(
|
||||||
ALL_HTTP_METHODS,
|
ALL_HTTP_METHODS,
|
||||||
route=route_prefix,
|
route=route_prefix,
|
||||||
python_methods=python_methods)
|
python_methods=python_methods,
|
||||||
|
legacy=False)
|
||||||
self.endpoint_state.update_endpoint(name, endpoint_info,
|
self.endpoint_state.update_endpoint(name, endpoint_info,
|
||||||
TrafficPolicy({
|
TrafficPolicy({
|
||||||
name: 1.0
|
name: 1.0
|
||||||
|
|
|
@ -22,6 +22,85 @@ from ray.serve.handle import DEFAULT
|
||||||
MAX_REPLICA_FAILURE_RETRIES = 10
|
MAX_REPLICA_FAILURE_RETRIES = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_request_to_handle(handle, scope, receive, send):
|
||||||
|
http_body_bytes = await receive_http_body(scope, receive, send)
|
||||||
|
|
||||||
|
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||||
|
handle = handle.options(
|
||||||
|
method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
|
||||||
|
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
|
||||||
|
http_method=scope["method"].upper(),
|
||||||
|
http_headers=headers)
|
||||||
|
|
||||||
|
# scope["router"] and scope["endpoint"] contain references to a router
|
||||||
|
# and endpoint object, respectively, which each in turn contain a
|
||||||
|
# reference to the Serve client, which cannot be serialized.
|
||||||
|
# The solution is to delete these from scope, as they will not be used.
|
||||||
|
# TODO(edoakes): this can be removed once we deprecate the old API.
|
||||||
|
if "router" in scope:
|
||||||
|
del scope["router"]
|
||||||
|
if "endpoint" in scope:
|
||||||
|
del scope["endpoint"]
|
||||||
|
|
||||||
|
# NOTE(edoakes): it's important that we defer building the starlette
|
||||||
|
# request until it reaches the backend replica to avoid unnecessary
|
||||||
|
# serialization cost, so we use a simple dataclass here.
|
||||||
|
request = HTTPRequestWrapper(scope, http_body_bytes)
|
||||||
|
|
||||||
|
retries = 0
|
||||||
|
backoff_time_s = 0.05
|
||||||
|
while retries < MAX_REPLICA_FAILURE_RETRIES:
|
||||||
|
object_ref = await handle.remote(request)
|
||||||
|
try:
|
||||||
|
result = await object_ref
|
||||||
|
break
|
||||||
|
except RayActorError:
|
||||||
|
logger.warning("Request failed due to replica failure. There are "
|
||||||
|
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
|
||||||
|
"remaining.")
|
||||||
|
await asyncio.sleep(backoff_time_s)
|
||||||
|
backoff_time_s *= 2
|
||||||
|
retries += 1
|
||||||
|
|
||||||
|
if isinstance(result, RayTaskError):
|
||||||
|
error_message = "Task Error. Traceback: {}.".format(result)
|
||||||
|
await Response(
|
||||||
|
error_message, status_code=500).send(scope, receive, send)
|
||||||
|
elif isinstance(result, starlette.responses.Response):
|
||||||
|
await result(scope, receive, send)
|
||||||
|
else:
|
||||||
|
await Response(result).send(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
class ServeStarletteEndpoint:
|
||||||
|
"""Wraps the given Serve endpoint in a Starlette endpoint.
|
||||||
|
|
||||||
|
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
|
||||||
|
a Starlette app or Starlette Router which calls the given Serve endpoint.
|
||||||
|
Usage:
|
||||||
|
route = starlette.routing.Route(
|
||||||
|
"/api",
|
||||||
|
ServeStarletteEndpoint(endpoint_tag),
|
||||||
|
methods=methods)
|
||||||
|
app = starlette.applications.Starlette(routes=[route])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, endpoint_tag: EndpointTag, path_prefix: str):
|
||||||
|
self.endpoint_tag = endpoint_tag
|
||||||
|
self.path_prefix = path_prefix
|
||||||
|
self.handle = serve.get_handle(
|
||||||
|
self.endpoint_tag, sync=False, missing_ok=True)
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
# Modify the path and root path so that reverse lookups and redirection
|
||||||
|
# work as expected. We do this here instead of in replicas so it can be
|
||||||
|
# changed without restarting the replicas.
|
||||||
|
scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
|
||||||
|
scope["root_path"] = self.path_prefix
|
||||||
|
|
||||||
|
await _send_request_to_handle(self.handle, scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class LongestPrefixRouter:
|
class LongestPrefixRouter:
|
||||||
"""Router that performs longest prefix matches on incoming routes."""
|
"""Router that performs longest prefix matches on incoming routes."""
|
||||||
|
|
||||||
|
@ -117,10 +196,21 @@ class HTTPProxy:
|
||||||
# controller instance this proxy is running in.
|
# controller instance this proxy is running in.
|
||||||
ray.serve.api._set_internal_replica_context(None, None,
|
ray.serve.api._set_internal_replica_context(None, None,
|
||||||
controller_name, None)
|
controller_name, None)
|
||||||
self.router = LongestPrefixRouter()
|
|
||||||
|
# Used only for displaying the route table.
|
||||||
|
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
|
||||||
|
|
||||||
|
# NOTE(edoakes): we currently have both a Starlette router and a
|
||||||
|
# longest-prefix router to maintain compatibility with the old API.
|
||||||
|
# We first match on the Starlette router (which contains routes using
|
||||||
|
# the old API) and then fall back to the prefix router. The Starlette
|
||||||
|
# router can be removed once we deprecate the old API.
|
||||||
|
self.starlette_router = starlette.routing.Router(
|
||||||
|
default=self._fallback_to_prefix_router)
|
||||||
|
self.prefix_router = LongestPrefixRouter()
|
||||||
self.long_poll_client = LongPollClient(
|
self.long_poll_client = LongPollClient(
|
||||||
ray.get_actor(controller_name), {
|
ray.get_actor(controller_name), {
|
||||||
LongPollNamespace.ROUTE_TABLE: self.router.update_routes,
|
LongPollNamespace.ROUTE_TABLE: self._update_routes,
|
||||||
},
|
},
|
||||||
call_in_event_loop=asyncio.get_event_loop())
|
call_in_event_loop=asyncio.get_event_loop())
|
||||||
self.request_counter = metrics.Counter(
|
self.request_counter = metrics.Counter(
|
||||||
|
@ -128,6 +218,41 @@ class HTTPProxy:
|
||||||
description="The number of HTTP requests processed.",
|
description="The number of HTTP requests processed.",
|
||||||
tag_keys=("route", ))
|
tag_keys=("route", ))
|
||||||
|
|
||||||
|
def _split_routes(
|
||||||
|
self, endpoints: Dict[EndpointTag, EndpointInfo]) -> Tuple[Dict[
|
||||||
|
EndpointTag, EndpointInfo], Dict[EndpointTag, EndpointInfo]]:
|
||||||
|
starlette_routes = {}
|
||||||
|
prefix_routes = {}
|
||||||
|
for endpoint, info in endpoints.items():
|
||||||
|
if info.legacy:
|
||||||
|
starlette_routes[endpoint] = info
|
||||||
|
else:
|
||||||
|
prefix_routes[endpoint] = info
|
||||||
|
|
||||||
|
return starlette_routes, prefix_routes
|
||||||
|
|
||||||
|
def _update_routes(self,
|
||||||
|
endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
|
||||||
|
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
|
||||||
|
for endpoint, info in endpoints.items():
|
||||||
|
if not info.legacy and info.route is None:
|
||||||
|
route = f"/{endpoint}"
|
||||||
|
else:
|
||||||
|
route = info.route
|
||||||
|
self.route_info[route] = (endpoint, info.http_methods)
|
||||||
|
|
||||||
|
starlette_routes, prefix_routes = self._split_routes(endpoints)
|
||||||
|
self.starlette_router.routes = [
|
||||||
|
starlette.routing.Route(
|
||||||
|
info.route,
|
||||||
|
ServeStarletteEndpoint(endpoint, info.route),
|
||||||
|
methods=info.http_methods)
|
||||||
|
for endpoint, info in starlette_routes.items()
|
||||||
|
if info.route is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
self.prefix_router.update_routes(prefix_routes)
|
||||||
|
|
||||||
async def block_until_endpoint_exists(self, endpoint: EndpointTag,
|
async def block_until_endpoint_exists(self, endpoint: EndpointTag,
|
||||||
timeout_s: float):
|
timeout_s: float):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
@ -135,8 +260,9 @@ class HTTPProxy:
|
||||||
if time.time() - start > timeout_s:
|
if time.time() - start > timeout_s:
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Waited {timeout_s} for {endpoint} to propagate.")
|
f"Waited {timeout_s} for {endpoint} to propagate.")
|
||||||
if self.router.endpoint_exists(endpoint):
|
for existing_endpoint, _ in self.route_info.values():
|
||||||
return
|
if existing_endpoint == endpoint:
|
||||||
|
return
|
||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
async def _not_found(self, scope, receive, send):
|
async def _not_found(self, scope, receive, send):
|
||||||
|
@ -147,6 +273,22 @@ class HTTPProxy:
|
||||||
status_code=404)
|
status_code=404)
|
||||||
await response.send(scope, receive, send)
|
await response.send(scope, receive, send)
|
||||||
|
|
||||||
|
async def _fallback_to_prefix_router(self, scope, receive, send):
|
||||||
|
route_prefix, handle = self.prefix_router.match_route(
|
||||||
|
scope["path"], scope["method"])
|
||||||
|
if route_prefix is None:
|
||||||
|
return await self._not_found(scope, receive, send)
|
||||||
|
|
||||||
|
# Modify the path and root path so that reverse lookups and redirection
|
||||||
|
# work as expected. We do this here instead of in replicas so it can be
|
||||||
|
# changed without restarting the replicas.
|
||||||
|
if route_prefix != "/":
|
||||||
|
assert not route_prefix.endswith("/")
|
||||||
|
scope["path"] = scope["path"].replace(route_prefix, "", 1)
|
||||||
|
scope["root_path"] = route_prefix
|
||||||
|
|
||||||
|
await _send_request_to_handle(handle, scope, receive, send)
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
"""Implements the ASGI protocol.
|
"""Implements the ASGI protocol.
|
||||||
|
|
||||||
|
@ -158,61 +300,10 @@ class HTTPProxy:
|
||||||
self.request_counter.inc(tags={"route": scope["path"]})
|
self.request_counter.inc(tags={"route": scope["path"]})
|
||||||
|
|
||||||
if scope["path"] == "/-/routes":
|
if scope["path"] == "/-/routes":
|
||||||
return await starlette.responses.JSONResponse(
|
return await starlette.responses.JSONResponse(self.route_info)(
|
||||||
self.router.route_info)(scope, receive, send)
|
scope, receive, send)
|
||||||
|
|
||||||
route_prefix, handle = self.router.match_route(scope["path"],
|
await self.starlette_router(scope, receive, send)
|
||||||
scope["method"])
|
|
||||||
if route_prefix is None:
|
|
||||||
return await self._not_found(scope, receive, send)
|
|
||||||
|
|
||||||
http_body_bytes = await receive_http_body(scope, receive, send)
|
|
||||||
|
|
||||||
# Modify the path and root path so that reverse lookups and redirection
|
|
||||||
# work as expected. We do this here instead of in replicas so it can be
|
|
||||||
# changed without restarting the replicas.
|
|
||||||
if route_prefix != "/":
|
|
||||||
assert not route_prefix.endswith("/")
|
|
||||||
scope["path"] = scope["path"].replace(route_prefix, "", 1)
|
|
||||||
scope["root_path"] = route_prefix
|
|
||||||
|
|
||||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
|
||||||
handle = handle.options(
|
|
||||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
|
||||||
DEFAULT.VALUE),
|
|
||||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
|
|
||||||
http_method=scope["method"].upper(),
|
|
||||||
http_headers=headers)
|
|
||||||
|
|
||||||
# NOTE(edoakes): it's important that we defer building the starlette
|
|
||||||
# request until it reaches the backend replica to avoid unnecessary
|
|
||||||
# serialization cost, so we use a simple dataclass here.
|
|
||||||
request = HTTPRequestWrapper(scope, http_body_bytes)
|
|
||||||
|
|
||||||
retries = 0
|
|
||||||
backoff_time_s = 0.05
|
|
||||||
while retries < MAX_REPLICA_FAILURE_RETRIES:
|
|
||||||
object_ref = await handle.remote(request)
|
|
||||||
try:
|
|
||||||
result = await object_ref
|
|
||||||
break
|
|
||||||
except RayActorError:
|
|
||||||
logger.warning(
|
|
||||||
"Request failed due to replica failure. There are "
|
|
||||||
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
|
|
||||||
"remaining.")
|
|
||||||
await asyncio.sleep(backoff_time_s)
|
|
||||||
backoff_time_s *= 2
|
|
||||||
retries += 1
|
|
||||||
|
|
||||||
if isinstance(result, RayTaskError):
|
|
||||||
error_message = "Task Error. Traceback: {}.".format(result)
|
|
||||||
await Response(
|
|
||||||
error_message, status_code=500).send(scope, receive, send)
|
|
||||||
elif isinstance(result, starlette.responses.Response):
|
|
||||||
await result(scope, receive, send)
|
|
||||||
else:
|
|
||||||
await Response(result).send(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_cpus=0)
|
@ray.remote(num_cpus=0)
|
||||||
|
|
|
@ -540,6 +540,35 @@ def test_starlette_request(serve_instance):
|
||||||
assert resp == long_string
|
assert resp == long_string
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_routes(serve_instance):
|
||||||
|
def f(starlette_request):
|
||||||
|
return starlette_request.path_params
|
||||||
|
|
||||||
|
serve.create_backend("f", f)
|
||||||
|
serve.create_endpoint("basic", backend="f", route="/api/{username}")
|
||||||
|
|
||||||
|
# Test multiple variables and test type conversion
|
||||||
|
serve.create_endpoint(
|
||||||
|
"complex",
|
||||||
|
backend="f",
|
||||||
|
route="/api/{user_id:int}/{number:float}",
|
||||||
|
methods=["POST"])
|
||||||
|
|
||||||
|
assert requests.get("http://127.0.0.1:8000/api/scaly").json() == {
|
||||||
|
"username": "scaly"
|
||||||
|
}
|
||||||
|
|
||||||
|
assert requests.post("http://127.0.0.1:8000/api/23/12.345").json() == {
|
||||||
|
"user_id": 23,
|
||||||
|
"number": 12.345
|
||||||
|
}
|
||||||
|
|
||||||
|
assert requests.get("http://127.0.0.1:8000/-/routes").json() == {
|
||||||
|
"/api/{username}": ["basic", ["GET"]],
|
||||||
|
"/api/{user_id:int}/{number:float}": ["complex", ["POST"]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
Loading…
Add table
Reference in a new issue