mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Serve] FastAPI Simple Class Based View (#14858)
This commit is contained in:
parent
803be1d968
commit
1fcca07856
5 changed files with 150 additions and 8 deletions
|
@ -24,7 +24,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle
|
|||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_current_node_resource_key, get_random_letters,
|
||||
logger, register_custom_serializers)
|
||||
logger, make_fastapi_class_based_view,
|
||||
register_custom_serializers)
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -54,6 +55,7 @@ class ReplicaContext:
|
|||
backend_tag: BackendTag
|
||||
replica_tag: ReplicaTag
|
||||
_internal_controller_name: str
|
||||
servable_object: Callable
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
|
@ -68,10 +70,15 @@ def create_or_get_async_loop_in_thread():
|
|||
return _global_async_loop
|
||||
|
||||
|
||||
def _set_internal_replica_context(backend_tag, replica_tag, controller_name):
|
||||
def _set_internal_replica_context(
|
||||
backend_tag: BackendTag,
|
||||
replica_tag: ReplicaTag,
|
||||
controller_name: str,
|
||||
servable_object: Callable,
|
||||
):
|
||||
global _INTERNAL_REPLICA_CONTEXT
|
||||
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(backend_tag, replica_tag,
|
||||
controller_name)
|
||||
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
|
||||
backend_tag, replica_tag, controller_name, servable_object)
|
||||
|
||||
|
||||
def _ensure_connected(f: Callable) -> Callable:
|
||||
|
@ -1072,6 +1079,9 @@ def ingress(
|
|||
|
||||
if app is not None:
|
||||
cls._serve_asgi_app = app
|
||||
# Sometimes there are decorators on the methods. We want to fix
|
||||
# the fast api routes here.
|
||||
make_fastapi_class_based_view(app, cls)
|
||||
if path_prefix is not None:
|
||||
cls._serve_path_prefix = path_prefix
|
||||
|
||||
|
|
|
@ -60,11 +60,20 @@ def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
|
|||
# backend code will connect to the instance that this backend is
|
||||
# running in.
|
||||
ray.serve.api._set_internal_replica_context(
|
||||
backend_tag, replica_tag, controller_name)
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
controller_name,
|
||||
servable_object=None)
|
||||
if is_function:
|
||||
_callable = backend
|
||||
else:
|
||||
_callable = backend(*init_args)
|
||||
# Setting the context again to update the servable_object.
|
||||
ray.serve.api._set_internal_replica_context(
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
controller_name,
|
||||
servable_object=_callable)
|
||||
|
||||
assert controller_name, "Must provide a valid controller_name"
|
||||
controller_handle = ray.get_actor(controller_name)
|
||||
|
|
|
@ -91,7 +91,7 @@ class HTTPProxy:
|
|||
# Set the controller name so that serve.connect() will connect to the
|
||||
# controller instance this proxy is running in.
|
||||
ray.serve.api._set_internal_replica_context(None, None,
|
||||
controller_name)
|
||||
controller_name, None)
|
||||
self.client = ray.serve.connect()
|
||||
|
||||
controller = ray.get_actor(controller_name)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from fastapi import FastAPI
|
||||
import requests
|
||||
import pytest
|
||||
import inspect
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.utils import make_fastapi_class_based_view
|
||||
|
||||
|
||||
def test_fastapi_function(serve_instance):
|
||||
|
@ -45,6 +47,61 @@ def test_ingress_prefix(serve_instance):
|
|||
assert resp.json() == {"result": 100}
|
||||
|
||||
|
||||
def test_class_based_view(serve_instance):
|
||||
client = serve_instance
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/other")
|
||||
def hello():
|
||||
return "hello"
|
||||
|
||||
@serve.ingress(app)
|
||||
class A:
|
||||
def __init__(self):
|
||||
self.val = 1
|
||||
|
||||
@app.get("/calc/{i}")
|
||||
def b(self, i: int):
|
||||
return i + self.val
|
||||
|
||||
@app.post("/calc/{i}")
|
||||
def c(self, i: int):
|
||||
return i - self.val
|
||||
|
||||
client.deploy("f", A)
|
||||
resp = requests.get(f"http://localhost:8000/f/calc/41")
|
||||
assert resp.json() == 42
|
||||
resp = requests.post(f"http://localhost:8000/f/calc/41")
|
||||
assert resp.json() == 40
|
||||
resp = requests.get(f"http://localhost:8000/f/other")
|
||||
assert resp.json() == "hello"
|
||||
|
||||
|
||||
def test_make_fastapi_cbv_util():
|
||||
app = FastAPI()
|
||||
|
||||
class A:
|
||||
@app.get("/{i}")
|
||||
def b(self, i: int):
|
||||
pass
|
||||
|
||||
# before, "self" is treated as a query params
|
||||
assert app.routes[-1].endpoint == A.b
|
||||
assert app.routes[-1].dependant.query_params[0].name == "self"
|
||||
assert len(app.routes[-1].dependant.dependencies) == 0
|
||||
|
||||
make_fastapi_class_based_view(app, A)
|
||||
|
||||
# after, "self" is treated as a dependency instead of query params
|
||||
assert app.routes[-1].endpoint == A.b
|
||||
assert len(app.routes[-1].dependant.query_params) == 0
|
||||
assert len(app.routes[-1].dependant.dependencies) == 1
|
||||
self_dep = app.routes[-1].dependant.dependencies[0]
|
||||
assert self_dep.name == "self"
|
||||
assert inspect.isfunction(self_dep.call)
|
||||
assert "get_current_servable" in str(self_dep.call)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -2,14 +2,14 @@ import asyncio
|
|||
from functools import singledispatch
|
||||
import importlib
|
||||
from itertools import groupby
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from typing import Iterable, List, Tuple, Dict, Optional
|
||||
from typing import Iterable, List, Tuple, Dict, Optional, Type
|
||||
import os
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from collections import UserDict
|
||||
|
||||
import starlette.requests
|
||||
|
@ -20,6 +20,7 @@ import pydantic
|
|||
|
||||
import ray
|
||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
@ -431,3 +432,68 @@ class ASGIHTTPSender:
|
|||
b"".join(self.buffer),
|
||||
status_code=self.status_code,
|
||||
headers=dict(self.header))
|
||||
|
||||
|
||||
def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None:
|
||||
"""Transform the `cls`'s methods and class annotations to FastAPI routes.
|
||||
|
||||
Modified from
|
||||
https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
|
||||
|
||||
Usage:
|
||||
>>> app = FastAPI()
|
||||
>>> class A:
|
||||
@app.route("/{i}")
|
||||
def func(self, i: int) -> str:
|
||||
return self.dep + i
|
||||
>>> # just running the app won't work, here.
|
||||
>>> make_fastapi_class_based_view(app, A)
|
||||
>>> # now app can be run properly
|
||||
"""
|
||||
# Delayed import to prevent ciruclar imports in workers.
|
||||
from fastapi import Depends, APIRouter
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
def get_current_servable_instance():
|
||||
from ray import serve
|
||||
return serve.get_replica_context().servable_object
|
||||
|
||||
# Find all the class method routes
|
||||
member_methods = {
|
||||
func
|
||||
for _, func in inspect.getmembers(cls, inspect.isfunction)
|
||||
}
|
||||
class_method_routes = [
|
||||
route for route in fastapi_app.routes
|
||||
if isinstance(route, APIRoute) and route.endpoint in member_methods
|
||||
]
|
||||
|
||||
# Modify these routes and mount it to a new APIRouter.
|
||||
# We need to to this (instead of modifying in place) because we want to use
|
||||
# the laster fastapi_app.include_router to re-run the dependency analysis
|
||||
# for each routes.
|
||||
new_router = APIRouter()
|
||||
for route in class_method_routes:
|
||||
fastapi_app.routes.remove(route)
|
||||
|
||||
# This block just adds a default values to the self parameters so that
|
||||
# FastAPI knows to inject the object when calling the route.
|
||||
# Before: def method(self, i): ...
|
||||
# After: def method(self=Depends(...), *, i):...
|
||||
old_endpoint = route.endpoint
|
||||
old_signature = inspect.signature(old_endpoint)
|
||||
old_parameters = list(old_signature.parameters.values())
|
||||
old_self_parameter = old_parameters[0]
|
||||
new_self_parameter = old_self_parameter.replace(
|
||||
default=Depends(get_current_servable_instance))
|
||||
new_parameters = [new_self_parameter] + [
|
||||
# Make the rest of the parameters keyword only because
|
||||
# the first argument is no longer positional.
|
||||
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY)
|
||||
for parameter in old_parameters[1:]
|
||||
]
|
||||
new_signature = old_signature.replace(parameters=new_parameters)
|
||||
setattr(route.endpoint, "__signature__", new_signature)
|
||||
# route.endpoint.__signature__ = new_signature
|
||||
new_router.routes.append(route)
|
||||
fastapi_app.include_router(new_router)
|
||||
|
|
Loading…
Add table
Reference in a new issue