From eed34092f26dcedadfcf469ec041bdda70ed1e24 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 7 Apr 2021 13:28:53 -0500 Subject: [PATCH] [serve] Defer building starlette request to backend replica (#15169) --- python/ray/serve/benchmarks/microbenchmark.py | 12 ++++++++--- python/ray/serve/http_proxy.py | 17 +++++++++++++--- python/ray/serve/http_util.py | 20 +++++++++---------- python/ray/serve/utils.py | 17 +++++++++------- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/python/ray/serve/benchmarks/microbenchmark.py b/python/ray/serve/benchmarks/microbenchmark.py index 574327a06..596ae6f3a 100644 --- a/python/ray/serve/benchmarks/microbenchmark.py +++ b/python/ray/serve/benchmarks/microbenchmark.py @@ -5,6 +5,7 @@ import aiohttp import asyncio import time +import requests import ray from ray import serve @@ -35,8 +36,9 @@ async def timeit(name, fn, multiplier=1): async def fetch(session, data): - async with session.get("http://127.0.0.1:8000/api", data=data) as response: - await response.text() + async with session.get("http://localhost:8000/api", data=data) as response: + response = await response.text() + assert response == "ok", response @ray.remote @@ -74,9 +76,11 @@ async def trial(intermediate_handles, num_replicas, max_batch_size, return await self.handle.remote() ForwardActor.deploy() + routes = requests.get("http://localhost:8000/-/routes").json() + assert "/api" in routes, routes @serve.deployment( - deployment_name, + name=deployment_name, num_replicas=num_replicas, max_concurrent_queries=max_concurrent_queries) class Backend: @@ -91,6 +95,8 @@ async def trial(intermediate_handles, num_replicas, max_batch_size, return b"ok" Backend.deploy() + routes = requests.get("http://localhost:8000/-/routes").json() + assert f"/{deployment_name}" in routes, routes if data_size == "small": data = None diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index bb0441f2d..fb0f04b23 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -14,7 +14,7 @@ from ray.serve.constants import WILDCARD_PATH_SUFFIX from ray.serve.long_poll import LongPollNamespace from ray.util import metrics from ray.serve.utils import logger -from ray.serve.http_util import Response, build_starlette_request +from ray.serve.http_util import HTTPRequestWrapper, Response from ray.serve.long_poll import LongPollClient from ray.serve.handle import DEFAULT @@ -52,12 +52,18 @@ class ServeStarletteEndpoint: headers = {k.decode(): v.decode() for k, v in scope["headers"]} + # scope["router"] and scope["endpoint"] contain references to a router + # and endpoint object, respectively, which each in turn contain a + # reference to the Serve client, which cannot be serialized. + # The solution is to delete these from scope, as they will not be used. + del scope["router"] + del scope["endpoint"] + # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. scope["path"] = scope["path"].replace(self.path_prefix, "", 1) scope["root_path"] = self.path_prefix - starlette_request = build_starlette_request(scope, http_body_bytes) handle = self.handle.options( method_name=headers.get("X-SERVE-CALL-METHOD".lower(), DEFAULT.VALUE), @@ -65,10 +71,15 @@ class ServeStarletteEndpoint: http_method=scope["method"].upper(), http_headers=headers) + # NOTE(edoakes): it's important that we defer building the starlette + # request until it reaches the backend replica to avoid unnecessary + # serialization cost, so we use a simple dataclass here. + request = HTTPRequestWrapper(scope, http_body_bytes) + retries = 0 backoff_time_s = 0.05 while retries < MAX_ACTOR_FAILURE_RETRIES: - object_ref = await handle.remote(starlette_request) + object_ref = await handle.remote(request) try: result = await object_ref break diff --git a/python/ray/serve/http_util.py b/python/ray/serve/http_util.py index 62456b47e..a9d8ab860 100644 --- a/python/ray/serve/http_util.py +++ b/python/ray/serve/http_util.py @@ -1,12 +1,19 @@ import asyncio +from dataclasses import dataclass import json -from typing import Callable, Tuple +from typing import Any, Callable, Dict, Tuple import starlette.requests from ray.serve.exceptions import RayServeException +@dataclass +class HTTPRequestWrapper: + scope: Dict[Any, Any] + body: bytes + + def build_starlette_request(scope, serialized_body: bytes): """Build and return a Starlette Request from ASGI payload. @@ -35,16 +42,7 @@ def build_starlette_request(scope, serialized_body: bytes): "more_body": False } - # scope["router"] and scope["endpoint"] contain references to a router and - # endpoint object, respectively, which each in turn contain a reference to - # the Serve client, which cannot be serialized. - # The solution is to delete these from scope, as they will not be used. - # Per ASGI recommendation, copy scope before passing to child. - child_scope = scope.copy() - del child_scope["router"] - del child_scope["endpoint"] - - return starlette.requests.Request(child_scope, mock_receive) + return starlette.requests.Request(scope, mock_receive) class Response: diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 4edcf937e..72f10ef04 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -19,6 +19,7 @@ import pydantic import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT from ray.serve.exceptions import RayServeException +from ray.serve.http_util import build_starlette_request, HTTPRequestWrapper ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 @@ -89,13 +90,15 @@ def parse_request_item(request_item): # it in ServeRequest. if isinstance(arg, starlette.requests.Request): return arg - - return ServeRequest( - arg, - request_item.kwargs, - headers=request_item.metadata.http_headers, - method=request_item.metadata.http_method, - ) + elif isinstance(arg, HTTPRequestWrapper): + return build_starlette_request(arg.scope, arg.body) + else: + return ServeRequest( + arg, + request_item.kwargs, + headers=request_item.metadata.http_headers, + method=request_item.metadata.http_method, + ) def _get_logger():