[Serve] FastAPI Simple Class Based View (#14858)

This commit is contained in:
Simon Mo 2021-03-25 10:21:36 -07:00 committed by GitHub
parent 803be1d968
commit 1fcca07856
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 8 deletions

View file

@ -24,7 +24,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.router import RequestMetadata, Router
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_current_node_resource_key, get_random_letters,
logger, register_custom_serializers)
logger, make_fastapi_class_based_view,
register_custom_serializers)
import ray
@ -54,6 +55,7 @@ class ReplicaContext:
backend_tag: BackendTag
replica_tag: ReplicaTag
_internal_controller_name: str
servable_object: Callable
def create_or_get_async_loop_in_thread():
@ -68,10 +70,15 @@ def create_or_get_async_loop_in_thread():
return _global_async_loop
def _set_internal_replica_context(backend_tag, replica_tag, controller_name):
def _set_internal_replica_context(
backend_tag: BackendTag,
replica_tag: ReplicaTag,
controller_name: str,
servable_object: Callable,
):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(backend_tag, replica_tag,
controller_name)
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
backend_tag, replica_tag, controller_name, servable_object)
def _ensure_connected(f: Callable) -> Callable:
@ -1072,6 +1079,9 @@ def ingress(
if app is not None:
cls._serve_asgi_app = app
# Sometimes there are decorators on the methods. We want to fix
# the fast api routes here.
make_fastapi_class_based_view(app, cls)
if path_prefix is not None:
cls._serve_path_prefix = path_prefix

View file

@ -60,11 +60,20 @@ def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
# backend code will connect to the instance that this backend is
# running in.
ray.serve.api._set_internal_replica_context(
backend_tag, replica_tag, controller_name)
backend_tag,
replica_tag,
controller_name,
servable_object=None)
if is_function:
_callable = backend
else:
_callable = backend(*init_args)
# Setting the context again to update the servable_object.
ray.serve.api._set_internal_replica_context(
backend_tag,
replica_tag,
controller_name,
servable_object=_callable)
assert controller_name, "Must provide a valid controller_name"
controller_handle = ray.get_actor(controller_name)

View file

@ -91,7 +91,7 @@ class HTTPProxy:
# Set the controller name so that serve.connect() will connect to the
# controller instance this proxy is running in.
ray.serve.api._set_internal_replica_context(None, None,
controller_name)
controller_name, None)
self.client = ray.serve.connect()
controller = ray.get_actor(controller_name)

View file

@ -1,8 +1,10 @@
from fastapi import FastAPI
import requests
import pytest
import inspect
from ray import serve
from ray.serve.utils import make_fastapi_class_based_view
def test_fastapi_function(serve_instance):
@ -45,6 +47,61 @@ def test_ingress_prefix(serve_instance):
assert resp.json() == {"result": 100}
def test_class_based_view(serve_instance):
client = serve_instance
app = FastAPI()
@app.get("/other")
def hello():
return "hello"
@serve.ingress(app)
class A:
def __init__(self):
self.val = 1
@app.get("/calc/{i}")
def b(self, i: int):
return i + self.val
@app.post("/calc/{i}")
def c(self, i: int):
return i - self.val
client.deploy("f", A)
resp = requests.get(f"http://localhost:8000/f/calc/41")
assert resp.json() == 42
resp = requests.post(f"http://localhost:8000/f/calc/41")
assert resp.json() == 40
resp = requests.get(f"http://localhost:8000/f/other")
assert resp.json() == "hello"
def test_make_fastapi_cbv_util():
app = FastAPI()
class A:
@app.get("/{i}")
def b(self, i: int):
pass
# before, "self" is treated as a query params
assert app.routes[-1].endpoint == A.b
assert app.routes[-1].dependant.query_params[0].name == "self"
assert len(app.routes[-1].dependant.dependencies) == 0
make_fastapi_class_based_view(app, A)
# after, "self" is treated as a dependency instead of query params
assert app.routes[-1].endpoint == A.b
assert len(app.routes[-1].dependant.query_params) == 0
assert len(app.routes[-1].dependant.dependencies) == 1
self_dep = app.routes[-1].dependant.dependencies[0]
assert self_dep.name == "self"
assert inspect.isfunction(self_dep.call)
assert "get_current_servable" in str(self_dep.call)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -2,14 +2,14 @@ import asyncio
from functools import singledispatch
import importlib
from itertools import groupby
import inspect
import json
import logging
import random
import string
import time
from typing import Iterable, List, Tuple, Dict, Optional
from typing import Iterable, List, Tuple, Dict, Optional, Type
import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
import starlette.requests
@ -20,6 +20,7 @@ import pydantic
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.exceptions import RayServeException
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
@ -431,3 +432,68 @@ class ASGIHTTPSender:
b"".join(self.buffer),
status_code=self.status_code,
headers=dict(self.header))
def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None:
"""Transform the `cls`'s methods and class annotations to FastAPI routes.
Modified from
https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
Usage:
>>> app = FastAPI()
>>> class A:
@app.route("/{i}")
def func(self, i: int) -> str:
return self.dep + i
>>> # just running the app won't work, here.
>>> make_fastapi_class_based_view(app, A)
>>> # now app can be run properly
"""
# Delayed import to prevent ciruclar imports in workers.
from fastapi import Depends, APIRouter
from fastapi.routing import APIRoute
def get_current_servable_instance():
from ray import serve
return serve.get_replica_context().servable_object
# Find all the class method routes
member_methods = {
func
for _, func in inspect.getmembers(cls, inspect.isfunction)
}
class_method_routes = [
route for route in fastapi_app.routes
if isinstance(route, APIRoute) and route.endpoint in member_methods
]
# Modify these routes and mount it to a new APIRouter.
# We need to to this (instead of modifying in place) because we want to use
# the laster fastapi_app.include_router to re-run the dependency analysis
# for each routes.
new_router = APIRouter()
for route in class_method_routes:
fastapi_app.routes.remove(route)
# This block just adds a default values to the self parameters so that
# FastAPI knows to inject the object when calling the route.
# Before: def method(self, i): ...
# After: def method(self=Depends(...), *, i):...
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters = list(old_signature.parameters.values())
old_self_parameter = old_parameters[0]
new_self_parameter = old_self_parameter.replace(
default=Depends(get_current_servable_instance))
new_parameters = [new_self_parameter] + [
# Make the rest of the parameters keyword only because
# the first argument is no longer positional.
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY)
for parameter in old_parameters[1:]
]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)
# route.endpoint.__signature__ = new_signature
new_router.routes.append(route)
fastapi_app.include_router(new_router)