mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Doc][Serve] Add minimal docs for model wrappers and http adapters (#23536)
This commit is contained in:
parent
afd287eb93
commit
cb1919b8d0
8 changed files with 119 additions and 57 deletions
|
@ -47,6 +47,7 @@ sphinx-external-toc==0.2.3
|
|||
sphinxcontrib.yt==0.2.2
|
||||
sphinx-sitemap==2.2.0
|
||||
sphinx-thebe==0.1.1
|
||||
autodoc_pydantic==1.6.1
|
||||
|
||||
# MyST
|
||||
myst-parser==0.15.2
|
||||
|
|
|
@ -39,6 +39,7 @@ extensions = [
|
|||
"sphinx.ext.coverage",
|
||||
"sphinx_external_toc",
|
||||
"sphinx_thebe",
|
||||
"sphinxcontrib.autodoc_pydantic",
|
||||
]
|
||||
|
||||
myst_enable_extensions = [
|
||||
|
|
|
@ -101,13 +101,14 @@ Predictors
|
|||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
.. _air-serve-integration:
|
||||
|
||||
Serving
|
||||
~~~~~~~
|
||||
|
||||
.. automodule:: ray.serve.model_wrappers
|
||||
:members:
|
||||
.. autoclass:: ray.serve.model_wrappers.ModelWrapperDeployment
|
||||
|
||||
.. autoclass:: ray.serve.model_wrappers.ModelWrapper
|
||||
|
||||
|
||||
Outputs
|
||||
|
|
|
@ -152,6 +152,60 @@ To try it out, save a code snippet in a local python file (i.e. main.py) and in
|
|||
ray start --head
|
||||
python main.py
|
||||
|
||||
.. _serve-http-adapters:
|
||||
|
||||
HTTP Adapters
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Ray Serve provides a suite of adapters to convert HTTP requests to ML inputs like `numpy` arrays.
|
||||
You can just use it with :ref:`Ray AI Runtime (AIR) model wrapper<air-serve-integration>` feature
|
||||
to one click deploy pre-trained models.
|
||||
Alternatively, you can directly import them and put them into your FastAPI app.
|
||||
|
||||
For example, we provide a simple adapter for n-dimensional array.
|
||||
|
||||
With :ref:`model wrappers<air-serve-integration>`, you can specify it via the ``input_schema`` field.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.http_adapters import json_to_ndarray
|
||||
from ray.serve.model_wrappers import ModelWrapperDeployment
|
||||
|
||||
ModelWrapperDeployment.options(name="my_model").deploy(
|
||||
my_ray_air_predictor,
|
||||
my_ray_air_checkpoint,
|
||||
input_schema=json_to_ndarray
|
||||
)
|
||||
|
||||
You can also bring the adapter to your own FastAPI app using
|
||||
`Depends <https://fastapi.tiangolo.com/tutorial/dependencies/#import-depends>`_.
|
||||
The input schema will automatically be part of the generated OpenAPI schema with FastAPI.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from fastapi import FastAPI, Depends
|
||||
from ray.serve.http_adapters import json_to_ndarray
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/endpoint")
|
||||
async def endpoint(np_array = Depends(json_to_ndarray)):
|
||||
...
|
||||
|
||||
It has the following schema for input:
|
||||
|
||||
.. _serve-ndarray-schema:
|
||||
|
||||
.. autopydantic_model:: ray.serve.http_adapters.NdArray
|
||||
|
||||
|
||||
Here is a list of adapters and please feel free to `contribute more <https://github.com/ray-project/ray/issues/new/choose>`_!
|
||||
|
||||
.. automodule:: ray.serve.http_adapters
|
||||
:members: json_to_ndarray, image_to_ndarray
|
||||
|
||||
|
||||
Configuring HTTP Server Locations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
|||
import numpy as np
|
||||
|
||||
from ray.serve.utils import require_packages
|
||||
from ray.ml.predictor import DataBatchType
|
||||
import starlette.requests
|
||||
|
||||
|
||||
|
@ -41,8 +40,8 @@ class NdArray(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
def array_to_databatch(payload: NdArray) -> DataBatchType:
|
||||
"""Accepts an NdArray from an HTTP body and converts it to a DataBatchType."""
|
||||
def json_to_ndarray(payload: NdArray) -> np.ndarray:
|
||||
"""Accepts an NdArray JSON from an HTTP body and converts it to a numpy array."""
|
||||
arr = np.array(payload.array)
|
||||
if payload.shape:
|
||||
arr = arr.reshape(*payload.shape)
|
||||
|
@ -60,9 +59,9 @@ def starlette_request(
|
|||
|
||||
|
||||
@require_packages(["PIL"])
|
||||
def image_to_databatch(img: bytes = File(...)) -> DataBatchType:
|
||||
def image_to_ndarray(img: bytes = File(...)) -> np.ndarray:
|
||||
"""Accepts a PIL-readable file from an HTTP form and converts
|
||||
it to a DataBatchType.
|
||||
it to a numpy array.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -39,36 +39,42 @@ def _load_predictor_cls(
|
|||
|
||||
|
||||
class ModelWrapper(SimpleSchemaIngress):
|
||||
"""Serve any Ray AIR predictor from an AIR checkpoint.
|
||||
|
||||
Args:
|
||||
predictor_cls(str, Type[Predictor]): The class or path for predictor class.
|
||||
The type must be a subclass of :class:`ray.ml.predicotr.Predictor`.
|
||||
checkpoint(Checkpoint, dict): The checkpoint object or a dictionary describe
|
||||
the object.
|
||||
|
||||
- The checkpoint object must be a subclass of
|
||||
:class:`ray.ml.checkpoint.Checkpoint`.
|
||||
- The dictionary should be in the form of
|
||||
``{"checkpoint_cls": "import.path.MyCheckpoint",
|
||||
"uri": "uri_to_load_from"}``.
|
||||
Serve will then call ``MyCheckpoint.from_uri("uri_to_load_from")`` to
|
||||
instantiate the object.
|
||||
|
||||
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
|
||||
function. By default, Serve will use the
|
||||
:ref:`NdArray <serve-ndarray-schema>` schema and convert to numpy array.
|
||||
You can pass in any FastAPI dependency resolver that returns
|
||||
an array. When you pass in a string, Serve will import it.
|
||||
Please refer to :ref:`Serve HTTP adatpers <serve-http-adapters>`
|
||||
documentation to learn more.
|
||||
batching_params(dict, None, False): override the default parameters to
|
||||
:func:`ray.serve.batch`. Pass ``False`` to disable batching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
predictor_cls: Union[str, Type[Predictor]],
|
||||
checkpoint: Union[Checkpoint, Dict],
|
||||
input_schema: Union[
|
||||
str, InputSchemaFn
|
||||
] = "ray.serve.http_adapters.array_to_databatch",
|
||||
] = "ray.serve.http_adapters.json_to_ndarray",
|
||||
batching_params: Optional[Union[Dict[str, int], bool]] = None,
|
||||
):
|
||||
"""Serve any Ray ML predictor from checkpoint.
|
||||
|
||||
Args:
|
||||
predictor_cls(str, Type[Predictor]): The class or path for predictor class.
|
||||
The type must be a subclass of ray.ml `Predictor`.
|
||||
checkpoint(Checkpoint, dict): The checkpoint object or a dictionary describe
|
||||
the object.
|
||||
- The checkpoint object must be a subclass of ray.ml `Checkpoint`.
|
||||
- The dictionary should be in the form of
|
||||
{"checkpoint_cls": "import.path.MyCheckpoint",
|
||||
"uri": "uri_to_load_from"}.
|
||||
Serve will then call `MyCheckpoint.from_uri("uri_to_load_from")` to
|
||||
instantiate the object.
|
||||
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
|
||||
function. By default, Serve will use the `NdArray` schema and convert to
|
||||
numpy array. You can pass in any FastAPI dependency resolver that returns
|
||||
an array. When you pass in a string, Serve will import it.
|
||||
Please refer to Serve HTTP adatper documentation to learn more.
|
||||
batching_params(dict, None, False): override the default parameters to
|
||||
serve.batch. Pass `False` to disable batching.
|
||||
"""
|
||||
predictor_cls = _load_predictor_cls(predictor_cls)
|
||||
checkpoint = _load_checkpoint(checkpoint)
|
||||
|
||||
|
@ -96,3 +102,8 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
async def predict(self, inp):
|
||||
"""Perform inference directly without HTTP."""
|
||||
return await self.batched_predict(inp)
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class ModelWrapperDeployment(ModelWrapper):
|
||||
"""Ray Serve Deployment of the ModelWrapper class."""
|
||||
|
|
|
@ -3,7 +3,7 @@ import io
|
|||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from ray.serve.http_adapters import NdArray, array_to_databatch, image_to_databatch
|
||||
from ray.serve.http_adapters import NdArray, json_to_ndarray, image_to_ndarray
|
||||
from ray.serve.utils import require_packages
|
||||
|
||||
|
||||
|
@ -16,28 +16,28 @@ def test_require_packages():
|
|||
func()
|
||||
|
||||
|
||||
def test_array_to_databatch():
|
||||
def test_json_to_ndarray():
|
||||
np.testing.assert_equal(
|
||||
array_to_databatch(NdArray(array=[1, 2], shape=None, dtype=None)),
|
||||
json_to_ndarray(NdArray(array=[1, 2], shape=None, dtype=None)),
|
||||
np.array([1, 2]),
|
||||
)
|
||||
np.testing.assert_equal(
|
||||
array_to_databatch(NdArray(array=[[1], [2]], shape=None, dtype=None)),
|
||||
json_to_ndarray(NdArray(array=[[1], [2]], shape=None, dtype=None)),
|
||||
np.array([[1], [2]]),
|
||||
)
|
||||
np.testing.assert_equal(
|
||||
array_to_databatch(NdArray(array=[[1], [2]], shape=[1, 2], dtype=None)),
|
||||
json_to_ndarray(NdArray(array=[[1], [2]], shape=[1, 2], dtype=None)),
|
||||
np.array([[1, 2]]),
|
||||
)
|
||||
np.testing.assert_equal(
|
||||
array_to_databatch(NdArray(array=[[1.9], [2.1]], shape=[1, 2], dtype="int")),
|
||||
json_to_ndarray(NdArray(array=[[1.9], [2.1]], shape=[1, 2], dtype="int")),
|
||||
np.array([[1.9, 2.1]]).astype("int"),
|
||||
)
|
||||
|
||||
|
||||
def test_image_to_databatch():
|
||||
def test_image_to_ndarray():
|
||||
buffer = io.BytesIO()
|
||||
arr = (np.random.rand(100, 100, 3) * 255).astype("uint8")
|
||||
image = Image.fromarray(arr).convert("RGB")
|
||||
image.save(buffer, format="png")
|
||||
np.testing.assert_almost_equal(image_to_databatch(buffer.getvalue()), arr)
|
||||
np.testing.assert_almost_equal(image_to_ndarray(buffer.getvalue()), arr)
|
||||
|
|
|
@ -9,11 +9,11 @@ from requests.adapters import HTTPAdapter, Retry
|
|||
from ray._private.test_utils import wait_for_condition
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.predictor import DataBatchType, Predictor
|
||||
from ray.serve.model_wrappers import ModelWrapper
|
||||
from ray.serve.model_wrappers import ModelWrapperDeployment
|
||||
from ray.serve.pipeline.api import build
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray.serve.api import RayServeDAGHandle
|
||||
from ray.serve.http_adapters import array_to_databatch
|
||||
from ray.serve.http_adapters import json_to_ndarray
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
@ -24,7 +24,12 @@ class AdderPredictor(Predictor):
|
|||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: "AdderCheckpoint") -> "Predictor":
|
||||
return cls(checkpoint.increment)
|
||||
if checkpoint._data_dict:
|
||||
return cls(checkpoint._data_dict["increment"])
|
||||
elif checkpoint._local_path: # uri case
|
||||
with open(checkpoint._local_path) as f:
|
||||
return cls(json.load(f))
|
||||
raise Exception("Unreachable")
|
||||
|
||||
def predict(self, data: DataBatchType) -> DataBatchType:
|
||||
return [
|
||||
|
@ -34,17 +39,7 @@ class AdderPredictor(Predictor):
|
|||
|
||||
|
||||
class AdderCheckpoint(Checkpoint):
|
||||
def __init__(self, increment: int):
|
||||
self.increment = increment
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Checkpoint":
|
||||
return cls(data["increment"])
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> "Checkpoint":
|
||||
with open(uri) as f:
|
||||
return cls(json.load(f))
|
||||
pass
|
||||
|
||||
|
||||
def adder_schema(query_param_arg: int) -> DataBatchType:
|
||||
|
@ -57,7 +52,7 @@ def send_request(**requests_kargs):
|
|||
|
||||
|
||||
def test_simple_adder(serve_instance):
|
||||
serve.deployment(name="Adder")(ModelWrapper).deploy(
|
||||
ModelWrapperDeployment.options(name="Adder").deploy(
|
||||
predictor_cls=AdderPredictor,
|
||||
checkpoint=AdderCheckpoint.from_dict({"increment": 2}),
|
||||
)
|
||||
|
@ -66,7 +61,7 @@ def test_simple_adder(serve_instance):
|
|||
|
||||
|
||||
def test_batching(serve_instance):
|
||||
serve.deployment(name="Adder")(ModelWrapper).deploy(
|
||||
ModelWrapperDeployment.options(name="Adder").deploy(
|
||||
predictor_cls=AdderPredictor,
|
||||
checkpoint=AdderCheckpoint.from_dict({"increment": 2}),
|
||||
batching_params=dict(max_batch_size=2, batch_wait_timeout_s=1000),
|
||||
|
@ -87,7 +82,7 @@ class Ingress:
|
|||
self.dag = dag
|
||||
|
||||
@app.post("/")
|
||||
async def predict(self, data=Depends(array_to_databatch)):
|
||||
async def predict(self, data=Depends(json_to_ndarray)):
|
||||
return await self.dag.remote(data)
|
||||
|
||||
|
||||
|
@ -100,7 +95,7 @@ def test_model_wrappers_in_pipeline(serve_instance):
|
|||
checkpoint_cls = "ray.serve.tests.test_model_wrappers.AdderCheckpoint"
|
||||
|
||||
with InputNode() as dag_input:
|
||||
m1 = ray.remote(ModelWrapper).bind(
|
||||
m1 = ModelWrapperDeployment.bind(
|
||||
predictor_cls=predictor_cls, # TODO: can't be the raw class right now?
|
||||
checkpoint={ # TODO: can't be the raw object right now?
|
||||
"checkpoint_cls": checkpoint_cls,
|
||||
|
@ -140,7 +135,7 @@ def test_yaml_compatibility(serve_instance):
|
|||
"deployments": [
|
||||
{
|
||||
"name": "Adder",
|
||||
"import_path": "ray.serve.model_wrappers.ModelWrapper",
|
||||
"import_path": "ray.serve.model_wrappers.ModelWrapperDeployment",
|
||||
"init_kwargs": {
|
||||
"predictor_cls": predictor_cls,
|
||||
"checkpoint": {
|
||||
|
|
Loading…
Add table
Reference in a new issue