[serve] Re-add variable route support for old API (#15455)

This commit is contained in:
Edward Oakes 2021-04-22 14:07:50 -05:00 committed by GitHub
parent 79c24146bd
commit 668a784553
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 59 deletions

View file

@ -41,6 +41,39 @@ in :mod:`serve.start <ray.serve.start>`:
instance group of Ray cluster to achieve high availability of Serve's HTTP
proxies.
Variable HTTP Routes
^^^^^^^^^^^^^^^^^^^^
Ray Serve supports capturing path parameters. For example, in a call of the form
.. code-block:: python
serve.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}")
the ``username`` parameter will be accessible in your backend code as follows:
.. code-block:: python
def my_backend(request):
username = request.path_params["username"]
...
Ray Serve uses Starlette's Router class under the hood for routing, so type
conversion for path parameters is also supported, as well as multiple path parameters.
For example, suppose this route is used:
.. code-block:: python
serve.create_endpoint(
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary
available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is
a Python int and ``3.14`` is a Python float.
For full details on the supported path parameters, see Starlette's
`path parameters documentation <https://www.starlette.io/routing/#path-parameters>`_.
Custom HTTP response status codes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View file

@ -20,6 +20,7 @@ class EndpointInfo:
http_methods: List[str]
python_methods: Optional[List[str]] = field(default_factory=list)
route: Optional[str] = None
legacy: Optional[bool] = True
class BackendInfo(BaseModel):

View file

@ -296,7 +296,8 @@ class ServeController:
endpoint_info = EndpointInfo(
ALL_HTTP_METHODS,
route=route_prefix,
python_methods=python_methods)
python_methods=python_methods,
legacy=False)
self.endpoint_state.update_endpoint(name, endpoint_info,
TrafficPolicy({
name: 1.0

View file

@ -22,6 +22,85 @@ from ray.serve.handle import DEFAULT
MAX_REPLICA_FAILURE_RETRIES = 10
async def _send_request_to_handle(handle, scope, receive, send):
http_body_bytes = await receive_http_body(scope, receive, send)
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
handle = handle.options(
method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
# scope["router"] and scope["endpoint"] contain references to a router
# and endpoint object, respectively, which each in turn contain a
# reference to the Serve client, which cannot be serialized.
# The solution is to delete these from scope, as they will not be used.
# TODO(edoakes): this can be removed once we deprecate the old API.
if "router" in scope:
del scope["router"]
if "endpoint" in scope:
del scope["endpoint"]
# NOTE(edoakes): it's important that we defer building the starlette
# request until it reaches the backend replica to avoid unnecessary
# serialization cost, so we use a simple dataclass here.
request = HTTPRequestWrapper(scope, http_body_bytes)
retries = 0
backoff_time_s = 0.05
while retries < MAX_REPLICA_FAILURE_RETRIES:
object_ref = await handle.remote(request)
try:
result = await object_ref
break
except RayActorError:
logger.warning("Request failed due to replica failure. There are "
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
"remaining.")
await asyncio.sleep(backoff_time_s)
backoff_time_s *= 2
retries += 1
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await Response(
error_message, status_code=500).send(scope, receive, send)
elif isinstance(result, starlette.responses.Response):
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
class ServeStarletteEndpoint:
"""Wraps the given Serve endpoint in a Starlette endpoint.
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
a Starlette app or Starlette Router which calls the given Serve endpoint.
Usage:
route = starlette.routing.Route(
"/api",
ServeStarletteEndpoint(endpoint_tag),
methods=methods)
app = starlette.applications.Starlette(routes=[route])
"""
def __init__(self, endpoint_tag: EndpointTag, path_prefix: str):
self.endpoint_tag = endpoint_tag
self.path_prefix = path_prefix
self.handle = serve.get_handle(
self.endpoint_tag, sync=False, missing_ok=True)
async def __call__(self, scope, receive, send):
# Modify the path and root path so that reverse lookups and redirection
# work as expected. We do this here instead of in replicas so it can be
# changed without restarting the replicas.
scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
scope["root_path"] = self.path_prefix
await _send_request_to_handle(self.handle, scope, receive, send)
class LongestPrefixRouter:
"""Router that performs longest prefix matches on incoming routes."""
@ -117,10 +196,21 @@ class HTTPProxy:
# controller instance this proxy is running in.
ray.serve.api._set_internal_replica_context(None, None,
controller_name, None)
self.router = LongestPrefixRouter()
# Used only for displaying the route table.
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
# NOTE(edoakes): we currently have both a Starlette router and a
# longest-prefix router to maintain compatibility with the old API.
# We first match on the Starlette router (which contains routes using
# the old API) and then fall back to the prefix router. The Starlette
# router can be removed once we deprecate the old API.
self.starlette_router = starlette.routing.Router(
default=self._fallback_to_prefix_router)
self.prefix_router = LongestPrefixRouter()
self.long_poll_client = LongPollClient(
ray.get_actor(controller_name), {
LongPollNamespace.ROUTE_TABLE: self.router.update_routes,
LongPollNamespace.ROUTE_TABLE: self._update_routes,
},
call_in_event_loop=asyncio.get_event_loop())
self.request_counter = metrics.Counter(
@ -128,6 +218,41 @@ class HTTPProxy:
description="The number of HTTP requests processed.",
tag_keys=("route", ))
def _split_routes(
self, endpoints: Dict[EndpointTag, EndpointInfo]) -> Tuple[Dict[
EndpointTag, EndpointInfo], Dict[EndpointTag, EndpointInfo]]:
starlette_routes = {}
prefix_routes = {}
for endpoint, info in endpoints.items():
if info.legacy:
starlette_routes[endpoint] = info
else:
prefix_routes[endpoint] = info
return starlette_routes, prefix_routes
def _update_routes(self,
endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict()
for endpoint, info in endpoints.items():
if not info.legacy and info.route is None:
route = f"/{endpoint}"
else:
route = info.route
self.route_info[route] = (endpoint, info.http_methods)
starlette_routes, prefix_routes = self._split_routes(endpoints)
self.starlette_router.routes = [
starlette.routing.Route(
info.route,
ServeStarletteEndpoint(endpoint, info.route),
methods=info.http_methods)
for endpoint, info in starlette_routes.items()
if info.route is not None
]
self.prefix_router.update_routes(prefix_routes)
async def block_until_endpoint_exists(self, endpoint: EndpointTag,
timeout_s: float):
start = time.time()
@ -135,7 +260,8 @@ class HTTPProxy:
if time.time() - start > timeout_s:
raise TimeoutError(
f"Waited {timeout_s} for {endpoint} to propagate.")
if self.router.endpoint_exists(endpoint):
for existing_endpoint, _ in self.route_info.values():
if existing_endpoint == endpoint:
return
await asyncio.sleep(0.2)
@ -147,6 +273,22 @@ class HTTPProxy:
status_code=404)
await response.send(scope, receive, send)
async def _fallback_to_prefix_router(self, scope, receive, send):
route_prefix, handle = self.prefix_router.match_route(
scope["path"], scope["method"])
if route_prefix is None:
return await self._not_found(scope, receive, send)
# Modify the path and root path so that reverse lookups and redirection
# work as expected. We do this here instead of in replicas so it can be
# changed without restarting the replicas.
if route_prefix != "/":
assert not route_prefix.endswith("/")
scope["path"] = scope["path"].replace(route_prefix, "", 1)
scope["root_path"] = route_prefix
await _send_request_to_handle(handle, scope, receive, send)
async def __call__(self, scope, receive, send):
"""Implements the ASGI protocol.
@ -158,61 +300,10 @@ class HTTPProxy:
self.request_counter.inc(tags={"route": scope["path"]})
if scope["path"] == "/-/routes":
return await starlette.responses.JSONResponse(
self.router.route_info)(scope, receive, send)
return await starlette.responses.JSONResponse(self.route_info)(
scope, receive, send)
route_prefix, handle = self.router.match_route(scope["path"],
scope["method"])
if route_prefix is None:
return await self._not_found(scope, receive, send)
http_body_bytes = await receive_http_body(scope, receive, send)
# Modify the path and root path so that reverse lookups and redirection
# work as expected. We do this here instead of in replicas so it can be
# changed without restarting the replicas.
if route_prefix != "/":
assert not route_prefix.endswith("/")
scope["path"] = scope["path"].replace(route_prefix, "", 1)
scope["root_path"] = route_prefix
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
handle = handle.options(
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
# NOTE(edoakes): it's important that we defer building the starlette
# request until it reaches the backend replica to avoid unnecessary
# serialization cost, so we use a simple dataclass here.
request = HTTPRequestWrapper(scope, http_body_bytes)
retries = 0
backoff_time_s = 0.05
while retries < MAX_REPLICA_FAILURE_RETRIES:
object_ref = await handle.remote(request)
try:
result = await object_ref
break
except RayActorError:
logger.warning(
"Request failed due to replica failure. There are "
f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries "
"remaining.")
await asyncio.sleep(backoff_time_s)
backoff_time_s *= 2
retries += 1
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await Response(
error_message, status_code=500).send(scope, receive, send)
elif isinstance(result, starlette.responses.Response):
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
await self.starlette_router(scope, receive, send)
@ray.remote(num_cpus=0)

View file

@ -540,6 +540,35 @@ def test_starlette_request(serve_instance):
assert resp == long_string
def test_variable_routes(serve_instance):
def f(starlette_request):
return starlette_request.path_params
serve.create_backend("f", f)
serve.create_endpoint("basic", backend="f", route="/api/{username}")
# Test multiple variables and test type conversion
serve.create_endpoint(
"complex",
backend="f",
route="/api/{user_id:int}/{number:float}",
methods=["POST"])
assert requests.get("http://127.0.0.1:8000/api/scaly").json() == {
"username": "scaly"
}
assert requests.post("http://127.0.0.1:8000/api/23/12.345").json() == {
"user_id": 23,
"number": 12.345
}
assert requests.get("http://127.0.0.1:8000/-/routes").json() == {
"/api/{username}": ["basic", ["GET"]],
"/api/{user_id:int}/{number:float}": ["complex", ["POST"]]
}
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))