[Doc][Serve] Add minimal docs for model wrappers and http adapters (#23536)

This commit is contained in:
Simon Mo 2022-03-29 11:33:14 -07:00 committed by GitHub
parent afd287eb93
commit cb1919b8d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 119 additions and 57 deletions

View file

@ -47,6 +47,7 @@ sphinx-external-toc==0.2.3
sphinxcontrib.yt==0.2.2
sphinx-sitemap==2.2.0
sphinx-thebe==0.1.1
autodoc_pydantic==1.6.1
# MyST
myst-parser==0.15.2

View file

@ -39,6 +39,7 @@ extensions = [
"sphinx.ext.coverage",
"sphinx_external_toc",
"sphinx_thebe",
"sphinxcontrib.autodoc_pydantic",
]
myst_enable_extensions = [

View file

@ -101,13 +101,14 @@ Predictors
:members:
:show-inheritance:
.. _air-serve-integration:
Serving
~~~~~~~
.. automodule:: ray.serve.model_wrappers
:members:
.. autoclass:: ray.serve.model_wrappers.ModelWrapperDeployment
.. autoclass:: ray.serve.model_wrappers.ModelWrapper
Outputs

View file

@ -152,6 +152,60 @@ To try it out, save a code snippet in a local python file (i.e. main.py) and in
ray start --head
python main.py
.. _serve-http-adapters:
HTTP Adapters
^^^^^^^^^^^^^
Ray Serve provides a suite of adapters to convert HTTP requests to ML inputs like `numpy` arrays.
You can just use it with :ref:`Ray AI Runtime (AIR) model wrapper<air-serve-integration>` feature
to one click deploy pre-trained models.
Alternatively, you can directly import them and put them into your FastAPI app.
For example, we provide a simple adapter for n-dimensional array.
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``input_schema`` field.
.. code-block:: python
from ray import serve
from ray.serve.http_adapters import json_to_ndarray
from ray.serve.model_wrappers import ModelWrapperDeployment
ModelWrapperDeployment.options(name="my_model").deploy(
my_ray_air_predictor,
my_ray_air_checkpoint,
input_schema=json_to_ndarray
)
You can also bring the adapter to your own FastAPI app using
`Depends <https://fastapi.tiangolo.com/tutorial/dependencies/#import-depends>`_.
The input schema will automatically be part of the generated OpenAPI schema with FastAPI.
.. code-block:: python
from fastapi import FastAPI, Depends
from ray.serve.http_adapters import json_to_ndarray
app = FastAPI()
@app.post("/endpoint")
async def endpoint(np_array = Depends(json_to_ndarray)):
...
It has the following schema for input:
.. _serve-ndarray-schema:
.. autopydantic_model:: ray.serve.http_adapters.NdArray
Here is a list of adapters and please feel free to `contribute more <https://github.com/ray-project/ray/issues/new/choose>`_!
.. automodule:: ray.serve.http_adapters
:members: json_to_ndarray, image_to_ndarray
Configuring HTTP Server Locations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View file

@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
import numpy as np
from ray.serve.utils import require_packages
from ray.ml.predictor import DataBatchType
import starlette.requests
@ -41,8 +40,8 @@ class NdArray(BaseModel):
)
def array_to_databatch(payload: NdArray) -> DataBatchType:
"""Accepts an NdArray from an HTTP body and converts it to a DataBatchType."""
def json_to_ndarray(payload: NdArray) -> np.ndarray:
"""Accepts an NdArray JSON from an HTTP body and converts it to a numpy array."""
arr = np.array(payload.array)
if payload.shape:
arr = arr.reshape(*payload.shape)
@ -60,9 +59,9 @@ def starlette_request(
@require_packages(["PIL"])
def image_to_databatch(img: bytes = File(...)) -> DataBatchType:
def image_to_ndarray(img: bytes = File(...)) -> np.ndarray:
"""Accepts a PIL-readable file from an HTTP form and converts
it to a DataBatchType.
it to a numpy array.
"""
from PIL import Image

View file

@ -39,36 +39,42 @@ def _load_predictor_cls(
class ModelWrapper(SimpleSchemaIngress):
"""Serve any Ray AIR predictor from an AIR checkpoint.
Args:
predictor_cls(str, Type[Predictor]): The class or path for predictor class.
The type must be a subclass of :class:`ray.ml.predicotr.Predictor`.
checkpoint(Checkpoint, dict): The checkpoint object or a dictionary describe
the object.
- The checkpoint object must be a subclass of
:class:`ray.ml.checkpoint.Checkpoint`.
- The dictionary should be in the form of
``{"checkpoint_cls": "import.path.MyCheckpoint",
"uri": "uri_to_load_from"}``.
Serve will then call ``MyCheckpoint.from_uri("uri_to_load_from")`` to
instantiate the object.
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
function. By default, Serve will use the
:ref:`NdArray <serve-ndarray-schema>` schema and convert to numpy array.
You can pass in any FastAPI dependency resolver that returns
an array. When you pass in a string, Serve will import it.
Please refer to :ref:`Serve HTTP adatpers <serve-http-adapters>`
documentation to learn more.
batching_params(dict, None, False): override the default parameters to
:func:`ray.serve.batch`. Pass ``False`` to disable batching.
"""
def __init__(
self,
predictor_cls: Union[str, Type[Predictor]],
checkpoint: Union[Checkpoint, Dict],
input_schema: Union[
str, InputSchemaFn
] = "ray.serve.http_adapters.array_to_databatch",
] = "ray.serve.http_adapters.json_to_ndarray",
batching_params: Optional[Union[Dict[str, int], bool]] = None,
):
"""Serve any Ray ML predictor from checkpoint.
Args:
predictor_cls(str, Type[Predictor]): The class or path for predictor class.
The type must be a subclass of ray.ml `Predictor`.
checkpoint(Checkpoint, dict): The checkpoint object or a dictionary describe
the object.
- The checkpoint object must be a subclass of ray.ml `Checkpoint`.
- The dictionary should be in the form of
{"checkpoint_cls": "import.path.MyCheckpoint",
"uri": "uri_to_load_from"}.
Serve will then call `MyCheckpoint.from_uri("uri_to_load_from")` to
instantiate the object.
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
function. By default, Serve will use the `NdArray` schema and convert to
numpy array. You can pass in any FastAPI dependency resolver that returns
an array. When you pass in a string, Serve will import it.
Please refer to Serve HTTP adatper documentation to learn more.
batching_params(dict, None, False): override the default parameters to
serve.batch. Pass `False` to disable batching.
"""
predictor_cls = _load_predictor_cls(predictor_cls)
checkpoint = _load_checkpoint(checkpoint)
@ -96,3 +102,8 @@ class ModelWrapper(SimpleSchemaIngress):
async def predict(self, inp):
"""Perform inference directly without HTTP."""
return await self.batched_predict(inp)
@serve.deployment
class ModelWrapperDeployment(ModelWrapper):
"""Ray Serve Deployment of the ModelWrapper class."""

View file

@ -3,7 +3,7 @@ import io
import numpy as np
import pytest
from PIL import Image
from ray.serve.http_adapters import NdArray, array_to_databatch, image_to_databatch
from ray.serve.http_adapters import NdArray, json_to_ndarray, image_to_ndarray
from ray.serve.utils import require_packages
@ -16,28 +16,28 @@ def test_require_packages():
func()
def test_array_to_databatch():
def test_json_to_ndarray():
np.testing.assert_equal(
array_to_databatch(NdArray(array=[1, 2], shape=None, dtype=None)),
json_to_ndarray(NdArray(array=[1, 2], shape=None, dtype=None)),
np.array([1, 2]),
)
np.testing.assert_equal(
array_to_databatch(NdArray(array=[[1], [2]], shape=None, dtype=None)),
json_to_ndarray(NdArray(array=[[1], [2]], shape=None, dtype=None)),
np.array([[1], [2]]),
)
np.testing.assert_equal(
array_to_databatch(NdArray(array=[[1], [2]], shape=[1, 2], dtype=None)),
json_to_ndarray(NdArray(array=[[1], [2]], shape=[1, 2], dtype=None)),
np.array([[1, 2]]),
)
np.testing.assert_equal(
array_to_databatch(NdArray(array=[[1.9], [2.1]], shape=[1, 2], dtype="int")),
json_to_ndarray(NdArray(array=[[1.9], [2.1]], shape=[1, 2], dtype="int")),
np.array([[1.9, 2.1]]).astype("int"),
)
def test_image_to_databatch():
def test_image_to_ndarray():
buffer = io.BytesIO()
arr = (np.random.rand(100, 100, 3) * 255).astype("uint8")
image = Image.fromarray(arr).convert("RGB")
image.save(buffer, format="png")
np.testing.assert_almost_equal(image_to_databatch(buffer.getvalue()), arr)
np.testing.assert_almost_equal(image_to_ndarray(buffer.getvalue()), arr)

View file

@ -9,11 +9,11 @@ from requests.adapters import HTTPAdapter, Retry
from ray._private.test_utils import wait_for_condition
from ray.ml.checkpoint import Checkpoint
from ray.ml.predictor import DataBatchType, Predictor
from ray.serve.model_wrappers import ModelWrapper
from ray.serve.model_wrappers import ModelWrapperDeployment
from ray.serve.pipeline.api import build
from ray.experimental.dag.input_node import InputNode
from ray.serve.api import RayServeDAGHandle
from ray.serve.http_adapters import array_to_databatch
from ray.serve.http_adapters import json_to_ndarray
import ray
from ray import serve
@ -24,7 +24,12 @@ class AdderPredictor(Predictor):
@classmethod
def from_checkpoint(cls, checkpoint: "AdderCheckpoint") -> "Predictor":
return cls(checkpoint.increment)
if checkpoint._data_dict:
return cls(checkpoint._data_dict["increment"])
elif checkpoint._local_path: # uri case
with open(checkpoint._local_path) as f:
return cls(json.load(f))
raise Exception("Unreachable")
def predict(self, data: DataBatchType) -> DataBatchType:
return [
@ -34,17 +39,7 @@ class AdderPredictor(Predictor):
class AdderCheckpoint(Checkpoint):
def __init__(self, increment: int):
self.increment = increment
@classmethod
def from_dict(cls, data: dict) -> "Checkpoint":
return cls(data["increment"])
@classmethod
def from_uri(cls, uri: str) -> "Checkpoint":
with open(uri) as f:
return cls(json.load(f))
pass
def adder_schema(query_param_arg: int) -> DataBatchType:
@ -57,7 +52,7 @@ def send_request(**requests_kargs):
def test_simple_adder(serve_instance):
serve.deployment(name="Adder")(ModelWrapper).deploy(
ModelWrapperDeployment.options(name="Adder").deploy(
predictor_cls=AdderPredictor,
checkpoint=AdderCheckpoint.from_dict({"increment": 2}),
)
@ -66,7 +61,7 @@ def test_simple_adder(serve_instance):
def test_batching(serve_instance):
serve.deployment(name="Adder")(ModelWrapper).deploy(
ModelWrapperDeployment.options(name="Adder").deploy(
predictor_cls=AdderPredictor,
checkpoint=AdderCheckpoint.from_dict({"increment": 2}),
batching_params=dict(max_batch_size=2, batch_wait_timeout_s=1000),
@ -87,7 +82,7 @@ class Ingress:
self.dag = dag
@app.post("/")
async def predict(self, data=Depends(array_to_databatch)):
async def predict(self, data=Depends(json_to_ndarray)):
return await self.dag.remote(data)
@ -100,7 +95,7 @@ def test_model_wrappers_in_pipeline(serve_instance):
checkpoint_cls = "ray.serve.tests.test_model_wrappers.AdderCheckpoint"
with InputNode() as dag_input:
m1 = ray.remote(ModelWrapper).bind(
m1 = ModelWrapperDeployment.bind(
predictor_cls=predictor_cls, # TODO: can't be the raw class right now?
checkpoint={ # TODO: can't be the raw object right now?
"checkpoint_cls": checkpoint_cls,
@ -140,7 +135,7 @@ def test_yaml_compatibility(serve_instance):
"deployments": [
{
"name": "Adder",
"import_path": "ray.serve.model_wrappers.ModelWrapper",
"import_path": "ray.serve.model_wrappers.ModelWrapperDeployment",
"init_kwargs": {
"predictor_cls": predictor_cls,
"checkpoint": {