mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
This reverts commit 91a1c3411f
.
This commit is contained in:
parent
81dcf9ff35
commit
e507aa5758
7 changed files with 95 additions and 246 deletions
|
@ -331,14 +331,6 @@ py_test(
|
|||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_pipeline_driver",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive", "team:serve"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_pipeline_dag",
|
||||
size = "medium",
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
import inspect
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import starlette
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from ray._private.utils import import_attr
|
||||
from ray.serve.api import RayServeDAGHandle
|
||||
from ray.serve.http_util import ASGIHTTPSender
|
||||
from ray import serve
|
||||
|
||||
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.starlette_request"
|
||||
InputSchemaFn = Callable[[Any], Any]
|
||||
|
||||
|
||||
def load_input_schema(
|
||||
input_schema: Optional[Union[str, InputSchemaFn]]
|
||||
) -> InputSchemaFn:
|
||||
if input_schema is None:
|
||||
input_schema = DEFAULT_INPUT_SCHEMA
|
||||
|
||||
if isinstance(input_schema, str):
|
||||
input_schema = import_attr(input_schema)
|
||||
|
||||
if not inspect.isfunction(input_schema):
|
||||
raise ValueError("input schema must be a callable function.")
|
||||
|
||||
if any(
|
||||
param.annotation == inspect.Parameter.empty
|
||||
for param in inspect.signature(input_schema).parameters.values()
|
||||
):
|
||||
raise ValueError("input schema function's signature should be type annotated.")
|
||||
return input_schema
|
||||
|
||||
|
||||
class SimpleSchemaIngress:
|
||||
def __init__(self, input_schema: Optional[Union[str, InputSchemaFn]] = None):
|
||||
"""Create a FastAPI endpoint annotated with input_schema dependency.
|
||||
|
||||
Args:
|
||||
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
|
||||
function. By default, Serve will directly pass in the request object
|
||||
starlette.requests.Request. You can pass in any FastAPI dependency
|
||||
resolver. When you pass in a string, Serve will import it.
|
||||
Please refer to Serve HTTP adatper documentation to learn more.
|
||||
"""
|
||||
input_schema = load_input_schema(input_schema)
|
||||
self.app = FastAPI()
|
||||
|
||||
@self.app.get("/")
|
||||
@self.app.post("/")
|
||||
async def handle_request(inp=Depends(input_schema)):
|
||||
resp = await self.predict(inp)
|
||||
return resp
|
||||
|
||||
@abstractmethod
|
||||
async def predict(self, inp):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __call__(self, request: starlette.requests.Request):
|
||||
# NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to
|
||||
# generate FastAPI on the fly, we should find a way to unify the two.
|
||||
sender = ASGIHTTPSender()
|
||||
await self.app(request.scope, receive=request.receive, send=sender)
|
||||
return sender.build_asgi_response()
|
||||
|
||||
|
||||
@serve.deployment(route_prefix="/")
|
||||
class DAGDriver(SimpleSchemaIngress):
|
||||
def __init__(
|
||||
self,
|
||||
dag_handle: RayServeDAGHandle,
|
||||
*,
|
||||
input_schema: Optional[Union[str, Callable]] = None,
|
||||
):
|
||||
self.dag_handle = dag_handle
|
||||
super().__init__(input_schema)
|
||||
|
||||
async def predict(self, inp):
|
||||
"""Perform inference directly without HTTP."""
|
||||
return await self.dag_handle.remote(inp)
|
|
@ -7,7 +7,6 @@ import numpy as np
|
|||
|
||||
from ray.serve.utils import require_packages
|
||||
from ray.ml.predictor import DataBatchType
|
||||
import starlette.requests
|
||||
|
||||
|
||||
_1DArray = List[float]
|
||||
|
@ -51,14 +50,6 @@ def array_to_databatch(payload: NdArray) -> DataBatchType:
|
|||
return arr
|
||||
|
||||
|
||||
def starlette_request(
|
||||
request: starlette.requests.Request,
|
||||
) -> starlette.requests.Request:
|
||||
"""Returns the raw request object."""
|
||||
# NOTE(simon): This adapter is used for ease of getting started.
|
||||
return request
|
||||
|
||||
|
||||
@require_packages(["PIL"])
|
||||
def image_to_databatch(img: bytes = File(...)) -> DataBatchType:
|
||||
"""Accepts a PIL-readable file from an HTTP form and converts
|
||||
|
|
21
python/ray/serve/ingress.py
Normal file
21
python/ray/serve/ingress.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
import starlette
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.api import DAGHandle
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class DAGDriver:
|
||||
def __init__(
|
||||
self,
|
||||
dag_handle: DAGHandle,
|
||||
*,
|
||||
input_schema: Optional[Union[str, Callable]] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __call__(self, request: starlette.requests.Request):
|
||||
"""Parse input schema and pass the result to the DAG handle."""
|
||||
raise NotImplementedError()
|
|
@ -1,11 +1,31 @@
|
|||
from typing import Dict, Optional, Type, Union
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
import starlette.requests
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from ray._private.utils import import_attr
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.predictor import Predictor
|
||||
from ray.serve.drivers import InputSchemaFn, SimpleSchemaIngress
|
||||
import ray
|
||||
from ray.ml.predictor import Predictor, DataBatchType
|
||||
from ray.serve.http_util import ASGIHTTPSender
|
||||
from ray import serve
|
||||
import ray
|
||||
|
||||
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.array_to_databatch"
|
||||
InputSchemaFn = Callable[[Any], DataBatchType]
|
||||
|
||||
|
||||
def _load_input_schema(
|
||||
input_schema: Optional[Union[str, InputSchemaFn]]
|
||||
) -> InputSchemaFn:
|
||||
if input_schema is None:
|
||||
input_schema = DEFAULT_INPUT_SCHEMA
|
||||
|
||||
if isinstance(input_schema, str):
|
||||
input_schema = import_attr(input_schema)
|
||||
|
||||
assert inspect.isfunction(input_schema), "input schema must be a callable function."
|
||||
return input_schema
|
||||
|
||||
|
||||
def _load_checkpoint(
|
||||
|
@ -38,14 +58,12 @@ def _load_predictor_cls(
|
|||
return predictor_cls
|
||||
|
||||
|
||||
class ModelWrapper(SimpleSchemaIngress):
|
||||
class ModelWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
predictor_cls: Union[str, Type[Predictor]],
|
||||
checkpoint: Union[Checkpoint, Dict],
|
||||
input_schema: Union[
|
||||
str, InputSchemaFn
|
||||
] = "ray.serve.http_adapters.array_to_databatch",
|
||||
input_schema: Optional[Union[str, InputSchemaFn]] = None,
|
||||
batching_params: Optional[Union[Dict[str, int], bool]] = None,
|
||||
):
|
||||
"""Serve any Ray ML predictor from checkpoint.
|
||||
|
@ -73,6 +91,7 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
checkpoint = _load_checkpoint(checkpoint)
|
||||
|
||||
self.model = predictor_cls.from_checkpoint(checkpoint)
|
||||
self.app = FastAPI()
|
||||
|
||||
# Configure Batching
|
||||
if batching_params is False:
|
||||
|
@ -91,8 +110,21 @@ class ModelWrapper(SimpleSchemaIngress):
|
|||
|
||||
self.batched_predict = batched_predict
|
||||
|
||||
super().__init__(input_schema)
|
||||
# Configure Input Schema
|
||||
input_schema = _load_input_schema(input_schema)
|
||||
|
||||
@self.app.get("/")
|
||||
@self.app.post("/")
|
||||
async def handle_request(inp=Depends(input_schema)):
|
||||
return await batched_predict(inp)
|
||||
|
||||
async def __call__(self, request: starlette.requests.Request):
|
||||
# NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to
|
||||
# generate FastAPI on the fly, we should find a way to unify the two.
|
||||
sender = ASGIHTTPSender()
|
||||
await self.app(request.scope, receive=request.receive, send=sender)
|
||||
return sender.build_asgi_response()
|
||||
|
||||
async def predict(self, inp):
|
||||
"""Perform inference directly without HTTP."""
|
||||
"""Performing inference directly without HTTP."""
|
||||
return await self.batched_predict(inp)
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
import requests
|
||||
from typing import TypeVar, Any
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.api import RayServeDAGHandle
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
from ray.serve.drivers import DAGDriver
|
||||
import starlette.requests
|
||||
|
||||
|
||||
RayHandleLike = TypeVar("RayHandleLike")
|
||||
NESTED_HANDLE_KEY = "nested_handle"
|
||||
|
@ -104,6 +101,16 @@ class Adder:
|
|||
__call__ = forward
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Driver:
|
||||
def __init__(self, dag: RayServeDAGHandle):
|
||||
self.dag = dag
|
||||
|
||||
async def __call__(self, inp: Any) -> Any:
|
||||
print(f"Driver got {inp}")
|
||||
return await self.dag.remote(inp)
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class NoargDriver:
|
||||
def __init__(self, dag: RayServeDAGHandle):
|
||||
|
@ -122,17 +129,13 @@ def test_single_func_no_input(serve_instance):
|
|||
assert requests.get("http://127.0.0.1:8000/").text == "hello"
|
||||
|
||||
|
||||
async def json_resolver(request: starlette.requests.Request):
|
||||
return await request.json()
|
||||
|
||||
|
||||
def test_single_func_deployment_dag(serve_instance):
|
||||
with InputNode() as dag_input:
|
||||
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(dag)
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote([1, 2])) == 4
|
||||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
|
||||
assert ray.get(handle.remote([1, 2])) == 4
|
||||
# TODO (simon): ModelWrapper and HTTP adapter ?
|
||||
|
||||
|
||||
def test_chained_function(serve_instance):
|
||||
|
@ -151,19 +154,17 @@ def test_chained_function(serve_instance):
|
|||
with pytest.raises(ValueError, match="Please provide a driver class"):
|
||||
_ = serve.run(serve_dag)
|
||||
|
||||
handle = serve.run(DAGDriver.bind(serve_dag, input_schema=json_resolver))
|
||||
assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2
|
||||
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
|
||||
handle = serve.run(Driver.bind(serve_dag))
|
||||
assert ray.get(handle.remote(2)) == 6 # 2 + 2*2
|
||||
|
||||
|
||||
def test_simple_class_with_class_method(serve_instance):
|
||||
with InputNode() as dag_input:
|
||||
model = Model.bind(2, ratio=0.3)
|
||||
dag = model.forward.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(dag)
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 0.6
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
|
||||
assert ray.get(handle.remote(1)) == 0.6
|
||||
|
||||
|
||||
def test_func_class_with_class_method(serve_instance):
|
||||
|
@ -173,11 +174,11 @@ def test_func_class_with_class_method(serve_instance):
|
|||
m1_output = m1.forward.bind(dag_input[0])
|
||||
m2_output = m2.forward.bind(dag_input[1])
|
||||
combine_output = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2])
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(combine_output)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote([1, 2, 3])) == 8
|
||||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
|
||||
assert ray.get(handle.remote([1, 2, 3])) == 8
|
||||
# TODO (simon): ModelWrapper and HTTP adapter ?
|
||||
|
||||
|
||||
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
||||
|
@ -186,11 +187,10 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
|||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2=m2)
|
||||
combine_output = combine.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(combine_output)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 5
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
|
||||
assert ray.get(handle.remote(1)) == 5
|
||||
|
||||
|
||||
def test_shared_deployment_handle(serve_instance):
|
||||
|
@ -198,11 +198,10 @@ def test_shared_deployment_handle(serve_instance):
|
|||
m = Model.bind(2)
|
||||
combine = Combine.bind(m, m2=m)
|
||||
combine_output = combine.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(combine_output)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 4
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
|
||||
assert ray.get(handle.remote(1)) == 4
|
||||
|
||||
|
||||
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
|
||||
|
@ -211,11 +210,10 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
|
|||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
output = combine.bind(dag_input)
|
||||
serve_dag = DAGDriver.bind(output, input_schema=json_resolver)
|
||||
serve_dag = Driver.bind(output)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
assert ray.get(handle.predict.remote(1)) == 5
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
|
||||
assert ray.get(handle.remote(1)) == 5
|
||||
|
||||
|
||||
def test_class_factory(serve_instance):
|
||||
|
@ -250,10 +248,9 @@ def test_single_node_driver_sucess(serve_instance):
|
|||
with InputNode() as input_node:
|
||||
out = m1.forward.bind(input_node)
|
||||
out = m2.forward.bind(out)
|
||||
driver = DAGDriver.bind(out, input_schema=json_resolver)
|
||||
driver = Driver.bind(out)
|
||||
handle = serve.run(driver)
|
||||
assert ray.get(handle.predict.remote(39)) == 42
|
||||
assert requests.post("http://127.0.0.1:8000/", json=39).json() == 42
|
||||
assert ray.get(handle.remote(39)) == 42
|
||||
|
||||
|
||||
def test_options_and_names(serve_instance):
|
||||
|
@ -283,10 +280,9 @@ class TakeHandle:
|
|||
def test_passing_handle(serve_instance):
|
||||
child = Adder.bind(1)
|
||||
parent = TakeHandle.bind(child)
|
||||
driver = DAGDriver.bind(parent, input_schema=json_resolver)
|
||||
driver = Driver.bind(parent)
|
||||
handle = serve.run(driver)
|
||||
assert ray.get(handle.predict.remote(1)) == 2
|
||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
|
||||
assert ray.get(handle.remote(1)) == 2
|
||||
|
||||
|
||||
def test_passing_handle_in_obj(serve_instance):
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
import sys
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import starlette.requests
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_input_schema
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray import serve
|
||||
import ray
|
||||
|
||||
|
||||
def my_resolver(a: int):
|
||||
return a
|
||||
|
||||
|
||||
def test_loading_check():
|
||||
with pytest.raises(ValueError, match="callable"):
|
||||
load_input_schema(["not function"])
|
||||
with pytest.raises(ValueError, match="type annotated"):
|
||||
|
||||
def func(a):
|
||||
return a
|
||||
|
||||
load_input_schema(func)
|
||||
assert (
|
||||
load_input_schema("ray.serve.tests.test_pipeline_driver.my_resolver")
|
||||
== my_resolver
|
||||
)
|
||||
|
||||
|
||||
def test_unit_schema_injection():
|
||||
class Impl(SimpleSchemaIngress):
|
||||
async def predict(self, inp):
|
||||
return inp
|
||||
|
||||
async def resolver(my_custom_param: int):
|
||||
return my_custom_param
|
||||
|
||||
server = Impl(input_schema=resolver)
|
||||
client = TestClient(server.app)
|
||||
|
||||
response = client.post("/")
|
||||
assert response.status_code == 422
|
||||
|
||||
response = client.post("/?my_custom_param=1")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "1"
|
||||
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["paths"]["/"]["get"]["parameters"][0] == {
|
||||
"required": True,
|
||||
"schema": {"title": "My Custom Param", "type": "integer"},
|
||||
"name": "my_custom_param",
|
||||
"in": "query",
|
||||
}
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def echo(inp):
|
||||
# FastAPI can't handle this.
|
||||
if isinstance(inp, starlette.requests.Request):
|
||||
return "starlette!"
|
||||
return inp
|
||||
|
||||
|
||||
def test_dag_driver_default(serve_instance):
|
||||
with InputNode() as inp:
|
||||
dag = echo.bind(inp)
|
||||
|
||||
handle = serve.run(DAGDriver.bind(dag))
|
||||
assert ray.get(handle.predict.remote(42)) == 42
|
||||
|
||||
resp = requests.post("http://127.0.0.1:8000/", json={"array": [1]})
|
||||
print(resp.text)
|
||||
|
||||
resp.raise_for_status()
|
||||
assert resp.json() == "starlette!"
|
||||
|
||||
|
||||
async def resolver(my_custom_param: int):
|
||||
return my_custom_param
|
||||
|
||||
|
||||
def test_dag_driver_custom_schema(serve_instance):
|
||||
with InputNode() as inp:
|
||||
dag = echo.bind(inp)
|
||||
|
||||
handle = serve.run(DAGDriver.bind(dag, input_schema=resolver))
|
||||
assert ray.get(handle.predict.remote(42)) == 42
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100")
|
||||
print(resp.text)
|
||||
resp.raise_for_status()
|
||||
assert resp.json() == 100
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
Loading…
Add table
Reference in a new issue