mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[serve] Defer building starlette request to backend replica (#15169)
This commit is contained in:
parent
195f818224
commit
eed34092f2
4 changed files with 42 additions and 24 deletions
|
@ -5,6 +5,7 @@
|
|||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
@ -35,8 +36,9 @@ async def timeit(name, fn, multiplier=1):
|
|||
|
||||
|
||||
async def fetch(session, data):
|
||||
async with session.get("http://127.0.0.1:8000/api", data=data) as response:
|
||||
await response.text()
|
||||
async with session.get("http://localhost:8000/api", data=data) as response:
|
||||
response = await response.text()
|
||||
assert response == "ok", response
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -74,9 +76,11 @@ async def trial(intermediate_handles, num_replicas, max_batch_size,
|
|||
return await self.handle.remote()
|
||||
|
||||
ForwardActor.deploy()
|
||||
routes = requests.get("http://localhost:8000/-/routes").json()
|
||||
assert "/api" in routes, routes
|
||||
|
||||
@serve.deployment(
|
||||
deployment_name,
|
||||
name=deployment_name,
|
||||
num_replicas=num_replicas,
|
||||
max_concurrent_queries=max_concurrent_queries)
|
||||
class Backend:
|
||||
|
@ -91,6 +95,8 @@ async def trial(intermediate_handles, num_replicas, max_batch_size,
|
|||
return b"ok"
|
||||
|
||||
Backend.deploy()
|
||||
routes = requests.get("http://localhost:8000/-/routes").json()
|
||||
assert f"/{deployment_name}" in routes, routes
|
||||
|
||||
if data_size == "small":
|
||||
data = None
|
||||
|
|
|
@ -14,7 +14,7 @@ from ray.serve.constants import WILDCARD_PATH_SUFFIX
|
|||
from ray.serve.long_poll import LongPollNamespace
|
||||
from ray.util import metrics
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.http_util import Response, build_starlette_request
|
||||
from ray.serve.http_util import HTTPRequestWrapper, Response
|
||||
from ray.serve.long_poll import LongPollClient
|
||||
from ray.serve.handle import DEFAULT
|
||||
|
||||
|
@ -52,12 +52,18 @@ class ServeStarletteEndpoint:
|
|||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
|
||||
# scope["router"] and scope["endpoint"] contain references to a router
|
||||
# and endpoint object, respectively, which each in turn contain a
|
||||
# reference to the Serve client, which cannot be serialized.
|
||||
# The solution is to delete these from scope, as they will not be used.
|
||||
del scope["router"]
|
||||
del scope["endpoint"]
|
||||
|
||||
# Modify the path and root path so that reverse lookups and redirection
|
||||
# work as expected. We do this here instead of in replicas so it can be
|
||||
# changed without restarting the replicas.
|
||||
scope["path"] = scope["path"].replace(self.path_prefix, "", 1)
|
||||
scope["root_path"] = self.path_prefix
|
||||
starlette_request = build_starlette_request(scope, http_body_bytes)
|
||||
handle = self.handle.options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||
DEFAULT.VALUE),
|
||||
|
@ -65,10 +71,15 @@ class ServeStarletteEndpoint:
|
|||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
|
||||
# NOTE(edoakes): it's important that we defer building the starlette
|
||||
# request until it reaches the backend replica to avoid unnecessary
|
||||
# serialization cost, so we use a simple dataclass here.
|
||||
request = HTTPRequestWrapper(scope, http_body_bytes)
|
||||
|
||||
retries = 0
|
||||
backoff_time_s = 0.05
|
||||
while retries < MAX_ACTOR_FAILURE_RETRIES:
|
||||
object_ref = await handle.remote(starlette_request)
|
||||
object_ref = await handle.remote(request)
|
||||
try:
|
||||
result = await object_ref
|
||||
break
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from typing import Callable, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import starlette.requests
|
||||
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPRequestWrapper:
|
||||
scope: Dict[Any, Any]
|
||||
body: bytes
|
||||
|
||||
|
||||
def build_starlette_request(scope, serialized_body: bytes):
|
||||
"""Build and return a Starlette Request from ASGI payload.
|
||||
|
||||
|
@ -35,16 +42,7 @@ def build_starlette_request(scope, serialized_body: bytes):
|
|||
"more_body": False
|
||||
}
|
||||
|
||||
# scope["router"] and scope["endpoint"] contain references to a router and
|
||||
# endpoint object, respectively, which each in turn contain a reference to
|
||||
# the Serve client, which cannot be serialized.
|
||||
# The solution is to delete these from scope, as they will not be used.
|
||||
# Per ASGI recommendation, copy scope before passing to child.
|
||||
child_scope = scope.copy()
|
||||
del child_scope["router"]
|
||||
del child_scope["endpoint"]
|
||||
|
||||
return starlette.requests.Request(child_scope, mock_receive)
|
||||
return starlette.requests.Request(scope, mock_receive)
|
||||
|
||||
|
||||
class Response:
|
||||
|
|
|
@ -19,6 +19,7 @@ import pydantic
|
|||
import ray
|
||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.http_util import build_starlette_request, HTTPRequestWrapper
|
||||
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
@ -89,13 +90,15 @@ def parse_request_item(request_item):
|
|||
# it in ServeRequest.
|
||||
if isinstance(arg, starlette.requests.Request):
|
||||
return arg
|
||||
|
||||
return ServeRequest(
|
||||
arg,
|
||||
request_item.kwargs,
|
||||
headers=request_item.metadata.http_headers,
|
||||
method=request_item.metadata.http_method,
|
||||
)
|
||||
elif isinstance(arg, HTTPRequestWrapper):
|
||||
return build_starlette_request(arg.scope, arg.body)
|
||||
else:
|
||||
return ServeRequest(
|
||||
arg,
|
||||
request_item.kwargs,
|
||||
headers=request_item.metadata.http_headers,
|
||||
method=request_item.metadata.http_method,
|
||||
)
|
||||
|
||||
|
||||
def _get_logger():
|
||||
|
|
Loading…
Add table
Reference in a new issue