[Serve] Rename input_schema to http_adapter and clarify it in doc (#24353)

This commit is contained in:
Simon Mo 2022-04-29 16:14:04 -07:00 committed by GitHub
parent ff0ced1a64
commit 3378e1924e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 73 additions and 50 deletions

View file

@ -298,7 +298,7 @@ Serve provides a default DAGDriver implementation that accepts HTTP request and
You can configure how does the DAGDriver convert HTTP request types. By default, we directly send in a [```starlette.requests.Request```](https://www.starlette.io/requests/) object to represent the whole request. You can also specifies built-in adapters. In this example, we will use a `json_request` adapter that parses HTTP body with JSON parser.
```{tip}
There are several useful adapters like ndarray JSON, image object, etc. You can checkout {ref}`the list of adapters here <serve-http-adapters>`. You can also easily plug in your own ```input_schema```.
There are several useful adapters like ndarray JSON, image object, etc. You can checkout {ref}`the list of adapters here <serve-http-adapters>`. You can also easily plug in your own adapter by passing in in the ```http_adapter``` field.
```
+++
@ -316,7 +316,7 @@ with InputNode() as dag_input:
# Each serve dag has a driver deployment as ingress that can be user provided.
serve_dag = DAGDriver.options(route_prefix="/my-dag", num_replicas=2).bind(
dag, input_schema=json_request
dag, http_adapter=json_request
)
```
@ -498,7 +498,7 @@ with InputNode() as dag_input:
# Each serve dag has a driver deployment as ingress that can be user provided.
serve_dag = DAGDriver.options(route_prefix="/my-dag", num_replicas=2).bind(
dag, input_schema=json_request
dag, http_adapter=json_request
)

View file

@ -160,13 +160,36 @@ HTTP Adapters
^^^^^^^^^^^^^
HTTP adapters are functions that convert raw HTTP request to Python types that you know and recognize.
You can use it in three different scenarios:
Its input arguments should be type annotated. At minimal, it should accept a ``starlette.requests.Request`` type.
But it can also accept any type that's recognized by the FastAPI's dependency injection framework.
For example, here is an adapter that extra the json content from request.
.. code-block:: python
async def json_resolver(request: starlette.requests.Request):
return await request.json()
Here is an adapter that accept two HTTP query parameters.
.. code-block:: python
def parse_query_args(field_a: int, field_b: str):
return YourDataClass(field_a, field_b)
You can specify different type signatures to facilitate HTTP fields extraction
include
`query parameters <https://fastapi.tiangolo.com/tutorial/query-params/>`_,
`body parameters <https://fastapi.tiangolo.com/tutorial/body/>`_,
and `many other data types <https://fastapi.tiangolo.com/tutorial/extra-data-types/>`_.
For more detail, you can take a look at `FastAPI documentation <https://fastapi.tiangolo.com/>`_.
You can use adapters in different scenarios within Serve:
- Ray AIR ``ModelWrapper``
- Serve Deployment Graph ``DAGDriver``
- Embedded in Bring Your Own ``FastAPI`` Application
Let's go over them one by one.
Ray AIR ``ModelWrapper``
@ -178,7 +201,7 @@ to one click deploy pre-trained models.
For example, we provide a simple adapter for n-dimensional array.
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``input_schema`` field.
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``http_adapter`` field.
.. code-block:: python
@ -189,13 +212,13 @@ With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``
ModelWrapperDeployment.options(name="my_model").deploy(
my_ray_air_predictor,
my_ray_air_checkpoint,
input_schema=json_to_ndarray
http_adapter=json_to_ndarray
)
Serve Deployment Graph ``DAGDriver``
""""""""""""""""""""""""""""""""""""
In :ref:`Serve Deployment Graph <serve-deployment-graph>`, you can configure
``ray.serve.drivers.DAGDriver`` to accept an http adapter via it's ``input_schema`` field.
``ray.serve.drivers.DAGDriver`` to accept an http adapter via it's ``http_adapter`` field.
For example, the json request adapters parse JSON in HTTP body:
@ -207,7 +230,7 @@ For example, the json request adapters parse JSON in HTTP body:
with InputNode() as input_node:
...
dag = DAGDriver.bind(other_node, input_schema=json_request)
dag = DAGDriver.bind(other_node, http_adapter=json_request)
Embedded in Bring Your Own ``FastAPI`` Application

View file

@ -10,47 +10,47 @@ from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve.http_util import ASGIHTTPSender
from ray import serve
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.starlette_request"
InputSchemaFn = Callable[[Any], Any]
DEFAULT_HTTP_ADAPTER = "ray.serve.http_adapters.starlette_request"
HTTPAdapterFn = Callable[[Any], Any]
def load_input_schema(
input_schema: Optional[Union[str, InputSchemaFn]]
) -> InputSchemaFn:
if input_schema is None:
input_schema = DEFAULT_INPUT_SCHEMA
def load_http_adapter(
http_adapter: Optional[Union[str, HTTPAdapterFn]]
) -> HTTPAdapterFn:
if http_adapter is None:
http_adapter = DEFAULT_HTTP_ADAPTER
if isinstance(input_schema, str):
input_schema = import_attr(input_schema)
if isinstance(http_adapter, str):
http_adapter = import_attr(http_adapter)
if not inspect.isfunction(input_schema):
if not inspect.isfunction(http_adapter):
raise ValueError("input schema must be a callable function.")
if any(
param.annotation == inspect.Parameter.empty
for param in inspect.signature(input_schema).parameters.values()
for param in inspect.signature(http_adapter).parameters.values()
):
raise ValueError("input schema function's signature should be type annotated.")
return input_schema
return http_adapter
class SimpleSchemaIngress:
def __init__(self, input_schema: Optional[Union[str, InputSchemaFn]] = None):
"""Create a FastAPI endpoint annotated with input_schema dependency.
def __init__(self, http_adapter: Optional[Union[str, HTTPAdapterFn]] = None):
"""Create a FastAPI endpoint annotated with http_adapter dependency.
Args:
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
http_adapter(str, HTTPAdapterFn, None): The FastAPI input conversion
function. By default, Serve will directly pass in the request object
starlette.requests.Request. You can pass in any FastAPI dependency
resolver. When you pass in a string, Serve will import it.
Please refer to Serve HTTP adatper documentation to learn more.
"""
input_schema = load_input_schema(input_schema)
http_adapter = load_http_adapter(http_adapter)
self.app = FastAPI()
@self.app.get("/")
@self.app.post("/")
async def handle_request(inp=Depends(input_schema)):
async def handle_request(inp=Depends(http_adapter)):
resp = await self.predict(inp)
return resp
@ -72,10 +72,10 @@ class DAGDriver(SimpleSchemaIngress):
self,
dag_handle: RayServeDAGHandle,
*,
input_schema: Optional[Union[str, Callable]] = None,
http_adapter: Optional[Union[str, Callable]] = None,
):
self.dag_handle = dag_handle
super().__init__(input_schema)
super().__init__(http_adapter)
async def predict(self, *args, **kwargs):
"""Perform inference directly without HTTP."""

View file

@ -3,7 +3,7 @@ from typing import Dict, Optional, Type, Union
from ray._private.utils import import_attr
from ray.ml.checkpoint import Checkpoint
from ray.ml.predictor import Predictor
from ray.serve.drivers import InputSchemaFn, SimpleSchemaIngress
from ray.serve.drivers import HTTPAdapterFn, SimpleSchemaIngress
import ray
from ray import serve
@ -55,7 +55,7 @@ class ModelWrapper(SimpleSchemaIngress):
Serve will then call ``MyCheckpoint.from_uri("uri_to_load_from")`` to
instantiate the object.
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
http_adapter(str, HTTPAdapterFn, None): The FastAPI input conversion
function. By default, Serve will use the
:ref:`NdArray <serve-ndarray-schema>` schema and convert to numpy array.
You can pass in any FastAPI dependency resolver that returns
@ -70,8 +70,8 @@ class ModelWrapper(SimpleSchemaIngress):
self,
predictor_cls: Union[str, Type[Predictor]],
checkpoint: Union[Checkpoint, Dict],
input_schema: Union[
str, InputSchemaFn
http_adapter: Union[
str, HTTPAdapterFn
] = "ray.serve.http_adapters.json_to_ndarray",
batching_params: Optional[Union[Dict[str, int], bool]] = None,
):
@ -97,7 +97,7 @@ class ModelWrapper(SimpleSchemaIngress):
self.batched_predict = batched_predict
super().__init__(input_schema)
super().__init__(http_adapter)
async def predict(self, inp):
"""Perform inference directly without HTTP."""

View file

@ -142,7 +142,7 @@ def test_yaml_compatibility(serve_instance):
"checkpoint_cls": checkpoint_cls,
"uri": path,
},
"input_schema": schema_func,
"http_adapter": schema_func,
"batching_params": {"max_batch_size": 1},
},
}

View file

@ -140,7 +140,7 @@ async def json_resolver(request: starlette.requests.Request):
def test_single_func_deployment_dag(serve_instance, use_build):
with InputNode() as dag_input:
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
serve_dag = DAGDriver.bind(dag, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote([1, 2])) == 4
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
@ -163,7 +163,7 @@ def test_chained_function(serve_instance, use_build):
with pytest.raises(ValueError, match="Please provide a driver class"):
_ = serve.run(serve_dag)
handle = serve.run(DAGDriver.bind(serve_dag, input_schema=json_resolver))
handle = serve.run(DAGDriver.bind(serve_dag, http_adapter=json_resolver))
assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
@ -173,7 +173,7 @@ def test_simple_class_with_class_method(serve_instance, use_build):
with InputNode() as dag_input:
model = Model.bind(2, ratio=0.3)
dag = model.forward.bind(dag_input)
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
serve_dag = DAGDriver.bind(dag, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 0.6
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
@ -187,7 +187,7 @@ def test_func_class_with_class_method(serve_instance, use_build):
m1_output = m1.forward.bind(dag_input[0])
m2_output = m2.forward.bind(dag_input[1])
combine_output = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2])
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote([1, 2, 3])) == 8
@ -201,7 +201,7 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_b
m2 = Model.bind(3)
combine = Combine.bind(m1, m2=m2)
combine_output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 5
@ -214,7 +214,7 @@ def test_shared_deployment_handle(serve_instance, use_build):
m = Model.bind(2)
combine = Combine.bind(m, m2=m)
combine_output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 4
@ -228,7 +228,7 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use
m2 = Model.bind(3)
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(output, input_schema=json_resolver)
serve_dag = DAGDriver.bind(output, http_adapter=json_resolver)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 5
@ -269,7 +269,7 @@ def test_single_node_driver_sucess(serve_instance, use_build):
with InputNode() as input_node:
out = m1.forward.bind(input_node)
out = m2.forward.bind(out)
driver = DAGDriver.bind(out, input_schema=json_resolver)
driver = DAGDriver.bind(out, http_adapter=json_resolver)
handle = serve.run(driver)
assert ray.get(handle.predict.remote(39)) == 42
assert requests.post("http://127.0.0.1:8000/", json=39).json() == 42
@ -303,7 +303,7 @@ class TakeHandle:
def test_passing_handle(serve_instance, use_build):
child = Adder.bind(1)
parent = TakeHandle.bind(child)
driver = DAGDriver.bind(parent, input_schema=json_resolver)
driver = DAGDriver.bind(parent, http_adapter=json_resolver)
handle = serve.run(driver)
assert ray.get(handle.predict.remote(1)) == 2
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2

View file

@ -5,7 +5,7 @@ import requests
import starlette.requests
from starlette.testclient import TestClient
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_input_schema
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_http_adapter
from ray.serve.http_adapters import json_request
from ray.experimental.dag.input_node import InputNode
from ray import serve
@ -18,15 +18,15 @@ def my_resolver(a: int):
def test_loading_check():
with pytest.raises(ValueError, match="callable"):
load_input_schema(["not function"])
load_http_adapter(["not function"])
with pytest.raises(ValueError, match="type annotated"):
def func(a):
return a
load_input_schema(func)
load_http_adapter(func)
loaded_my_resolver = load_input_schema(
loaded_my_resolver = load_http_adapter(
"ray.serve.tests.test_pipeline_driver.my_resolver"
)
assert (loaded_my_resolver == my_resolver) or (
@ -42,7 +42,7 @@ def test_unit_schema_injection():
async def resolver(my_custom_param: int):
return my_custom_param
server = Impl(input_schema=resolver)
server = Impl(http_adapter=resolver)
client = TestClient(server.app)
response = client.post("/")
@ -92,7 +92,7 @@ def test_dag_driver_custom_schema(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)
handle = serve.run(DAGDriver.bind(dag, input_schema=resolver))
handle = serve.run(DAGDriver.bind(dag, http_adapter=resolver))
assert ray.get(handle.predict.remote(42)) == 42
resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100")
@ -110,7 +110,7 @@ def test_dag_driver_partial_input(serve_instance):
with InputNode() as inp:
dag = DAGDriver.bind(
combine.bind(echo.bind(inp[0]), echo.bind(inp[1]), echo.bind(inp[2])),
input_schema=json_request,
http_adapter=json_request,
)
handle = serve.run(dag)
assert ray.get(handle.predict.remote([1, 2, [3, 4]])) == [1, 2, [3, 4]]