mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Rename input_schema
to http_adapter
and clarify it in doc (#24353)
This commit is contained in:
parent
ff0ced1a64
commit
3378e1924e
7 changed files with 73 additions and 50 deletions
|
@ -298,7 +298,7 @@ Serve provides a default DAGDriver implementation that accepts HTTP request and
|
|||
You can configure how does the DAGDriver convert HTTP request types. By default, we directly send in a [```starlette.requests.Request```](https://www.starlette.io/requests/) object to represent the whole request. You can also specifies built-in adapters. In this example, we will use a `json_request` adapter that parses HTTP body with JSON parser.
|
||||
|
||||
```{tip}
|
||||
There are several useful adapters like ndarray JSON, image object, etc. You can checkout {ref}`the list of adapters here <serve-http-adapters>`. You can also easily plug in your own ```input_schema```.
|
||||
There are several useful adapters like ndarray JSON, image object, etc. You can checkout {ref}`the list of adapters here <serve-http-adapters>`. You can also easily plug in your own adapter by passing in in the ```http_adapter``` field.
|
||||
```
|
||||
|
||||
+++
|
||||
|
@ -316,7 +316,7 @@ with InputNode() as dag_input:
|
|||
|
||||
# Each serve dag has a driver deployment as ingress that can be user provided.
|
||||
serve_dag = DAGDriver.options(route_prefix="/my-dag", num_replicas=2).bind(
|
||||
dag, input_schema=json_request
|
||||
dag, http_adapter=json_request
|
||||
)
|
||||
|
||||
```
|
||||
|
@ -498,7 +498,7 @@ with InputNode() as dag_input:
|
|||
|
||||
# Each serve dag has a driver deployment as ingress that can be user provided.
|
||||
serve_dag = DAGDriver.options(route_prefix="/my-dag", num_replicas=2).bind(
|
||||
dag, input_schema=json_request
|
||||
dag, http_adapter=json_request
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -160,13 +160,36 @@ HTTP Adapters
|
|||
^^^^^^^^^^^^^
|
||||
|
||||
HTTP adapters are functions that convert raw HTTP request to Python types that you know and recognize.
|
||||
You can use it in three different scenarios:
|
||||
Its input arguments should be type annotated. At minimal, it should accept a ``starlette.requests.Request`` type.
|
||||
But it can also accept any type that's recognized by the FastAPI's dependency injection framework.
|
||||
|
||||
For example, here is an adapter that extra the json content from request.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def json_resolver(request: starlette.requests.Request):
|
||||
return await request.json()
|
||||
|
||||
Here is an adapter that accept two HTTP query parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def parse_query_args(field_a: int, field_b: str):
|
||||
return YourDataClass(field_a, field_b)
|
||||
|
||||
You can specify different type signatures to facilitate HTTP fields extraction
|
||||
include
|
||||
`query parameters <https://fastapi.tiangolo.com/tutorial/query-params/>`_,
|
||||
`body parameters <https://fastapi.tiangolo.com/tutorial/body/>`_,
|
||||
and `many other data types <https://fastapi.tiangolo.com/tutorial/extra-data-types/>`_.
|
||||
For more detail, you can take a look at `FastAPI documentation <https://fastapi.tiangolo.com/>`_.
|
||||
|
||||
You can use adapters in different scenarios within Serve:
|
||||
|
||||
- Ray AIR ``ModelWrapper``
|
||||
- Serve Deployment Graph ``DAGDriver``
|
||||
- Embedded in Bring Your Own ``FastAPI`` Application
|
||||
|
||||
|
||||
Let's go over them one by one.
|
||||
|
||||
Ray AIR ``ModelWrapper``
|
||||
|
@ -178,7 +201,7 @@ to one click deploy pre-trained models.
|
|||
|
||||
For example, we provide a simple adapter for n-dimensional array.
|
||||
|
||||
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``input_schema`` field.
|
||||
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``http_adapter`` field.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -189,13 +212,13 @@ With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``
|
|||
ModelWrapperDeployment.options(name="my_model").deploy(
|
||||
my_ray_air_predictor,
|
||||
my_ray_air_checkpoint,
|
||||
input_schema=json_to_ndarray
|
||||
http_adapter=json_to_ndarray
|
||||
)
|
||||
|
||||
Serve Deployment Graph ``DAGDriver``
|
||||
""""""""""""""""""""""""""""""""""""
|
||||
In :ref:`Serve Deployment Graph <serve-deployment-graph>`, you can configure
|
||||
``ray.serve.drivers.DAGDriver`` to accept an http adapter via it's ``input_schema`` field.
|
||||
``ray.serve.drivers.DAGDriver`` to accept an http adapter via it's ``http_adapter`` field.
|
||||
|
||||
For example, the json request adapters parse JSON in HTTP body:
|
||||
|
||||
|
@ -207,7 +230,7 @@ For example, the json request adapters parse JSON in HTTP body:
|
|||
|
||||
with InputNode() as input_node:
|
||||
...
|
||||
dag = DAGDriver.bind(other_node, input_schema=json_request)
|
||||
dag = DAGDriver.bind(other_node, http_adapter=json_request)
|
||||
|
||||
|
||||
Embedded in Bring Your Own ``FastAPI`` Application
|
||||
|
|
|
@ -10,47 +10,47 @@ from ray.serve.deployment_graph import RayServeDAGHandle
|
|||
from ray.serve.http_util import ASGIHTTPSender
|
||||
from ray import serve
|
||||
|
||||
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.starlette_request"
|
||||
InputSchemaFn = Callable[[Any], Any]
|
||||
DEFAULT_HTTP_ADAPTER = "ray.serve.http_adapters.starlette_request"
|
||||
HTTPAdapterFn = Callable[[Any], Any]
|
||||
|
||||
|
||||
def load_input_schema(
|
||||
input_schema: Optional[Union[str, InputSchemaFn]]
|
||||
) -> InputSchemaFn:
|
||||
if input_schema is None:
|
||||
input_schema = DEFAULT_INPUT_SCHEMA
|
||||
def load_http_adapter(
|
||||
http_adapter: Optional[Union[str, HTTPAdapterFn]]
|
||||
) -> HTTPAdapterFn:
|
||||
if http_adapter is None:
|
||||
http_adapter = DEFAULT_HTTP_ADAPTER
|
||||
|
||||
if isinstance(input_schema, str):
|
||||
input_schema = import_attr(input_schema)
|
||||
if isinstance(http_adapter, str):
|
||||
http_adapter = import_attr(http_adapter)
|
||||
|
||||
if not inspect.isfunction(input_schema):
|
||||
if not inspect.isfunction(http_adapter):
|
||||
raise ValueError("input schema must be a callable function.")
|
||||
|
||||
if any(
|
||||
param.annotation == inspect.Parameter.empty
|
||||
for param in inspect.signature(input_schema).parameters.values()
|
||||
for param in inspect.signature(http_adapter).parameters.values()
|
||||
):
|
||||
raise ValueError("input schema function's signature should be type annotated.")
|
||||
return input_schema
|
||||
return http_adapter
|
||||
|
||||
|
||||
class SimpleSchemaIngress:
|
||||
def __init__(self, input_schema: Optional[Union[str, InputSchemaFn]] = None):
|
||||
"""Create a FastAPI endpoint annotated with input_schema dependency.
|
||||
def __init__(self, http_adapter: Optional[Union[str, HTTPAdapterFn]] = None):
|
||||
"""Create a FastAPI endpoint annotated with http_adapter dependency.
|
||||
|
||||
Args:
|
||||
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
|
||||
http_adapter(str, HTTPAdapterFn, None): The FastAPI input conversion
|
||||
function. By default, Serve will directly pass in the request object
|
||||
starlette.requests.Request. You can pass in any FastAPI dependency
|
||||
resolver. When you pass in a string, Serve will import it.
|
||||
Please refer to Serve HTTP adatper documentation to learn more.
|
||||
"""
|
||||
input_schema = load_input_schema(input_schema)
|
||||
http_adapter = load_http_adapter(http_adapter)
|
||||
self.app = FastAPI()
|
||||
|
||||
@self.app.get("/")
|
||||
@self.app.post("/")
|
||||
async def handle_request(inp=Depends(input_schema)):
|
||||
async def handle_request(inp=Depends(http_adapter)):
|
||||
resp = await self.predict(inp)
|
||||
return resp
|
||||
|
||||
|
@ -72,10 +72,10 @@ class DAGDriver(SimpleSchemaIngress):
|
|||
self,
|
||||
dag_handle: RayServeDAGHandle,
|
||||
*,
|
||||
input_schema: Optional[Union[str, Callable]] = None,
|
||||
http_adapter: Optional[Union[str, Callable]] = None,
|
||||
):
|
||||
self.dag_handle = dag_handle
|
||||
super().__init__(input_schema)
|
||||
super().__init__(http_adapter)
|
||||
|
||||
async def predict(self, *args, **kwargs):
|
||||
"""Perform inference directly without HTTP."""
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, Optional, Type, Union
|
|||
from ray._private.utils import import_attr
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.predictor import Predictor
|
||||
from ray.serve.drivers import InputSchemaFn, SimpleSchemaIngress
|
||||
from ray.serve.drivers import HTTPAdapterFn, SimpleSchemaIngress
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
@ -55,7 +55,7 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
Serve will then call ``MyCheckpoint.from_uri("uri_to_load_from")`` to
|
||||
instantiate the object.
|
||||
|
||||
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
|
||||
http_adapter(str, HTTPAdapterFn, None): The FastAPI input conversion
|
||||
function. By default, Serve will use the
|
||||
:ref:`NdArray <serve-ndarray-schema>` schema and convert to numpy array.
|
||||
You can pass in any FastAPI dependency resolver that returns
|
||||
|
@ -70,8 +70,8 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
self,
|
||||
predictor_cls: Union[str, Type[Predictor]],
|
||||
checkpoint: Union[Checkpoint, Dict],
|
||||
input_schema: Union[
|
||||
str, InputSchemaFn
|
||||
http_adapter: Union[
|
||||
str, HTTPAdapterFn
|
||||
] = "ray.serve.http_adapters.json_to_ndarray",
|
||||
batching_params: Optional[Union[Dict[str, int], bool]] = None,
|
||||
):
|
||||
|
@ -97,7 +97,7 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
|
||||
self.batched_predict = batched_predict
|
||||
|
||||
super().__init__(input_schema)
|
||||
super().__init__(http_adapter)
|
||||
|
||||
async def predict(self, inp):
|
||||
"""Perform inference directly without HTTP."""
|
||||
|
|
|
@ -142,7 +142,7 @@ def test_yaml_compatibility(serve_instance):
|
|||
"checkpoint_cls": checkpoint_cls,
|
||||
"uri": path,
|
||||
},
|
||||
"input_schema": schema_func,
|
||||
"http_adapter": schema_func,
|
||||
"batching_params": {"max_batch_size": 1},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -140,7 +140,7 @@ async def json_resolver(request: starlette.requests.Request):
|
|||
def test_single_func_deployment_dag(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(dag, http_adapter=json_resolver)
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote([1, 2])) == 4
|
||||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
|
||||
|
@ -163,7 +163,7 @@ def test_chained_function(serve_instance, use_build):
|
|||
with pytest.raises(ValueError, match="Please provide a driver class"):
|
||||
_ = serve.run(serve_dag)
|
||||
|
||||
handle = serve.run(DAGDriver.bind(serve_dag, input_schema=json_resolver))
|
||||
handle = serve.run(DAGDriver.bind(serve_dag, http_adapter=json_resolver))
|
||||
assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2
|
||||
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
|
||||
|
||||
|
@ -173,7 +173,7 @@ def test_simple_class_with_class_method(serve_instance, use_build):
|
|||
with InputNode() as dag_input:
|
||||
model = Model.bind(2, ratio=0.3)
|
||||
dag = model.forward.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(dag, http_adapter=json_resolver)
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 0.6
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
|
||||
|
@ -187,7 +187,7 @@ def test_func_class_with_class_method(serve_instance, use_build):
|
|||
m1_output = m1.forward.bind(dag_input[0])
|
||||
m2_output = m2.forward.bind(dag_input[1])
|
||||
combine_output = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2])
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote([1, 2, 3])) == 8
|
||||
|
@ -201,7 +201,7 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_b
|
|||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2=m2)
|
||||
combine_output = combine.__call__.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 5
|
||||
|
@ -214,7 +214,7 @@ def test_shared_deployment_handle(serve_instance, use_build):
|
|||
m = Model.bind(2)
|
||||
combine = Combine.bind(m, m2=m)
|
||||
combine_output = combine.__call__.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(combine_output, http_adapter=json_resolver)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 4
|
||||
|
@ -228,7 +228,7 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use
|
|||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
output = combine.__call__.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(output, input_schema=json_resolver)
|
||||
serve_dag = DAGDriver.bind(output, http_adapter=json_resolver)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 5
|
||||
|
@ -269,7 +269,7 @@ def test_single_node_driver_sucess(serve_instance, use_build):
|
|||
with InputNode() as input_node:
|
||||
out = m1.forward.bind(input_node)
|
||||
out = m2.forward.bind(out)
|
||||
driver = DAGDriver.bind(out, input_schema=json_resolver)
|
||||
driver = DAGDriver.bind(out, http_adapter=json_resolver)
|
||||
handle = serve.run(driver)
|
||||
assert ray.get(handle.predict.remote(39)) == 42
|
||||
assert requests.post("http://127.0.0.1:8000/", json=39).json() == 42
|
||||
|
@ -303,7 +303,7 @@ class TakeHandle:
|
|||
def test_passing_handle(serve_instance, use_build):
|
||||
child = Adder.bind(1)
|
||||
parent = TakeHandle.bind(child)
|
||||
driver = DAGDriver.bind(parent, input_schema=json_resolver)
|
||||
driver = DAGDriver.bind(parent, http_adapter=json_resolver)
|
||||
handle = serve.run(driver)
|
||||
assert ray.get(handle.predict.remote(1)) == 2
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
|
||||
|
|
|
@ -5,7 +5,7 @@ import requests
|
|||
import starlette.requests
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_input_schema
|
||||
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_http_adapter
|
||||
from ray.serve.http_adapters import json_request
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray import serve
|
||||
|
@ -18,15 +18,15 @@ def my_resolver(a: int):
|
|||
|
||||
def test_loading_check():
|
||||
with pytest.raises(ValueError, match="callable"):
|
||||
load_input_schema(["not function"])
|
||||
load_http_adapter(["not function"])
|
||||
with pytest.raises(ValueError, match="type annotated"):
|
||||
|
||||
def func(a):
|
||||
return a
|
||||
|
||||
load_input_schema(func)
|
||||
load_http_adapter(func)
|
||||
|
||||
loaded_my_resolver = load_input_schema(
|
||||
loaded_my_resolver = load_http_adapter(
|
||||
"ray.serve.tests.test_pipeline_driver.my_resolver"
|
||||
)
|
||||
assert (loaded_my_resolver == my_resolver) or (
|
||||
|
@ -42,7 +42,7 @@ def test_unit_schema_injection():
|
|||
async def resolver(my_custom_param: int):
|
||||
return my_custom_param
|
||||
|
||||
server = Impl(input_schema=resolver)
|
||||
server = Impl(http_adapter=resolver)
|
||||
client = TestClient(server.app)
|
||||
|
||||
response = client.post("/")
|
||||
|
@ -92,7 +92,7 @@ def test_dag_driver_custom_schema(serve_instance):
|
|||
with InputNode() as inp:
|
||||
dag = echo.bind(inp)
|
||||
|
||||
handle = serve.run(DAGDriver.bind(dag, input_schema=resolver))
|
||||
handle = serve.run(DAGDriver.bind(dag, http_adapter=resolver))
|
||||
assert ray.get(handle.predict.remote(42)) == 42
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100")
|
||||
|
@ -110,7 +110,7 @@ def test_dag_driver_partial_input(serve_instance):
|
|||
with InputNode() as inp:
|
||||
dag = DAGDriver.bind(
|
||||
combine.bind(echo.bind(inp[0]), echo.bind(inp[1]), echo.bind(inp[2])),
|
||||
input_schema=json_request,
|
||||
http_adapter=json_request,
|
||||
)
|
||||
handle = serve.run(dag)
|
||||
assert ray.get(handle.predict.remote([1, 2, [3, 4]])) == [1, 2, [3, 4]]
|
||||
|
|
Loading…
Add table
Reference in a new issue