[Serve] Use ServeHandle in HTTP proxy (#12523)

This commit is contained in:
architkulkarni 2020-12-28 18:33:42 -08:00 committed by GitHub
parent 30c22921d9
commit cc1c2c3dc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 44 additions and 62 deletions

View file

@ -7,8 +7,6 @@ from uuid import UUID
import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from ray.serve.context import TaskContext
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
@ -74,7 +72,6 @@ class ThreadProxiedRouter:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Python,
call_method=handle_options.method_name,
shard_key=handle_options.shard_key,
http_method=handle_options.http_method,

View file

@ -106,8 +106,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
class RayServeWrappedReplica(object):
def __init__(self, backend_tag, replica_tag, init_args,
backend_config: BackendConfig, controller_name: str):
# Set the controller name so that serve.connect() will connect to
# the instance that this backend is running in.
# Set the controller name so that serve.connect() in the user's
# backend code will connect to the instance that this backend is
# running in.
ray.serve.api._set_internal_controller_name(controller_name)
if is_function:
_callable = func_or_class

View file

@ -1,7 +0,0 @@
from enum import IntEnum
class TaskContext(IntEnum):
"""TaskContext constants for queue.enqueue method"""
Web = 1
Python = 2

View file

@ -100,7 +100,7 @@ class RayServeHandle:
class RayServeSyncHandle(RayServeHandle):
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
"""Issue an asynchronous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get (or ``await object_ref``), respectively.
@ -110,7 +110,9 @@ class RayServeSyncHandle(RayServeHandle):
Args:
request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``.
Otherwise, it will be available in ``request.data``.
If it's a Starlette Request object, it will be passed in to the
backend directly, unmodified. Otherwise, the data will be
available in ``request.data``.
``**kwargs``: All keyword arguments will be available in
``request.args``.
"""

View file

@ -8,12 +8,12 @@ import starlette.responses
import ray
from ray.exceptions import RayTaskError
from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.util import metrics
from ray.serve.utils import _get_logger, get_random_letters
from ray.serve.http_util import Response
from ray.serve.utils import _get_logger
from ray.serve.http_util import Response, build_starlette_request
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Router, RequestMetadata
from ray.serve.router import Router
from ray.serve.handle import DEFAULT
logger = _get_logger()
@ -22,10 +22,15 @@ class HTTPProxy:
"""This class is meant to be instantiated and run by an ASGI HTTP server.
>>> import uvicorn
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
>>> uvicorn.run(HTTPProxy(controller_name))
"""
def __init__(self, controller_name):
# Set the controller name so that serve.connect() will connect to the
# controller instance this proxy is running in.
ray.serve.api._set_internal_controller_name(controller_name)
self.client = ray.serve.connect()
controller = ray.get_actor(controller_name)
self.route_table = {} # Should be updated via long polling.
self.router = Router(controller)
@ -113,18 +118,19 @@ class HTTPProxy:
http_body_bytes = await self.receive_http_body(scope, receive, send)
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Web,
http_method=scope["method"].upper(),
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
ref = await self.router.assign_request(request_metadata, scope,
http_body_bytes)
result = await ref
handle = self.client.get_handle(
endpoint_name, sync=False).options(
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
request = build_starlette_request(scope, http_body_bytes)
object_ref = await handle.remote(request)
result = await object_ref
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)

View file

@ -10,7 +10,6 @@ from ray.serve.exceptions import RayServeException
import ray
from ray.actor import ActorHandle
from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
@ -23,7 +22,6 @@ REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
class RequestMetadata:
request_id: str
endpoint: str
request_context: TaskContext
call_method: str = "__call__"
shard_key: Optional[str] = None
@ -42,7 +40,6 @@ class RequestMetadata:
class Query:
args: List[Any]
kwargs: Dict[Any, Any]
context: TaskContext
metadata: RequestMetadata
# Fields used by backend worker to perform timing measurement.
@ -242,7 +239,6 @@ class Router:
query = Query(
args=list(request_args),
kwargs=request_kwargs,
context=request_meta.request_context,
metadata=request_meta,
)

View file

@ -15,7 +15,7 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
@pytest.fixture(scope="session")
def _shared_serve_instance():
os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests
os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log
# Overriding task_retry_delay_ms to relaunch actors more quickly
ray.init(
num_cpus=36,

View file

@ -5,7 +5,6 @@ import numpy as np
import ray
from ray import serve
import ray.serve.context as context
from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error
from ray.serve.controller import TrafficPolicy
from ray.serve.router import Router, RequestMetadata
@ -67,10 +66,7 @@ async def add_servable_to_router(servable, router, controller_name, **kwargs):
def make_request_param(call_method="__call__"):
return RequestMetadata(
get_random_letters(10),
"endpoint",
context.TaskContext.Python,
call_method=call_method)
get_random_letters(10), "endpoint", call_method=call_method)
@pytest.fixture

View file

@ -7,7 +7,6 @@ from collections import defaultdict
import os
import pytest
from ray.serve.context import TaskContext
import ray
from ray.serve.controller import TrafficPolicy
@ -212,9 +211,7 @@ async def test_replica_set(ray_instance):
# Send two queries. They should go through the router but blocked by signal
# actors.
query = Query([], {}, TaskContext.Python,
RequestMetadata("request-id", "endpoint",
TaskContext.Python))
query = Query([], {}, RequestMetadata("request-id", "endpoint"))
first_ref = await rs.assign_replica(query)
second_ref = await rs.assign_replica(query)

View file

@ -12,15 +12,13 @@ import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
import starlette.requests
import requests
import numpy as np
import pydantic
import starlette.requests
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.context import TaskContext
from ray.serve.http_util import build_starlette_request
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
@ -85,23 +83,19 @@ class ServeRequest:
def parse_request_item(request_item):
if request_item.metadata.request_context == TaskContext.Web:
asgi_scope, body_bytes = request_item.args
return build_starlette_request(asgi_scope, body_bytes)
else:
arg = request_item.args[0] if len(request_item.args) == 1 else None
arg = request_item.args[0] if len(request_item.args) == 1 else None
# If the input data from handle is web request, we don't need to wrap
# it in ServeRequest.
if isinstance(arg, starlette.requests.Request):
return arg
# If the input data from handle is web request, we don't need to wrap
# it in ServeRequest.
if isinstance(arg, starlette.requests.Request):
return arg
return ServeRequest(
arg,
request_item.kwargs,
headers=request_item.metadata.http_headers,
method=request_item.metadata.http_method,
)
return ServeRequest(
arg,
request_item.kwargs,
headers=request_item.metadata.http_headers,
method=request_item.metadata.http_method,
)
def _get_logger():