diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 179564fcb..e8f3949cd 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -331,14 +331,6 @@ py_test( deps = [":serve_lib"], ) -py_test( - name = "test_pipeline_driver", - size = "small", - srcs = serve_tests_srcs, - tags = ["exclusive", "team:serve"], - deps = [":serve_lib"], -) - py_test( name = "test_pipeline_dag", size = "medium", diff --git a/python/ray/serve/drivers.py b/python/ray/serve/drivers.py deleted file mode 100644 index 2615ea5d5..000000000 --- a/python/ray/serve/drivers.py +++ /dev/null @@ -1,82 +0,0 @@ -import inspect -from abc import abstractmethod -from typing import Any, Callable, Optional, Union - -import starlette -from fastapi import Depends, FastAPI - -from ray._private.utils import import_attr -from ray.serve.api import RayServeDAGHandle -from ray.serve.http_util import ASGIHTTPSender -from ray import serve - -DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.starlette_request" -InputSchemaFn = Callable[[Any], Any] - - -def load_input_schema( - input_schema: Optional[Union[str, InputSchemaFn]] -) -> InputSchemaFn: - if input_schema is None: - input_schema = DEFAULT_INPUT_SCHEMA - - if isinstance(input_schema, str): - input_schema = import_attr(input_schema) - - if not inspect.isfunction(input_schema): - raise ValueError("input schema must be a callable function.") - - if any( - param.annotation == inspect.Parameter.empty - for param in inspect.signature(input_schema).parameters.values() - ): - raise ValueError("input schema function's signature should be type annotated.") - return input_schema - - -class SimpleSchemaIngress: - def __init__(self, input_schema: Optional[Union[str, InputSchemaFn]] = None): - """Create a FastAPI endpoint annotated with input_schema dependency. - - Args: - input_schema(str, InputSchemaFn, None): The FastAPI input conversion - function. By default, Serve will directly pass in the request object - starlette.requests.Request. You can pass in any FastAPI dependency - resolver. When you pass in a string, Serve will import it. - Please refer to Serve HTTP adatper documentation to learn more. - """ - input_schema = load_input_schema(input_schema) - self.app = FastAPI() - - @self.app.get("/") - @self.app.post("/") - async def handle_request(inp=Depends(input_schema)): - resp = await self.predict(inp) - return resp - - @abstractmethod - async def predict(self, inp): - raise NotImplementedError() - - async def __call__(self, request: starlette.requests.Request): - # NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to - # generate FastAPI on the fly, we should find a way to unify the two. - sender = ASGIHTTPSender() - await self.app(request.scope, receive=request.receive, send=sender) - return sender.build_asgi_response() - - -@serve.deployment(route_prefix="/") -class DAGDriver(SimpleSchemaIngress): - def __init__( - self, - dag_handle: RayServeDAGHandle, - *, - input_schema: Optional[Union[str, Callable]] = None, - ): - self.dag_handle = dag_handle - super().__init__(input_schema) - - async def predict(self, inp): - """Perform inference directly without HTTP.""" - return await self.dag_handle.remote(inp) diff --git a/python/ray/serve/http_adapters.py b/python/ray/serve/http_adapters.py index 359aaf457..d9d250c3d 100644 --- a/python/ray/serve/http_adapters.py +++ b/python/ray/serve/http_adapters.py @@ -7,7 +7,6 @@ import numpy as np from ray.serve.utils import require_packages from ray.ml.predictor import DataBatchType -import starlette.requests _1DArray = List[float] @@ -51,14 +50,6 @@ def array_to_databatch(payload: NdArray) -> DataBatchType: return arr -def starlette_request( - request: starlette.requests.Request, -) -> starlette.requests.Request: - """Returns the raw request object.""" - # NOTE(simon): This adapter is used for ease of getting started. - return request - - @require_packages(["PIL"]) def image_to_databatch(img: bytes = File(...)) -> DataBatchType: """Accepts a PIL-readable file from an HTTP form and converts diff --git a/python/ray/serve/ingress.py b/python/ray/serve/ingress.py new file mode 100644 index 000000000..8d1724ac4 --- /dev/null +++ b/python/ray/serve/ingress.py @@ -0,0 +1,21 @@ +from typing import Callable, Optional, Union + +import starlette + +from ray import serve +from ray.serve.api import DAGHandle + + +@serve.deployment +class DAGDriver: + def __init__( + self, + dag_handle: DAGHandle, + *, + input_schema: Optional[Union[str, Callable]] = None, + ): + raise NotImplementedError() + + async def __call__(self, request: starlette.requests.Request): + """Parse input schema and pass the result to the DAG handle.""" + raise NotImplementedError() diff --git a/python/ray/serve/model_wrappers.py b/python/ray/serve/model_wrappers.py index 4e2b4263b..13fd5e989 100644 --- a/python/ray/serve/model_wrappers.py +++ b/python/ray/serve/model_wrappers.py @@ -1,11 +1,31 @@ -from typing import Dict, Optional, Type, Union +import inspect +from typing import Any, Callable, Dict, Optional, Type, Union + +import starlette.requests +from fastapi import Depends, FastAPI from ray._private.utils import import_attr from ray.ml.checkpoint import Checkpoint -from ray.ml.predictor import Predictor -from ray.serve.drivers import InputSchemaFn, SimpleSchemaIngress -import ray +from ray.ml.predictor import Predictor, DataBatchType +from ray.serve.http_util import ASGIHTTPSender from ray import serve +import ray + +DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.array_to_databatch" +InputSchemaFn = Callable[[Any], DataBatchType] + + +def _load_input_schema( + input_schema: Optional[Union[str, InputSchemaFn]] +) -> InputSchemaFn: + if input_schema is None: + input_schema = DEFAULT_INPUT_SCHEMA + + if isinstance(input_schema, str): + input_schema = import_attr(input_schema) + + assert inspect.isfunction(input_schema), "input schema must be a callable function." + return input_schema def _load_checkpoint( @@ -38,14 +58,12 @@ def _load_predictor_cls( return predictor_cls -class ModelWrapper(SimpleSchemaIngress): +class ModelWrapper: def __init__( self, predictor_cls: Union[str, Type[Predictor]], checkpoint: Union[Checkpoint, Dict], - input_schema: Union[ - str, InputSchemaFn - ] = "ray.serve.http_adapters.array_to_databatch", + input_schema: Optional[Union[str, InputSchemaFn]] = None, batching_params: Optional[Union[Dict[str, int], bool]] = None, ): """Serve any Ray ML predictor from checkpoint. @@ -73,6 +91,7 @@ class ModelWrapper(SimpleSchemaIngress): checkpoint = _load_checkpoint(checkpoint) self.model = predictor_cls.from_checkpoint(checkpoint) + self.app = FastAPI() # Configure Batching if batching_params is False: @@ -91,8 +110,21 @@ class ModelWrapper(SimpleSchemaIngress): self.batched_predict = batched_predict - super().__init__(input_schema) + # Configure Input Schema + input_schema = _load_input_schema(input_schema) + + @self.app.get("/") + @self.app.post("/") + async def handle_request(inp=Depends(input_schema)): + return await batched_predict(inp) + + async def __call__(self, request: starlette.requests.Request): + # NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to + # generate FastAPI on the fly, we should find a way to unify the two. + sender = ASGIHTTPSender() + await self.app(request.scope, receive=request.receive, send=sender) + return sender.build_asgi_response() async def predict(self, inp): - """Perform inference directly without HTTP.""" + """Performing inference directly without HTTP.""" return await self.batched_predict(inp) diff --git a/python/ray/serve/tests/test_pipeline_dag.py b/python/ray/serve/tests/test_pipeline_dag.py index 1506b211b..70df0e19a 100644 --- a/python/ray/serve/tests/test_pipeline_dag.py +++ b/python/ray/serve/tests/test_pipeline_dag.py @@ -1,19 +1,16 @@ import pytest import os import sys -from typing import TypeVar +import requests +from typing import TypeVar, Any import numpy as np -import requests import ray from ray import serve from ray.serve.api import RayServeDAGHandle from ray.experimental.dag.input_node import InputNode from ray.serve.pipeline.api import build as pipeline_build -from ray.serve.drivers import DAGDriver -import starlette.requests - RayHandleLike = TypeVar("RayHandleLike") NESTED_HANDLE_KEY = "nested_handle" @@ -104,6 +101,16 @@ class Adder: __call__ = forward +@serve.deployment +class Driver: + def __init__(self, dag: RayServeDAGHandle): + self.dag = dag + + async def __call__(self, inp: Any) -> Any: + print(f"Driver got {inp}") + return await self.dag.remote(inp) + + @serve.deployment class NoargDriver: def __init__(self, dag: RayServeDAGHandle): @@ -122,17 +129,13 @@ def test_single_func_no_input(serve_instance): assert requests.get("http://127.0.0.1:8000/").text == "hello" -async def json_resolver(request: starlette.requests.Request): - return await request.json() - - def test_single_func_deployment_dag(serve_instance): with InputNode() as dag_input: dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1) - serve_dag = DAGDriver.bind(dag, input_schema=json_resolver) + serve_dag = Driver.bind(dag) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote([1, 2])) == 4 - assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4 + assert ray.get(handle.remote([1, 2])) == 4 + # TODO (simon): ModelWrapper and HTTP adapter ? def test_chained_function(serve_instance): @@ -151,19 +154,17 @@ def test_chained_function(serve_instance): with pytest.raises(ValueError, match="Please provide a driver class"): _ = serve.run(serve_dag) - handle = serve.run(DAGDriver.bind(serve_dag, input_schema=json_resolver)) - assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2 - assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6 + handle = serve.run(Driver.bind(serve_dag)) + assert ray.get(handle.remote(2)) == 6 # 2 + 2*2 def test_simple_class_with_class_method(serve_instance): with InputNode() as dag_input: model = Model.bind(2, ratio=0.3) dag = model.forward.bind(dag_input) - serve_dag = DAGDriver.bind(dag, input_schema=json_resolver) + serve_dag = Driver.bind(dag) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote(1)) == 0.6 - assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6 + assert ray.get(handle.remote(1)) == 0.6 def test_func_class_with_class_method(serve_instance): @@ -173,11 +174,11 @@ def test_func_class_with_class_method(serve_instance): m1_output = m1.forward.bind(dag_input[0]) m2_output = m2.forward.bind(dag_input[1]) combine_output = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2]) - serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver) + serve_dag = Driver.bind(combine_output) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote([1, 2, 3])) == 8 - assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8 + assert ray.get(handle.remote([1, 2, 3])) == 8 + # TODO (simon): ModelWrapper and HTTP adapter ? def test_multi_instantiation_class_deployment_in_init_args(serve_instance): @@ -186,11 +187,10 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance): m2 = Model.bind(3) combine = Combine.bind(m1, m2=m2) combine_output = combine.bind(dag_input) - serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver) + serve_dag = Driver.bind(combine_output) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote(1)) == 5 - assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5 + assert ray.get(handle.remote(1)) == 5 def test_shared_deployment_handle(serve_instance): @@ -198,11 +198,10 @@ def test_shared_deployment_handle(serve_instance): m = Model.bind(2) combine = Combine.bind(m, m2=m) combine_output = combine.bind(dag_input) - serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver) + serve_dag = Driver.bind(combine_output) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote(1)) == 4 - assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4 + assert ray.get(handle.remote(1)) == 4 def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance): @@ -211,11 +210,10 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance): m2 = Model.bind(3) combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True) output = combine.bind(dag_input) - serve_dag = DAGDriver.bind(output, input_schema=json_resolver) + serve_dag = Driver.bind(output) handle = serve.run(serve_dag) - assert ray.get(handle.predict.remote(1)) == 5 - assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5 + assert ray.get(handle.remote(1)) == 5 def test_class_factory(serve_instance): @@ -250,10 +248,9 @@ def test_single_node_driver_sucess(serve_instance): with InputNode() as input_node: out = m1.forward.bind(input_node) out = m2.forward.bind(out) - driver = DAGDriver.bind(out, input_schema=json_resolver) + driver = Driver.bind(out) handle = serve.run(driver) - assert ray.get(handle.predict.remote(39)) == 42 - assert requests.post("http://127.0.0.1:8000/", json=39).json() == 42 + assert ray.get(handle.remote(39)) == 42 def test_options_and_names(serve_instance): @@ -283,10 +280,9 @@ class TakeHandle: def test_passing_handle(serve_instance): child = Adder.bind(1) parent = TakeHandle.bind(child) - driver = DAGDriver.bind(parent, input_schema=json_resolver) + driver = Driver.bind(parent) handle = serve.run(driver) - assert ray.get(handle.predict.remote(1)) == 2 - assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2 + assert ray.get(handle.remote(1)) == 2 def test_passing_handle_in_obj(serve_instance): diff --git a/python/ray/serve/tests/test_pipeline_driver.py b/python/ray/serve/tests/test_pipeline_driver.py deleted file mode 100644 index c6b147aba..000000000 --- a/python/ray/serve/tests/test_pipeline_driver.py +++ /dev/null @@ -1,101 +0,0 @@ -import sys - -import pytest -import requests -import starlette.requests -from starlette.testclient import TestClient - -from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_input_schema -from ray.experimental.dag.input_node import InputNode -from ray import serve -import ray - - -def my_resolver(a: int): - return a - - -def test_loading_check(): - with pytest.raises(ValueError, match="callable"): - load_input_schema(["not function"]) - with pytest.raises(ValueError, match="type annotated"): - - def func(a): - return a - - load_input_schema(func) - assert ( - load_input_schema("ray.serve.tests.test_pipeline_driver.my_resolver") - == my_resolver - ) - - -def test_unit_schema_injection(): - class Impl(SimpleSchemaIngress): - async def predict(self, inp): - return inp - - async def resolver(my_custom_param: int): - return my_custom_param - - server = Impl(input_schema=resolver) - client = TestClient(server.app) - - response = client.post("/") - assert response.status_code == 422 - - response = client.post("/?my_custom_param=1") - assert response.status_code == 200 - assert response.text == "1" - - response = client.get("/openapi.json") - assert response.status_code == 200 - assert response.json()["paths"]["/"]["get"]["parameters"][0] == { - "required": True, - "schema": {"title": "My Custom Param", "type": "integer"}, - "name": "my_custom_param", - "in": "query", - } - - -@serve.deployment -def echo(inp): - # FastAPI can't handle this. - if isinstance(inp, starlette.requests.Request): - return "starlette!" - return inp - - -def test_dag_driver_default(serve_instance): - with InputNode() as inp: - dag = echo.bind(inp) - - handle = serve.run(DAGDriver.bind(dag)) - assert ray.get(handle.predict.remote(42)) == 42 - - resp = requests.post("http://127.0.0.1:8000/", json={"array": [1]}) - print(resp.text) - - resp.raise_for_status() - assert resp.json() == "starlette!" - - -async def resolver(my_custom_param: int): - return my_custom_param - - -def test_dag_driver_custom_schema(serve_instance): - with InputNode() as inp: - dag = echo.bind(inp) - - handle = serve.run(DAGDriver.bind(dag, input_schema=resolver)) - assert ray.get(handle.predict.remote(42)) == 42 - - resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100") - print(resp.text) - resp.raise_for_status() - assert resp.json() == 100 - - -if __name__ == "__main__": - sys.exit(pytest.main(["-v", "-s", __file__]))