mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Serve] Use ServeHandle in HTTP proxy (#12523)
This commit is contained in:
parent
30c22921d9
commit
cc1c2c3dc9
10 changed files with 44 additions and 62 deletions
|
@ -7,8 +7,6 @@ from uuid import UUID
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from ray.serve.context import TaskContext
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||||
|
@ -74,7 +72,6 @@ class ThreadProxiedRouter:
|
||||||
request_metadata = RequestMetadata(
|
request_metadata = RequestMetadata(
|
||||||
get_random_letters(10), # Used for debugging.
|
get_random_letters(10), # Used for debugging.
|
||||||
endpoint_name,
|
endpoint_name,
|
||||||
TaskContext.Python,
|
|
||||||
call_method=handle_options.method_name,
|
call_method=handle_options.method_name,
|
||||||
shard_key=handle_options.shard_key,
|
shard_key=handle_options.shard_key,
|
||||||
http_method=handle_options.http_method,
|
http_method=handle_options.http_method,
|
||||||
|
|
|
@ -106,8 +106,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||||
class RayServeWrappedReplica(object):
|
class RayServeWrappedReplica(object):
|
||||||
def __init__(self, backend_tag, replica_tag, init_args,
|
def __init__(self, backend_tag, replica_tag, init_args,
|
||||||
backend_config: BackendConfig, controller_name: str):
|
backend_config: BackendConfig, controller_name: str):
|
||||||
# Set the controller name so that serve.connect() will connect to
|
# Set the controller name so that serve.connect() in the user's
|
||||||
# the instance that this backend is running in.
|
# backend code will connect to the instance that this backend is
|
||||||
|
# running in.
|
||||||
ray.serve.api._set_internal_controller_name(controller_name)
|
ray.serve.api._set_internal_controller_name(controller_name)
|
||||||
if is_function:
|
if is_function:
|
||||||
_callable = func_or_class
|
_callable = func_or_class
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
from enum import IntEnum
|
|
||||||
|
|
||||||
|
|
||||||
class TaskContext(IntEnum):
|
|
||||||
"""TaskContext constants for queue.enqueue method"""
|
|
||||||
Web = 1
|
|
||||||
Python = 2
|
|
|
@ -100,7 +100,7 @@ class RayServeHandle:
|
||||||
class RayServeSyncHandle(RayServeHandle):
|
class RayServeSyncHandle(RayServeHandle):
|
||||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Issue an asynchrounous request to the endpoint.
|
"""Issue an asynchronous request to the endpoint.
|
||||||
|
|
||||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||||
|
@ -110,7 +110,9 @@ class RayServeSyncHandle(RayServeHandle):
|
||||||
Args:
|
Args:
|
||||||
request_data(dict, Any): If it's a dictionary, the data will be
|
request_data(dict, Any): If it's a dictionary, the data will be
|
||||||
available in ``request.json()`` or ``request.form()``.
|
available in ``request.json()`` or ``request.form()``.
|
||||||
Otherwise, it will be available in ``request.data``.
|
If it's a Starlette Request object, it will be passed in to the
|
||||||
|
backend directly, unmodified. Otherwise, the data will be
|
||||||
|
available in ``request.data``.
|
||||||
``**kwargs``: All keyword arguments will be available in
|
``**kwargs``: All keyword arguments will be available in
|
||||||
``request.args``.
|
``request.args``.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -8,12 +8,12 @@ import starlette.responses
|
||||||
import ray
|
import ray
|
||||||
from ray.exceptions import RayTaskError
|
from ray.exceptions import RayTaskError
|
||||||
from ray.serve.constants import LongPollKey
|
from ray.serve.constants import LongPollKey
|
||||||
from ray.serve.context import TaskContext
|
|
||||||
from ray.util import metrics
|
from ray.util import metrics
|
||||||
from ray.serve.utils import _get_logger, get_random_letters
|
from ray.serve.utils import _get_logger
|
||||||
from ray.serve.http_util import Response
|
from ray.serve.http_util import Response, build_starlette_request
|
||||||
from ray.serve.long_poll import LongPollAsyncClient
|
from ray.serve.long_poll import LongPollAsyncClient
|
||||||
from ray.serve.router import Router, RequestMetadata
|
from ray.serve.router import Router
|
||||||
|
from ray.serve.handle import DEFAULT
|
||||||
|
|
||||||
logger = _get_logger()
|
logger = _get_logger()
|
||||||
|
|
||||||
|
@ -22,10 +22,15 @@ class HTTPProxy:
|
||||||
"""This class is meant to be instantiated and run by an ASGI HTTP server.
|
"""This class is meant to be instantiated and run by an ASGI HTTP server.
|
||||||
|
|
||||||
>>> import uvicorn
|
>>> import uvicorn
|
||||||
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
|
>>> uvicorn.run(HTTPProxy(controller_name))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, controller_name):
|
def __init__(self, controller_name):
|
||||||
|
# Set the controller name so that serve.connect() will connect to the
|
||||||
|
# controller instance this proxy is running in.
|
||||||
|
ray.serve.api._set_internal_controller_name(controller_name)
|
||||||
|
self.client = ray.serve.connect()
|
||||||
|
|
||||||
controller = ray.get_actor(controller_name)
|
controller = ray.get_actor(controller_name)
|
||||||
self.route_table = {} # Should be updated via long polling.
|
self.route_table = {} # Should be updated via long polling.
|
||||||
self.router = Router(controller)
|
self.router = Router(controller)
|
||||||
|
@ -113,18 +118,19 @@ class HTTPProxy:
|
||||||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||||
|
|
||||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||||
request_metadata = RequestMetadata(
|
|
||||||
get_random_letters(10), # Used for debugging.
|
|
||||||
endpoint_name,
|
|
||||||
TaskContext.Web,
|
|
||||||
http_method=scope["method"].upper(),
|
|
||||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
|
|
||||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
|
|
||||||
)
|
|
||||||
|
|
||||||
ref = await self.router.assign_request(request_metadata, scope,
|
handle = self.client.get_handle(
|
||||||
http_body_bytes)
|
endpoint_name, sync=False).options(
|
||||||
result = await ref
|
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||||
|
DEFAULT.VALUE),
|
||||||
|
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
|
||||||
|
DEFAULT.VALUE),
|
||||||
|
http_method=scope["method"].upper(),
|
||||||
|
http_headers=headers)
|
||||||
|
|
||||||
|
request = build_starlette_request(scope, http_body_bytes)
|
||||||
|
object_ref = await handle.remote(request)
|
||||||
|
result = await object_ref
|
||||||
|
|
||||||
if isinstance(result, RayTaskError):
|
if isinstance(result, RayTaskError):
|
||||||
error_message = "Task Error. Traceback: {}.".format(result)
|
error_message = "Task Error. Traceback: {}.".format(result)
|
||||||
|
|
|
@ -10,7 +10,6 @@ from ray.serve.exceptions import RayServeException
|
||||||
import ray
|
import ray
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.serve.constants import LongPollKey
|
from ray.serve.constants import LongPollKey
|
||||||
from ray.serve.context import TaskContext
|
|
||||||
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
|
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
|
||||||
from ray.serve.long_poll import LongPollAsyncClient
|
from ray.serve.long_poll import LongPollAsyncClient
|
||||||
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
|
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
|
||||||
|
@ -23,7 +22,6 @@ REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
|
||||||
class RequestMetadata:
|
class RequestMetadata:
|
||||||
request_id: str
|
request_id: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
request_context: TaskContext
|
|
||||||
|
|
||||||
call_method: str = "__call__"
|
call_method: str = "__call__"
|
||||||
shard_key: Optional[str] = None
|
shard_key: Optional[str] = None
|
||||||
|
@ -42,7 +40,6 @@ class RequestMetadata:
|
||||||
class Query:
|
class Query:
|
||||||
args: List[Any]
|
args: List[Any]
|
||||||
kwargs: Dict[Any, Any]
|
kwargs: Dict[Any, Any]
|
||||||
context: TaskContext
|
|
||||||
metadata: RequestMetadata
|
metadata: RequestMetadata
|
||||||
|
|
||||||
# Fields used by backend worker to perform timing measurement.
|
# Fields used by backend worker to perform timing measurement.
|
||||||
|
@ -242,7 +239,6 @@ class Router:
|
||||||
query = Query(
|
query = Query(
|
||||||
args=list(request_args),
|
args=list(request_args),
|
||||||
kwargs=request_kwargs,
|
kwargs=request_kwargs,
|
||||||
context=request_meta.request_context,
|
|
||||||
metadata=request_meta,
|
metadata=request_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def _shared_serve_instance():
|
def _shared_serve_instance():
|
||||||
os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests
|
os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log
|
||||||
# Overriding task_retry_delay_ms to relaunch actors more quickly
|
# Overriding task_retry_delay_ms to relaunch actors more quickly
|
||||||
ray.init(
|
ray.init(
|
||||||
num_cpus=36,
|
num_cpus=36,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
import ray.serve.context as context
|
|
||||||
from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error
|
from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error
|
||||||
from ray.serve.controller import TrafficPolicy
|
from ray.serve.controller import TrafficPolicy
|
||||||
from ray.serve.router import Router, RequestMetadata
|
from ray.serve.router import Router, RequestMetadata
|
||||||
|
@ -67,10 +66,7 @@ async def add_servable_to_router(servable, router, controller_name, **kwargs):
|
||||||
|
|
||||||
def make_request_param(call_method="__call__"):
|
def make_request_param(call_method="__call__"):
|
||||||
return RequestMetadata(
|
return RequestMetadata(
|
||||||
get_random_letters(10),
|
get_random_letters(10), "endpoint", call_method=call_method)
|
||||||
"endpoint",
|
|
||||||
context.TaskContext.Python,
|
|
||||||
call_method=call_method)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -7,7 +7,6 @@ from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from ray.serve.context import TaskContext
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.serve.controller import TrafficPolicy
|
from ray.serve.controller import TrafficPolicy
|
||||||
|
@ -212,9 +211,7 @@ async def test_replica_set(ray_instance):
|
||||||
|
|
||||||
# Send two queries. They should go through the router but blocked by signal
|
# Send two queries. They should go through the router but blocked by signal
|
||||||
# actors.
|
# actors.
|
||||||
query = Query([], {}, TaskContext.Python,
|
query = Query([], {}, RequestMetadata("request-id", "endpoint"))
|
||||||
RequestMetadata("request-id", "endpoint",
|
|
||||||
TaskContext.Python))
|
|
||||||
first_ref = await rs.assign_replica(query)
|
first_ref = await rs.assign_replica(query)
|
||||||
second_ref = await rs.assign_replica(query)
|
second_ref = await rs.assign_replica(query)
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ import os
|
||||||
from ray.serve.exceptions import RayServeException
|
from ray.serve.exceptions import RayServeException
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
|
|
||||||
|
import starlette.requests
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pydantic
|
import pydantic
|
||||||
import starlette.requests
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||||
from ray.serve.context import TaskContext
|
|
||||||
from ray.serve.http_util import build_starlette_request
|
|
||||||
|
|
||||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||||
|
|
||||||
|
@ -85,23 +83,19 @@ class ServeRequest:
|
||||||
|
|
||||||
|
|
||||||
def parse_request_item(request_item):
|
def parse_request_item(request_item):
|
||||||
if request_item.metadata.request_context == TaskContext.Web:
|
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
||||||
asgi_scope, body_bytes = request_item.args
|
|
||||||
return build_starlette_request(asgi_scope, body_bytes)
|
|
||||||
else:
|
|
||||||
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
|
||||||
|
|
||||||
# If the input data from handle is web request, we don't need to wrap
|
# If the input data from handle is web request, we don't need to wrap
|
||||||
# it in ServeRequest.
|
# it in ServeRequest.
|
||||||
if isinstance(arg, starlette.requests.Request):
|
if isinstance(arg, starlette.requests.Request):
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
return ServeRequest(
|
return ServeRequest(
|
||||||
arg,
|
arg,
|
||||||
request_item.kwargs,
|
request_item.kwargs,
|
||||||
headers=request_item.metadata.http_headers,
|
headers=request_item.metadata.http_headers,
|
||||||
method=request_item.metadata.http_method,
|
method=request_item.metadata.http_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_logger():
|
def _get_logger():
|
||||||
|
|
Loading…
Add table
Reference in a new issue