From 668a784553c3f679ecd2a146770884e66a326d54 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 22 Apr 2021 14:07:50 -0500 Subject: [PATCH] [serve] Re-add variable route support for old API (#15455) --- doc/source/serve/http-servehandle.rst | 33 ++++ python/ray/serve/common.py | 1 + python/ray/serve/controller.py | 3 +- python/ray/serve/http_proxy.py | 207 ++++++++++++++++++-------- python/ray/serve/tests/test_api.py | 29 ++++ 5 files changed, 214 insertions(+), 59 deletions(-) diff --git a/doc/source/serve/http-servehandle.rst b/doc/source/serve/http-servehandle.rst index cc4a14229..6fbf2f219 100644 --- a/doc/source/serve/http-servehandle.rst +++ b/doc/source/serve/http-servehandle.rst @@ -41,6 +41,39 @@ in :mod:`serve.start `: instance group of Ray cluster to achieve high availability of Serve's HTTP proxies. +Variable HTTP Routes +^^^^^^^^^^^^^^^^^^^^ + +Ray Serve supports capturing path parameters. For example, in a call of the form + +.. code-block:: python + + serve.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}") + +the ``username`` parameter will be accessible in your backend code as follows: + +.. code-block:: python + + def my_backend(request): + username = request.path_params["username"] + ... + +Ray Serve uses Starlette's Router class under the hood for routing, so type +conversion for path parameters is also supported, as well as multiple path parameters. +For example, suppose this route is used: + +.. code-block:: python + + serve.create_endpoint( + "complex", backend="f", route="/api/{user_id:int}/{number:float}") + +Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary +available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is +a Python int and ``3.14`` is a Python float. + +For full details on the supported path parameters, see Starlette's +`path parameters documentation `_. + Custom HTTP response status codes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index 0af19ea5f..dafdf7a05 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -20,6 +20,7 @@ class EndpointInfo: http_methods: List[str] python_methods: Optional[List[str]] = field(default_factory=list) route: Optional[str] = None + legacy: Optional[bool] = True class BackendInfo(BaseModel): diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e487fdbf4..8cbf699e0 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -296,7 +296,8 @@ class ServeController: endpoint_info = EndpointInfo( ALL_HTTP_METHODS, route=route_prefix, - python_methods=python_methods) + python_methods=python_methods, + legacy=False) self.endpoint_state.update_endpoint(name, endpoint_info, TrafficPolicy({ name: 1.0 diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 1f826b748..7654ba227 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -22,6 +22,85 @@ from ray.serve.handle import DEFAULT MAX_REPLICA_FAILURE_RETRIES = 10 +async def _send_request_to_handle(handle, scope, receive, send): + http_body_bytes = await receive_http_body(scope, receive, send) + + headers = {k.decode(): v.decode() for k, v in scope["headers"]} + handle = handle.options( + method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), + shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), + http_method=scope["method"].upper(), + http_headers=headers) + + # scope["router"] and scope["endpoint"] contain references to a router + # and endpoint object, respectively, which each in turn contain a + # reference to the Serve client, which cannot be serialized. + # The solution is to delete these from scope, as they will not be used. + # TODO(edoakes): this can be removed once we deprecate the old API. + if "router" in scope: + del scope["router"] + if "endpoint" in scope: + del scope["endpoint"] + + # NOTE(edoakes): it's important that we defer building the starlette + # request until it reaches the backend replica to avoid unnecessary + # serialization cost, so we use a simple dataclass here. + request = HTTPRequestWrapper(scope, http_body_bytes) + + retries = 0 + backoff_time_s = 0.05 + while retries < MAX_REPLICA_FAILURE_RETRIES: + object_ref = await handle.remote(request) + try: + result = await object_ref + break + except RayActorError: + logger.warning("Request failed due to replica failure. There are " + f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " + "remaining.") + await asyncio.sleep(backoff_time_s) + backoff_time_s *= 2 + retries += 1 + + if isinstance(result, RayTaskError): + error_message = "Task Error. Traceback: {}.".format(result) + await Response( + error_message, status_code=500).send(scope, receive, send) + elif isinstance(result, starlette.responses.Response): + await result(scope, receive, send) + else: + await Response(result).send(scope, receive, send) + + +class ServeStarletteEndpoint: + """Wraps the given Serve endpoint in a Starlette endpoint. + + Implements the ASGI protocol. Constructs a Starlette endpoint for use by + a Starlette app or Starlette Router which calls the given Serve endpoint. + Usage: + route = starlette.routing.Route( + "/api", + ServeStarletteEndpoint(endpoint_tag), + methods=methods) + app = starlette.applications.Starlette(routes=[route]) + """ + + def __init__(self, endpoint_tag: EndpointTag, path_prefix: str): + self.endpoint_tag = endpoint_tag + self.path_prefix = path_prefix + self.handle = serve.get_handle( + self.endpoint_tag, sync=False, missing_ok=True) + + async def __call__(self, scope, receive, send): + # Modify the path and root path so that reverse lookups and redirection + # work as expected. We do this here instead of in replicas so it can be + # changed without restarting the replicas. + scope["path"] = scope["path"].replace(self.path_prefix, "", 1) + scope["root_path"] = self.path_prefix + + await _send_request_to_handle(self.handle, scope, receive, send) + + class LongestPrefixRouter: """Router that performs longest prefix matches on incoming routes.""" @@ -117,10 +196,21 @@ class HTTPProxy: # controller instance this proxy is running in. ray.serve.api._set_internal_replica_context(None, None, controller_name, None) - self.router = LongestPrefixRouter() + + # Used only for displaying the route table. + self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict() + + # NOTE(edoakes): we currently have both a Starlette router and a + # longest-prefix router to maintain compatibility with the old API. + # We first match on the Starlette router (which contains routes using + # the old API) and then fall back to the prefix router. The Starlette + # router can be removed once we deprecate the old API. + self.starlette_router = starlette.routing.Router( + default=self._fallback_to_prefix_router) + self.prefix_router = LongestPrefixRouter() self.long_poll_client = LongPollClient( ray.get_actor(controller_name), { - LongPollNamespace.ROUTE_TABLE: self.router.update_routes, + LongPollNamespace.ROUTE_TABLE: self._update_routes, }, call_in_event_loop=asyncio.get_event_loop()) self.request_counter = metrics.Counter( @@ -128,6 +218,41 @@ class HTTPProxy: description="The number of HTTP requests processed.", tag_keys=("route", )) + def _split_routes( + self, endpoints: Dict[EndpointTag, EndpointInfo]) -> Tuple[Dict[ + EndpointTag, EndpointInfo], Dict[EndpointTag, EndpointInfo]]: + starlette_routes = {} + prefix_routes = {} + for endpoint, info in endpoints.items(): + if info.legacy: + starlette_routes[endpoint] = info + else: + prefix_routes[endpoint] = info + + return starlette_routes, prefix_routes + + def _update_routes(self, + endpoints: Dict[EndpointTag, EndpointInfo]) -> None: + self.route_info: Dict[str, Tuple[EndpointTag, List[str]]] = dict() + for endpoint, info in endpoints.items(): + if not info.legacy and info.route is None: + route = f"/{endpoint}" + else: + route = info.route + self.route_info[route] = (endpoint, info.http_methods) + + starlette_routes, prefix_routes = self._split_routes(endpoints) + self.starlette_router.routes = [ + starlette.routing.Route( + info.route, + ServeStarletteEndpoint(endpoint, info.route), + methods=info.http_methods) + for endpoint, info in starlette_routes.items() + if info.route is not None + ] + + self.prefix_router.update_routes(prefix_routes) + async def block_until_endpoint_exists(self, endpoint: EndpointTag, timeout_s: float): start = time.time() @@ -135,8 +260,9 @@ class HTTPProxy: if time.time() - start > timeout_s: raise TimeoutError( f"Waited {timeout_s} for {endpoint} to propagate.") - if self.router.endpoint_exists(endpoint): - return + for existing_endpoint, _ in self.route_info.values(): + if existing_endpoint == endpoint: + return await asyncio.sleep(0.2) async def _not_found(self, scope, receive, send): @@ -147,6 +273,22 @@ class HTTPProxy: status_code=404) await response.send(scope, receive, send) + async def _fallback_to_prefix_router(self, scope, receive, send): + route_prefix, handle = self.prefix_router.match_route( + scope["path"], scope["method"]) + if route_prefix is None: + return await self._not_found(scope, receive, send) + + # Modify the path and root path so that reverse lookups and redirection + # work as expected. We do this here instead of in replicas so it can be + # changed without restarting the replicas. + if route_prefix != "/": + assert not route_prefix.endswith("/") + scope["path"] = scope["path"].replace(route_prefix, "", 1) + scope["root_path"] = route_prefix + + await _send_request_to_handle(handle, scope, receive, send) + async def __call__(self, scope, receive, send): """Implements the ASGI protocol. @@ -158,61 +300,10 @@ class HTTPProxy: self.request_counter.inc(tags={"route": scope["path"]}) if scope["path"] == "/-/routes": - return await starlette.responses.JSONResponse( - self.router.route_info)(scope, receive, send) + return await starlette.responses.JSONResponse(self.route_info)( + scope, receive, send) - route_prefix, handle = self.router.match_route(scope["path"], - scope["method"]) - if route_prefix is None: - return await self._not_found(scope, receive, send) - - http_body_bytes = await receive_http_body(scope, receive, send) - - # Modify the path and root path so that reverse lookups and redirection - # work as expected. We do this here instead of in replicas so it can be - # changed without restarting the replicas. - if route_prefix != "/": - assert not route_prefix.endswith("/") - scope["path"] = scope["path"].replace(route_prefix, "", 1) - scope["root_path"] = route_prefix - - headers = {k.decode(): v.decode() for k, v in scope["headers"]} - handle = handle.options( - method_name=headers.get("X-SERVE-CALL-METHOD".lower(), - DEFAULT.VALUE), - shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), - http_method=scope["method"].upper(), - http_headers=headers) - - # NOTE(edoakes): it's important that we defer building the starlette - # request until it reaches the backend replica to avoid unnecessary - # serialization cost, so we use a simple dataclass here. - request = HTTPRequestWrapper(scope, http_body_bytes) - - retries = 0 - backoff_time_s = 0.05 - while retries < MAX_REPLICA_FAILURE_RETRIES: - object_ref = await handle.remote(request) - try: - result = await object_ref - break - except RayActorError: - logger.warning( - "Request failed due to replica failure. There are " - f"{MAX_REPLICA_FAILURE_RETRIES - retries} retries " - "remaining.") - await asyncio.sleep(backoff_time_s) - backoff_time_s *= 2 - retries += 1 - - if isinstance(result, RayTaskError): - error_message = "Task Error. Traceback: {}.".format(result) - await Response( - error_message, status_code=500).send(scope, receive, send) - elif isinstance(result, starlette.responses.Response): - await result(scope, receive, send) - else: - await Response(result).send(scope, receive, send) + await self.starlette_router(scope, receive, send) @ray.remote(num_cpus=0) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 0708f1517..fff8af106 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -540,6 +540,35 @@ def test_starlette_request(serve_instance): assert resp == long_string +def test_variable_routes(serve_instance): + def f(starlette_request): + return starlette_request.path_params + + serve.create_backend("f", f) + serve.create_endpoint("basic", backend="f", route="/api/{username}") + + # Test multiple variables and test type conversion + serve.create_endpoint( + "complex", + backend="f", + route="/api/{user_id:int}/{number:float}", + methods=["POST"]) + + assert requests.get("http://127.0.0.1:8000/api/scaly").json() == { + "username": "scaly" + } + + assert requests.post("http://127.0.0.1:8000/api/23/12.345").json() == { + "user_id": 23, + "number": 12.345 + } + + assert requests.get("http://127.0.0.1:8000/-/routes").json() == { + "/api/{username}": ["basic", ["GET"]], + "/api/{user_id:int}/{number:float}": ["complex", ["POST"]] + } + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__]))