Revert "[Serve] Implement Default DAGDriver (#23301)" (#23358)

This reverts commit 91a1c3411f.
This commit is contained in:
Stephanie Wang 2022-03-21 08:54:52 -07:00 committed by GitHub
parent 81dcf9ff35
commit e507aa5758
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 246 deletions

View file

@ -331,14 +331,6 @@ py_test(
deps = [":serve_lib"],
)
py_test(
name = "test_pipeline_driver",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)
py_test(
name = "test_pipeline_dag",
size = "medium",

View file

@ -1,82 +0,0 @@
import inspect
from abc import abstractmethod
from typing import Any, Callable, Optional, Union
import starlette
from fastapi import Depends, FastAPI
from ray._private.utils import import_attr
from ray.serve.api import RayServeDAGHandle
from ray.serve.http_util import ASGIHTTPSender
from ray import serve
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.starlette_request"
InputSchemaFn = Callable[[Any], Any]
def load_input_schema(
input_schema: Optional[Union[str, InputSchemaFn]]
) -> InputSchemaFn:
if input_schema is None:
input_schema = DEFAULT_INPUT_SCHEMA
if isinstance(input_schema, str):
input_schema = import_attr(input_schema)
if not inspect.isfunction(input_schema):
raise ValueError("input schema must be a callable function.")
if any(
param.annotation == inspect.Parameter.empty
for param in inspect.signature(input_schema).parameters.values()
):
raise ValueError("input schema function's signature should be type annotated.")
return input_schema
class SimpleSchemaIngress:
def __init__(self, input_schema: Optional[Union[str, InputSchemaFn]] = None):
"""Create a FastAPI endpoint annotated with input_schema dependency.
Args:
input_schema(str, InputSchemaFn, None): The FastAPI input conversion
function. By default, Serve will directly pass in the request object
starlette.requests.Request. You can pass in any FastAPI dependency
resolver. When you pass in a string, Serve will import it.
Please refer to Serve HTTP adatper documentation to learn more.
"""
input_schema = load_input_schema(input_schema)
self.app = FastAPI()
@self.app.get("/")
@self.app.post("/")
async def handle_request(inp=Depends(input_schema)):
resp = await self.predict(inp)
return resp
@abstractmethod
async def predict(self, inp):
raise NotImplementedError()
async def __call__(self, request: starlette.requests.Request):
# NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to
# generate FastAPI on the fly, we should find a way to unify the two.
sender = ASGIHTTPSender()
await self.app(request.scope, receive=request.receive, send=sender)
return sender.build_asgi_response()
@serve.deployment(route_prefix="/")
class DAGDriver(SimpleSchemaIngress):
def __init__(
self,
dag_handle: RayServeDAGHandle,
*,
input_schema: Optional[Union[str, Callable]] = None,
):
self.dag_handle = dag_handle
super().__init__(input_schema)
async def predict(self, inp):
"""Perform inference directly without HTTP."""
return await self.dag_handle.remote(inp)

View file

@ -7,7 +7,6 @@ import numpy as np
from ray.serve.utils import require_packages
from ray.ml.predictor import DataBatchType
import starlette.requests
_1DArray = List[float]
@ -51,14 +50,6 @@ def array_to_databatch(payload: NdArray) -> DataBatchType:
return arr
def starlette_request(
request: starlette.requests.Request,
) -> starlette.requests.Request:
"""Returns the raw request object."""
# NOTE(simon): This adapter is used for ease of getting started.
return request
@require_packages(["PIL"])
def image_to_databatch(img: bytes = File(...)) -> DataBatchType:
"""Accepts a PIL-readable file from an HTTP form and converts

View file

@ -0,0 +1,21 @@
from typing import Callable, Optional, Union
import starlette
from ray import serve
from ray.serve.api import DAGHandle
@serve.deployment
class DAGDriver:
def __init__(
self,
dag_handle: DAGHandle,
*,
input_schema: Optional[Union[str, Callable]] = None,
):
raise NotImplementedError()
async def __call__(self, request: starlette.requests.Request):
"""Parse input schema and pass the result to the DAG handle."""
raise NotImplementedError()

View file

@ -1,11 +1,31 @@
from typing import Dict, Optional, Type, Union
import inspect
from typing import Any, Callable, Dict, Optional, Type, Union
import starlette.requests
from fastapi import Depends, FastAPI
from ray._private.utils import import_attr
from ray.ml.checkpoint import Checkpoint
from ray.ml.predictor import Predictor
from ray.serve.drivers import InputSchemaFn, SimpleSchemaIngress
import ray
from ray.ml.predictor import Predictor, DataBatchType
from ray.serve.http_util import ASGIHTTPSender
from ray import serve
import ray
DEFAULT_INPUT_SCHEMA = "ray.serve.http_adapters.array_to_databatch"
InputSchemaFn = Callable[[Any], DataBatchType]
def _load_input_schema(
input_schema: Optional[Union[str, InputSchemaFn]]
) -> InputSchemaFn:
if input_schema is None:
input_schema = DEFAULT_INPUT_SCHEMA
if isinstance(input_schema, str):
input_schema = import_attr(input_schema)
assert inspect.isfunction(input_schema), "input schema must be a callable function."
return input_schema
def _load_checkpoint(
@ -38,14 +58,12 @@ def _load_predictor_cls(
return predictor_cls
class ModelWrapper(SimpleSchemaIngress):
class ModelWrapper:
def __init__(
self,
predictor_cls: Union[str, Type[Predictor]],
checkpoint: Union[Checkpoint, Dict],
input_schema: Union[
str, InputSchemaFn
] = "ray.serve.http_adapters.array_to_databatch",
input_schema: Optional[Union[str, InputSchemaFn]] = None,
batching_params: Optional[Union[Dict[str, int], bool]] = None,
):
"""Serve any Ray ML predictor from checkpoint.
@ -73,6 +91,7 @@ class ModelWrapper(SimpleSchemaIngress):
checkpoint = _load_checkpoint(checkpoint)
self.model = predictor_cls.from_checkpoint(checkpoint)
self.app = FastAPI()
# Configure Batching
if batching_params is False:
@ -91,8 +110,21 @@ class ModelWrapper(SimpleSchemaIngress):
self.batched_predict = batched_predict
super().__init__(input_schema)
# Configure Input Schema
input_schema = _load_input_schema(input_schema)
@self.app.get("/")
@self.app.post("/")
async def handle_request(inp=Depends(input_schema)):
return await batched_predict(inp)
async def __call__(self, request: starlette.requests.Request):
# NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to
# generate FastAPI on the fly, we should find a way to unify the two.
sender = ASGIHTTPSender()
await self.app(request.scope, receive=request.receive, send=sender)
return sender.build_asgi_response()
async def predict(self, inp):
"""Perform inference directly without HTTP."""
"""Performing inference directly without HTTP."""
return await self.batched_predict(inp)

View file

@ -1,19 +1,16 @@
import pytest
import os
import sys
from typing import TypeVar
import requests
from typing import TypeVar, Any
import numpy as np
import requests
import ray
from ray import serve
from ray.serve.api import RayServeDAGHandle
from ray.experimental.dag.input_node import InputNode
from ray.serve.pipeline.api import build as pipeline_build
from ray.serve.drivers import DAGDriver
import starlette.requests
RayHandleLike = TypeVar("RayHandleLike")
NESTED_HANDLE_KEY = "nested_handle"
@ -104,6 +101,16 @@ class Adder:
__call__ = forward
@serve.deployment
class Driver:
def __init__(self, dag: RayServeDAGHandle):
self.dag = dag
async def __call__(self, inp: Any) -> Any:
print(f"Driver got {inp}")
return await self.dag.remote(inp)
@serve.deployment
class NoargDriver:
def __init__(self, dag: RayServeDAGHandle):
@ -122,17 +129,13 @@ def test_single_func_no_input(serve_instance):
assert requests.get("http://127.0.0.1:8000/").text == "hello"
async def json_resolver(request: starlette.requests.Request):
return await request.json()
def test_single_func_deployment_dag(serve_instance):
with InputNode() as dag_input:
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
serve_dag = Driver.bind(dag)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote([1, 2])) == 4
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
assert ray.get(handle.remote([1, 2])) == 4
# TODO (simon): ModelWrapper and HTTP adapter ?
def test_chained_function(serve_instance):
@ -151,19 +154,17 @@ def test_chained_function(serve_instance):
with pytest.raises(ValueError, match="Please provide a driver class"):
_ = serve.run(serve_dag)
handle = serve.run(DAGDriver.bind(serve_dag, input_schema=json_resolver))
assert ray.get(handle.predict.remote(2)) == 6 # 2 + 2*2
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
handle = serve.run(Driver.bind(serve_dag))
assert ray.get(handle.remote(2)) == 6 # 2 + 2*2
def test_simple_class_with_class_method(serve_instance):
with InputNode() as dag_input:
model = Model.bind(2, ratio=0.3)
dag = model.forward.bind(dag_input)
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
serve_dag = Driver.bind(dag)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 0.6
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
assert ray.get(handle.remote(1)) == 0.6
def test_func_class_with_class_method(serve_instance):
@ -173,11 +174,11 @@ def test_func_class_with_class_method(serve_instance):
m1_output = m1.forward.bind(dag_input[0])
m2_output = m2.forward.bind(dag_input[1])
combine_output = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2])
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = Driver.bind(combine_output)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote([1, 2, 3])) == 8
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
assert ray.get(handle.remote([1, 2, 3])) == 8
# TODO (simon): ModelWrapper and HTTP adapter ?
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
@ -186,11 +187,10 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
m2 = Model.bind(3)
combine = Combine.bind(m1, m2=m2)
combine_output = combine.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = Driver.bind(combine_output)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 5
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
assert ray.get(handle.remote(1)) == 5
def test_shared_deployment_handle(serve_instance):
@ -198,11 +198,10 @@ def test_shared_deployment_handle(serve_instance):
m = Model.bind(2)
combine = Combine.bind(m, m2=m)
combine_output = combine.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)
serve_dag = Driver.bind(combine_output)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 4
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
assert ray.get(handle.remote(1)) == 4
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
@ -211,11 +210,10 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
m2 = Model.bind(3)
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
output = combine.bind(dag_input)
serve_dag = DAGDriver.bind(output, input_schema=json_resolver)
serve_dag = Driver.bind(output)
handle = serve.run(serve_dag)
assert ray.get(handle.predict.remote(1)) == 5
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
assert ray.get(handle.remote(1)) == 5
def test_class_factory(serve_instance):
@ -250,10 +248,9 @@ def test_single_node_driver_sucess(serve_instance):
with InputNode() as input_node:
out = m1.forward.bind(input_node)
out = m2.forward.bind(out)
driver = DAGDriver.bind(out, input_schema=json_resolver)
driver = Driver.bind(out)
handle = serve.run(driver)
assert ray.get(handle.predict.remote(39)) == 42
assert requests.post("http://127.0.0.1:8000/", json=39).json() == 42
assert ray.get(handle.remote(39)) == 42
def test_options_and_names(serve_instance):
@ -283,10 +280,9 @@ class TakeHandle:
def test_passing_handle(serve_instance):
child = Adder.bind(1)
parent = TakeHandle.bind(child)
driver = DAGDriver.bind(parent, input_schema=json_resolver)
driver = Driver.bind(parent)
handle = serve.run(driver)
assert ray.get(handle.predict.remote(1)) == 2
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
assert ray.get(handle.remote(1)) == 2
def test_passing_handle_in_obj(serve_instance):

View file

@ -1,101 +0,0 @@
import sys
import pytest
import requests
import starlette.requests
from starlette.testclient import TestClient
from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, load_input_schema
from ray.experimental.dag.input_node import InputNode
from ray import serve
import ray
def my_resolver(a: int):
return a
def test_loading_check():
with pytest.raises(ValueError, match="callable"):
load_input_schema(["not function"])
with pytest.raises(ValueError, match="type annotated"):
def func(a):
return a
load_input_schema(func)
assert (
load_input_schema("ray.serve.tests.test_pipeline_driver.my_resolver")
== my_resolver
)
def test_unit_schema_injection():
class Impl(SimpleSchemaIngress):
async def predict(self, inp):
return inp
async def resolver(my_custom_param: int):
return my_custom_param
server = Impl(input_schema=resolver)
client = TestClient(server.app)
response = client.post("/")
assert response.status_code == 422
response = client.post("/?my_custom_param=1")
assert response.status_code == 200
assert response.text == "1"
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json()["paths"]["/"]["get"]["parameters"][0] == {
"required": True,
"schema": {"title": "My Custom Param", "type": "integer"},
"name": "my_custom_param",
"in": "query",
}
@serve.deployment
def echo(inp):
# FastAPI can't handle this.
if isinstance(inp, starlette.requests.Request):
return "starlette!"
return inp
def test_dag_driver_default(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)
handle = serve.run(DAGDriver.bind(dag))
assert ray.get(handle.predict.remote(42)) == 42
resp = requests.post("http://127.0.0.1:8000/", json={"array": [1]})
print(resp.text)
resp.raise_for_status()
assert resp.json() == "starlette!"
async def resolver(my_custom_param: int):
return my_custom_param
def test_dag_driver_custom_schema(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)
handle = serve.run(DAGDriver.bind(dag, input_schema=resolver))
assert ray.get(handle.predict.remote(42)) == 42
resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100")
print(resp.text)
resp.raise_for_status()
assert resp.json() == 100
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))