[serve] Small cleanup in HTTP proxy (#15028)

This commit is contained in:
Edward Oakes 2021-03-31 09:18:11 -05:00 committed by GitHub
parent b0ea947fa3
commit 12f5e5ab62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,19 +22,17 @@ class ServeStarletteEndpoint:
"""Wraps the given Serve endpoint in a Starlette endpoint.
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
a Starlette app or Starlette Router which calls the given Serve endpoint
using the given Serve client.
a Starlette app or Starlette Router which calls the given Serve endpoint.
Usage:
route = starlette.routing.Route(
"/api",
ServeStarletteEndpoint(self.client, endpoint_tag),
ServeStarletteEndpoint(endpoint_tag),
methods=methods)
app = starlette.applications.Starlette(routes=[route])
"""
def __init__(self, client, endpoint_tag: EndpointTag):
self.client = client
def __init__(self, endpoint_tag: EndpointTag):
self.endpoint_tag = endpoint_tag
# This will be lazily populated when the first request comes in.
# TODO(edoakes): we should be able to construct the handle here, but
@ -87,12 +85,11 @@ class HTTPProxy:
>>> uvicorn.run(HTTPProxy(controller_name))
"""
def __init__(self, controller_name):
def __init__(self, controller_name: str):
# Set the controller name so that serve.connect() will connect to the
# controller instance this proxy is running in.
ray.serve.api._set_internal_replica_context(None, None,
controller_name, None)
self.client = ray.serve.connect()
controller = ray.get_actor(controller_name)
@ -110,15 +107,13 @@ class HTTPProxy:
description="The number of HTTP requests processed.",
tag_keys=("route", ))
def _update_route_table(self, route_table):
def _update_route_table(self, route_table: Dict[str, List[str]]):
logger.debug(f"HTTP Proxy: Get updated route table: {route_table}.")
self.route_table = route_table
routes = [
starlette.routing.Route(
route,
ServeStarletteEndpoint(self.client, endpoint_tag),
methods=methods)
route, ServeStarletteEndpoint(endpoint_tag), methods=methods)
for route, (endpoint_tag, methods) in route_table.items()
if not self._is_headless(route)
]
@ -162,13 +157,12 @@ class HTTPProxy:
@ray.remote
class HTTPProxyActor:
async def __init__(
self,
host,
port,
controller_name,
http_middlewares: List[
"starlette.middleware.Middleware"] = []): # noqa: F821
def __init__(self,
host: str,
port: int,
controller_name: str,
http_middlewares: List[
"starlette.middleware.Middleware"] = []): # noqa: F821
self.host = host
self.port = port