diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 838205a8e..05ce4576e 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,8 +7,6 @@ from uuid import UUID import threading from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union -from ray.serve.context import TaskContext - import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT) @@ -74,7 +72,6 @@ class ThreadProxiedRouter: request_metadata = RequestMetadata( get_random_letters(10), # Used for debugging. endpoint_name, - TaskContext.Python, call_method=handle_options.method_name, shard_key=handle_options.shard_key, http_method=handle_options.http_method, diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 54ba942c5..93904ce94 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -106,8 +106,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): class RayServeWrappedReplica(object): def __init__(self, backend_tag, replica_tag, init_args, backend_config: BackendConfig, controller_name: str): - # Set the controller name so that serve.connect() will connect to - # the instance that this backend is running in. + # Set the controller name so that serve.connect() in the user's + # backend code will connect to the instance that this backend is + # running in. ray.serve.api._set_internal_controller_name(controller_name) if is_function: _callable = func_or_class diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py deleted file mode 100644 index 5d21c40e7..000000000 --- a/python/ray/serve/context.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import IntEnum - - -class TaskContext(IntEnum): - """TaskContext constants for queue.enqueue method""" - Web = 1 - Python = 2 diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 381c8b833..80ea835d3 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -100,7 +100,7 @@ class RayServeHandle: class RayServeSyncHandle(RayServeHandle): def remote(self, request_data: Optional[Union[Dict, Any]] = None, **kwargs): - """Issue an asynchrounous request to the endpoint. + """Issue an asynchronous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get (or ``await object_ref``), respectively. @@ -110,7 +110,9 @@ class RayServeSyncHandle(RayServeHandle): Args: request_data(dict, Any): If it's a dictionary, the data will be available in ``request.json()`` or ``request.form()``. - Otherwise, it will be available in ``request.data``. + If it's a Starlette Request object, it will be passed in to the + backend directly, unmodified. Otherwise, the data will be + available in ``request.data``. ``**kwargs``: All keyword arguments will be available in ``request.args``. """ diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 3215f3578..77448860a 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -8,12 +8,12 @@ import starlette.responses import ray from ray.exceptions import RayTaskError from ray.serve.constants import LongPollKey -from ray.serve.context import TaskContext from ray.util import metrics -from ray.serve.utils import _get_logger, get_random_letters -from ray.serve.http_util import Response +from ray.serve.utils import _get_logger +from ray.serve.http_util import Response, build_starlette_request from ray.serve.long_poll import LongPollAsyncClient -from ray.serve.router import Router, RequestMetadata +from ray.serve.router import Router +from ray.serve.handle import DEFAULT logger = _get_logger() @@ -22,10 +22,15 @@ class HTTPProxy: """This class is meant to be instantiated and run by an ASGI HTTP server. >>> import uvicorn - >>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle)) + >>> uvicorn.run(HTTPProxy(controller_name)) """ def __init__(self, controller_name): + # Set the controller name so that serve.connect() will connect to the + # controller instance this proxy is running in. + ray.serve.api._set_internal_controller_name(controller_name) + self.client = ray.serve.connect() + controller = ray.get_actor(controller_name) self.route_table = {} # Should be updated via long polling. self.router = Router(controller) @@ -113,18 +118,19 @@ class HTTPProxy: http_body_bytes = await self.receive_http_body(scope, receive, send) headers = {k.decode(): v.decode() for k, v in scope["headers"]} - request_metadata = RequestMetadata( - get_random_letters(10), # Used for debugging. - endpoint_name, - TaskContext.Web, - http_method=scope["method"].upper(), - call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"), - shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None), - ) - ref = await self.router.assign_request(request_metadata, scope, - http_body_bytes) - result = await ref + handle = self.client.get_handle( + endpoint_name, sync=False).options( + method_name=headers.get("X-SERVE-CALL-METHOD".lower(), + DEFAULT.VALUE), + shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), + DEFAULT.VALUE), + http_method=scope["method"].upper(), + http_headers=headers) + + request = build_starlette_request(scope, http_body_bytes) + object_ref = await handle.remote(request) + result = await object_ref if isinstance(result, RayTaskError): error_message = "Task Error. Traceback: {}.".format(result) diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index b74beb812..8d0f578f7 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -10,7 +10,6 @@ from ray.serve.exceptions import RayServeException import ray from ray.actor import ActorHandle from ray.serve.constants import LongPollKey -from ray.serve.context import TaskContext from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy from ray.serve.long_poll import LongPollAsyncClient from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta @@ -23,7 +22,6 @@ REPORT_QUEUE_LENGTH_PERIOD_S = 1.0 class RequestMetadata: request_id: str endpoint: str - request_context: TaskContext call_method: str = "__call__" shard_key: Optional[str] = None @@ -42,7 +40,6 @@ class RequestMetadata: class Query: args: List[Any] kwargs: Dict[Any, Any] - context: TaskContext metadata: RequestMetadata # Fields used by backend worker to perform timing measurement. @@ -242,7 +239,6 @@ class Router: query = Query( args=list(request_args), kwargs=request_kwargs, - context=request_meta.request_context, metadata=request_meta, ) diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 87b5659e6..6c78cfabf 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -15,7 +15,7 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1: @pytest.fixture(scope="session") def _shared_serve_instance(): - os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests + os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log # Overriding task_retry_delay_ms to relaunch actors more quickly ray.init( num_cpus=36, diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 2bcf035c7..ee175a4d1 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -5,7 +5,6 @@ import numpy as np import ray from ray import serve -import ray.serve.context as context from ray.serve.backend_worker import create_backend_replica, wrap_to_ray_error from ray.serve.controller import TrafficPolicy from ray.serve.router import Router, RequestMetadata @@ -67,10 +66,7 @@ async def add_servable_to_router(servable, router, controller_name, **kwargs): def make_request_param(call_method="__call__"): return RequestMetadata( - get_random_letters(10), - "endpoint", - context.TaskContext.Python, - call_method=call_method) + get_random_letters(10), "endpoint", call_method=call_method) @pytest.fixture diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 678f5ea1b..69ffeb2e0 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -7,7 +7,6 @@ from collections import defaultdict import os import pytest -from ray.serve.context import TaskContext import ray from ray.serve.controller import TrafficPolicy @@ -212,9 +211,7 @@ async def test_replica_set(ray_instance): # Send two queries. They should go through the router but blocked by signal # actors. - query = Query([], {}, TaskContext.Python, - RequestMetadata("request-id", "endpoint", - TaskContext.Python)) + query = Query([], {}, RequestMetadata("request-id", "endpoint")) first_ref = await rs.assign_replica(query) second_ref = await rs.assign_replica(query) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 45f7e27fe..a7d271592 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -12,15 +12,13 @@ import os from ray.serve.exceptions import RayServeException from collections import UserDict +import starlette.requests import requests import numpy as np import pydantic -import starlette.requests import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT -from ray.serve.context import TaskContext -from ray.serve.http_util import build_starlette_request ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 @@ -85,23 +83,19 @@ class ServeRequest: def parse_request_item(request_item): - if request_item.metadata.request_context == TaskContext.Web: - asgi_scope, body_bytes = request_item.args - return build_starlette_request(asgi_scope, body_bytes) - else: - arg = request_item.args[0] if len(request_item.args) == 1 else None + arg = request_item.args[0] if len(request_item.args) == 1 else None - # If the input data from handle is web request, we don't need to wrap - # it in ServeRequest. - if isinstance(arg, starlette.requests.Request): - return arg + # If the input data from handle is web request, we don't need to wrap + # it in ServeRequest. + if isinstance(arg, starlette.requests.Request): + return arg - return ServeRequest( - arg, - request_item.kwargs, - headers=request_item.metadata.http_headers, - method=request_item.metadata.http_method, - ) + return ServeRequest( + arg, + request_item.kwargs, + headers=request_item.metadata.http_headers, + method=request_item.metadata.http_method, + ) def _get_logger():