mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Serve] Use ServeHandle in HTTP proxy (#12523)
This commit is contained in:
parent
30c22921d9
commit
cc1c2c3dc9
10 changed files with 44 additions and 62 deletions
|
@ -7,8 +7,6 @@ from uuid import UUID
|
|||
import threading
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||
|
@ -74,7 +72,6 @@ class ThreadProxiedRouter:
|
|||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=handle_options.method_name,
|
||||
shard_key=handle_options.shard_key,
|
||||
http_method=handle_options.http_method,
|
||||
|
|
|
@ -106,8 +106,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
|||
class RayServeWrappedReplica(object):
|
||||
def __init__(self, backend_tag, replica_tag, init_args,
|
||||
backend_config: BackendConfig, controller_name: str):
|
||||
# Set the controller name so that serve.connect() will connect to
|
||||
# the instance that this backend is running in.
|
||||
# Set the controller name so that serve.connect() in the user's
|
||||
# backend code will connect to the instance that this backend is
|
||||
# running in.
|
||||
ray.serve.api._set_internal_controller_name(controller_name)
|
||||
if is_function:
|
||||
_callable = func_or_class
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from enum import IntEnum
|
||||
|
||||
|
||||
class TaskContext(IntEnum):
|
||||
"""TaskContext constants for queue.enqueue method"""
|
||||
Web = 1
|
||||
Python = 2
|
|
@ -100,7 +100,7 @@ class RayServeHandle:
|
|||
class RayServeSyncHandle(RayServeHandle):
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
"""Issue an asynchronous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
|
@ -110,7 +110,9 @@ class RayServeSyncHandle(RayServeHandle):
|
|||
Args:
|
||||
request_data(dict, Any): If it's a dictionary, the data will be
|
||||
available in ``request.json()`` or ``request.form()``.
|
||||
Otherwise, it will be available in ``request.data``.
|
||||
If it's a Starlette Request object, it will be passed in to the
|
||||
backend directly, unmodified. Otherwise, the data will be
|
||||
available in ``request.data``.
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.args``.
|
||||
"""
|
||||
|
|
|
@ -8,12 +8,12 @@ import starlette.responses
|
|||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve.constants import LongPollKey
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.util import metrics
|
||||
from ray.serve.utils import _get_logger, get_random_letters
|
||||
from ray.serve.http_util import Response
|
||||
from ray.serve.utils import _get_logger
|
||||
from ray.serve.http_util import Response, build_starlette_request
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.router import Router, RequestMetadata
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.handle import DEFAULT
|
||||
|
||||
logger = _get_logger()
|
||||
|
||||
|
@ -22,10 +22,15 @@ class HTTPProxy:
|
|||
"""This class is meant to be instantiated and run by an ASGI HTTP server.
|
||||
|
||||
>>> import uvicorn
|
||||
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
|
||||
>>> uvicorn.run(HTTPProxy(controller_name))
|
||||
"""
|
||||
|
||||
def __init__(self, controller_name):
|
||||
# Set the controller name so that serve.connect() will connect to the
|
||||
# controller instance this proxy is running in.
|
||||
ray.serve.api._set_internal_controller_name(controller_name)
|
||||
self.client = ray.serve.connect()
|
||||
|
||||
controller = ray.get_actor(controller_name)
|
||||
self.route_table = {} # Should be updated via long polling.
|
||||
self.router = Router(controller)
|
||||
|
@ -113,18 +118,19 @@ class HTTPProxy:
|
|||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
endpoint_name,
|
||||
TaskContext.Web,
|
||||
http_method=scope["method"].upper(),
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
|
||||
)
|
||||
|
||||
ref = await self.router.assign_request(request_metadata, scope,
|
||||
http_body_bytes)
|
||||
result = await ref
|
||||
handle = self.client.get_handle(
|
||||
endpoint_name, sync=False).options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||
DEFAULT.VALUE),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
|
||||
DEFAULT.VALUE),
|
||||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
|
||||
request = build_starlette_request(scope, http_body_bytes)
|
||||
object_ref = await handle.remote(request)
|
||||
result = await object_ref
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
|
|
|
@ -10,7 +10,6 @@ from ray.serve.exceptions import RayServeException
|
|||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.serve.constants import LongPollKey
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
|
||||
|
@ -23,7 +22,6 @@ REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
|
|||
class RequestMetadata:
|
||||
request_id: str
|
||||
endpoint: str
|
||||
request_context: TaskContext
|
||||
|
||||
call_method: str = "__call__"
|
||||
shard_key: Optional[str] = None
|
||||
|
@ -42,7 +40,6 @@ class RequestMetadata:
|
|||
class Query:
|
||||
args: List[Any]
|
||||
kwargs: Dict[Any, Any]
|
||||
context: TaskContext
|
||||
metadata: RequestMetadata
|
||||
|
||||
# Fields used by backend worker to perform timing measurement.
|
||||
|
@ -242,7 +239,6 @@ class Router:
|
|||
query = Query(
|
||||
args=list(request_args),
|
||||
kwargs=request_kwargs,
|
||||
context=request_meta.request_context,
|
||||
metadata=request_meta,
|
||||
)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def _shared_serve_instance():
|
||||
os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests
|
||||
os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log
|
||||
# Overriding task_retry_delay_ms to relaunch actors more quickly
|
||||
ray.init(
|
||||
num_cpus=36,
|
||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
|||
|
||||
import ray
|
||||
from ray import serve
|
||||
import ray.serve.context as context
|
||||
from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error
|
||||
from ray.serve.controller import TrafficPolicy
|
||||
from ray.serve.router import Router, RequestMetadata
|
||||
|
@ -67,10 +66,7 @@ async def add_servable_to_router(servable, router, controller_name, **kwargs):
|
|||
|
||||
def make_request_param(call_method="__call__"):
|
||||
return RequestMetadata(
|
||||
get_random_letters(10),
|
||||
"endpoint",
|
||||
context.TaskContext.Python,
|
||||
call_method=call_method)
|
||||
get_random_letters(10), "endpoint", call_method=call_method)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -7,7 +7,6 @@ from collections import defaultdict
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.controller import TrafficPolicy
|
||||
|
@ -212,9 +211,7 @@ async def test_replica_set(ray_instance):
|
|||
|
||||
# Send two queries. They should go through the router but blocked by signal
|
||||
# actors.
|
||||
query = Query([], {}, TaskContext.Python,
|
||||
RequestMetadata("request-id", "endpoint",
|
||||
TaskContext.Python))
|
||||
query = Query([], {}, RequestMetadata("request-id", "endpoint"))
|
||||
first_ref = await rs.assign_replica(query)
|
||||
second_ref = await rs.assign_replica(query)
|
||||
|
||||
|
|
|
@ -12,15 +12,13 @@ import os
|
|||
from ray.serve.exceptions import RayServeException
|
||||
from collections import UserDict
|
||||
|
||||
import starlette.requests
|
||||
import requests
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import starlette.requests
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.http_util import build_starlette_request
|
||||
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
@ -85,23 +83,19 @@ class ServeRequest:
|
|||
|
||||
|
||||
def parse_request_item(request_item):
|
||||
if request_item.metadata.request_context == TaskContext.Web:
|
||||
asgi_scope, body_bytes = request_item.args
|
||||
return build_starlette_request(asgi_scope, body_bytes)
|
||||
else:
|
||||
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
||||
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
||||
|
||||
# If the input data from handle is web request, we don't need to wrap
|
||||
# it in ServeRequest.
|
||||
if isinstance(arg, starlette.requests.Request):
|
||||
return arg
|
||||
# If the input data from handle is web request, we don't need to wrap
|
||||
# it in ServeRequest.
|
||||
if isinstance(arg, starlette.requests.Request):
|
||||
return arg
|
||||
|
||||
return ServeRequest(
|
||||
arg,
|
||||
request_item.kwargs,
|
||||
headers=request_item.metadata.http_headers,
|
||||
method=request_item.metadata.http_method,
|
||||
)
|
||||
return ServeRequest(
|
||||
arg,
|
||||
request_item.kwargs,
|
||||
headers=request_item.metadata.http_headers,
|
||||
method=request_item.metadata.http_method,
|
||||
)
|
||||
|
||||
|
||||
def _get_logger():
|
||||
|
|
Loading…
Add table
Reference in a new issue