[Serve] Use ServeHandle in HTTP proxy (#12523)

This commit is contained in:
architkulkarni 2020-12-28 18:33:42 -08:00 committed by GitHub
parent 30c22921d9
commit cc1c2c3dc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 44 additions and 62 deletions

View file

@ -7,8 +7,6 @@ from uuid import UUID
import threading import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from ray.serve.context import TaskContext
import ray import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT) SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
@ -74,7 +72,6 @@ class ThreadProxiedRouter:
request_metadata = RequestMetadata( request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging. get_random_letters(10), # Used for debugging.
endpoint_name, endpoint_name,
TaskContext.Python,
call_method=handle_options.method_name, call_method=handle_options.method_name,
shard_key=handle_options.shard_key, shard_key=handle_options.shard_key,
http_method=handle_options.http_method, http_method=handle_options.http_method,

View file

@ -106,8 +106,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
class RayServeWrappedReplica(object): class RayServeWrappedReplica(object):
def __init__(self, backend_tag, replica_tag, init_args, def __init__(self, backend_tag, replica_tag, init_args,
backend_config: BackendConfig, controller_name: str): backend_config: BackendConfig, controller_name: str):
# Set the controller name so that serve.connect() will connect to # Set the controller name so that serve.connect() in the user's
# the instance that this backend is running in. # backend code will connect to the instance that this backend is
# running in.
ray.serve.api._set_internal_controller_name(controller_name) ray.serve.api._set_internal_controller_name(controller_name)
if is_function: if is_function:
_callable = func_or_class _callable = func_or_class

View file

@ -1,7 +0,0 @@
from enum import IntEnum
class TaskContext(IntEnum):
"""TaskContext constants for queue.enqueue method"""
Web = 1
Python = 2

View file

@ -100,7 +100,7 @@ class RayServeHandle:
class RayServeSyncHandle(RayServeHandle): class RayServeSyncHandle(RayServeHandle):
def remote(self, request_data: Optional[Union[Dict, Any]] = None, def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs): **kwargs):
"""Issue an asynchrounous request to the endpoint. """Issue an asynchronous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get (or ``await object_ref``), respectively. using ray.wait or ray.get (or ``await object_ref``), respectively.
@ -110,7 +110,9 @@ class RayServeSyncHandle(RayServeHandle):
Args: Args:
request_data(dict, Any): If it's a dictionary, the data will be request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``. available in ``request.json()`` or ``request.form()``.
Otherwise, it will be available in ``request.data``. If it's a Starlette Request object, it will be passed in to the
backend directly, unmodified. Otherwise, the data will be
available in ``request.data``.
``**kwargs``: All keyword arguments will be available in ``**kwargs``: All keyword arguments will be available in
``request.args``. ``request.args``.
""" """

View file

@ -8,12 +8,12 @@ import starlette.responses
import ray import ray
from ray.exceptions import RayTaskError from ray.exceptions import RayTaskError
from ray.serve.constants import LongPollKey from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.util import metrics from ray.util import metrics
from ray.serve.utils import _get_logger, get_random_letters from ray.serve.utils import _get_logger
from ray.serve.http_util import Response from ray.serve.http_util import Response, build_starlette_request
from ray.serve.long_poll import LongPollAsyncClient from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Router, RequestMetadata from ray.serve.router import Router
from ray.serve.handle import DEFAULT
logger = _get_logger() logger = _get_logger()
@ -22,10 +22,15 @@ class HTTPProxy:
"""This class is meant to be instantiated and run by an ASGI HTTP server. """This class is meant to be instantiated and run by an ASGI HTTP server.
>>> import uvicorn >>> import uvicorn
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle)) >>> uvicorn.run(HTTPProxy(controller_name))
""" """
def __init__(self, controller_name): def __init__(self, controller_name):
# Set the controller name so that serve.connect() will connect to the
# controller instance this proxy is running in.
ray.serve.api._set_internal_controller_name(controller_name)
self.client = ray.serve.connect()
controller = ray.get_actor(controller_name) controller = ray.get_actor(controller_name)
self.route_table = {} # Should be updated via long polling. self.route_table = {} # Should be updated via long polling.
self.router = Router(controller) self.router = Router(controller)
@ -113,18 +118,19 @@ class HTTPProxy:
http_body_bytes = await self.receive_http_body(scope, receive, send) http_body_bytes = await self.receive_http_body(scope, receive, send)
headers = {k.decode(): v.decode() for k, v in scope["headers"]} headers = {k.decode(): v.decode() for k, v in scope["headers"]}
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Web,
http_method=scope["method"].upper(),
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
ref = await self.router.assign_request(request_metadata, scope, handle = self.client.get_handle(
http_body_bytes) endpoint_name, sync=False).options(
result = await ref method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
request = build_starlette_request(scope, http_body_bytes)
object_ref = await handle.remote(request)
result = await object_ref
if isinstance(result, RayTaskError): if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result) error_message = "Task Error. Traceback: {}.".format(result)

View file

@ -10,7 +10,6 @@ from ray.serve.exceptions import RayServeException
import ray import ray
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.serve.constants import LongPollKey from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
from ray.serve.long_poll import LongPollAsyncClient from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
@ -23,7 +22,6 @@ REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
class RequestMetadata: class RequestMetadata:
request_id: str request_id: str
endpoint: str endpoint: str
request_context: TaskContext
call_method: str = "__call__" call_method: str = "__call__"
shard_key: Optional[str] = None shard_key: Optional[str] = None
@ -42,7 +40,6 @@ class RequestMetadata:
class Query: class Query:
args: List[Any] args: List[Any]
kwargs: Dict[Any, Any] kwargs: Dict[Any, Any]
context: TaskContext
metadata: RequestMetadata metadata: RequestMetadata
# Fields used by backend worker to perform timing measurement. # Fields used by backend worker to perform timing measurement.
@ -242,7 +239,6 @@ class Router:
query = Query( query = Query(
args=list(request_args), args=list(request_args),
kwargs=request_kwargs, kwargs=request_kwargs,
context=request_meta.request_context,
metadata=request_meta, metadata=request_meta,
) )

View file

@ -15,7 +15,7 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def _shared_serve_instance(): def _shared_serve_instance():
os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log
# Overriding task_retry_delay_ms to relaunch actors more quickly # Overriding task_retry_delay_ms to relaunch actors more quickly
ray.init( ray.init(
num_cpus=36, num_cpus=36,

View file

@ -5,7 +5,6 @@ import numpy as np
import ray import ray
from ray import serve from ray import serve
import ray.serve.context as context
from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error
from ray.serve.controller import TrafficPolicy from ray.serve.controller import TrafficPolicy
from ray.serve.router import Router, RequestMetadata from ray.serve.router import Router, RequestMetadata
@ -67,10 +66,7 @@ async def add_servable_to_router(servable, router, controller_name, **kwargs):
def make_request_param(call_method="__call__"): def make_request_param(call_method="__call__"):
return RequestMetadata( return RequestMetadata(
get_random_letters(10), get_random_letters(10), "endpoint", call_method=call_method)
"endpoint",
context.TaskContext.Python,
call_method=call_method)
@pytest.fixture @pytest.fixture

View file

@ -7,7 +7,6 @@ from collections import defaultdict
import os import os
import pytest import pytest
from ray.serve.context import TaskContext
import ray import ray
from ray.serve.controller import TrafficPolicy from ray.serve.controller import TrafficPolicy
@ -212,9 +211,7 @@ async def test_replica_set(ray_instance):
# Send two queries. They should go through the router but blocked by signal # Send two queries. They should go through the router but blocked by signal
# actors. # actors.
query = Query([], {}, TaskContext.Python, query = Query([], {}, RequestMetadata("request-id", "endpoint"))
RequestMetadata("request-id", "endpoint",
TaskContext.Python))
first_ref = await rs.assign_replica(query) first_ref = await rs.assign_replica(query)
second_ref = await rs.assign_replica(query) second_ref = await rs.assign_replica(query)

View file

@ -12,15 +12,13 @@ import os
from ray.serve.exceptions import RayServeException from ray.serve.exceptions import RayServeException
from collections import UserDict from collections import UserDict
import starlette.requests
import requests import requests
import numpy as np import numpy as np
import pydantic import pydantic
import starlette.requests
import ray import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.context import TaskContext
from ray.serve.http_util import build_starlette_request
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
@ -85,10 +83,6 @@ class ServeRequest:
def parse_request_item(request_item): def parse_request_item(request_item):
if request_item.metadata.request_context == TaskContext.Web:
asgi_scope, body_bytes = request_item.args
return build_starlette_request(asgi_scope, body_bytes)
else:
arg = request_item.args[0] if len(request_item.args) == 1 else None arg = request_item.args[0] if len(request_item.args) == 1 else None
# If the input data from handle is web request, we don't need to wrap # If the input data from handle is web request, we don't need to wrap