mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[serve] Implement serve.build
(#23232)
The Serve REST API relies on YAML config files to specify and deploy deployments. This change introduces `serve.build()` and `serve build`, which translate Pipelines to YAML files. Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
This commit is contained in:
parent
be216a0e8c
commit
cf7b4e65c2
10 changed files with 343 additions and 114 deletions
|
@ -1061,6 +1061,7 @@ def import_attr(full_path: str):
|
|||
"""Given a full import path to a module attr, return the imported attr.
|
||||
|
||||
For example, the following are equivalent:
|
||||
MyClass = import_attr("module.submodule:MyClass")
|
||||
MyClass = import_attr("module.submodule.MyClass")
|
||||
from module.submodule import MyClass
|
||||
|
||||
|
@ -1069,9 +1070,19 @@ def import_attr(full_path: str):
|
|||
"""
|
||||
if full_path is None:
|
||||
raise TypeError("import path cannot be None")
|
||||
|
||||
if ":" in full_path:
|
||||
if full_path.count(":") > 1:
|
||||
raise ValueError(
|
||||
f'Got invalid import path "{full_path}". An '
|
||||
"import path may have at most one colon."
|
||||
)
|
||||
module_name, attr_name = full_path.split(":")
|
||||
else:
|
||||
last_period_idx = full_path.rfind(".")
|
||||
attr_name = full_path[last_period_idx + 1 :]
|
||||
module_name = full_path[:last_period_idx]
|
||||
attr_name = full_path[last_period_idx + 1 :]
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, attr_name)
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ from ray.serve.utils import (
|
|||
get_current_node_resource_key,
|
||||
get_random_letters,
|
||||
get_deployment_import_path,
|
||||
in_interactive_shell,
|
||||
logger,
|
||||
DEFAULT,
|
||||
)
|
||||
|
@ -1792,9 +1793,7 @@ class Application:
|
|||
Returns:
|
||||
Dict: The Application's deployments formatted in a dictionary.
|
||||
"""
|
||||
return ServeApplicationSchema(
|
||||
deployments=[deployment_to_schema(d) for d in self._deployments.values()]
|
||||
).dict()
|
||||
return serve_application_to_schema(self._deployments.values()).dict()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict) -> "Application":
|
||||
|
@ -1811,8 +1810,7 @@ class Application:
|
|||
Application: a new application object containing the deployments.
|
||||
"""
|
||||
|
||||
schema = ServeApplicationSchema.parse_obj(d)
|
||||
return cls([schema_to_deployment(s) for s in schema.deployments])
|
||||
return cls(schema_to_serve_application(ServeApplicationSchema.parse_obj(d)))
|
||||
|
||||
def to_yaml(self, f: Optional[TextIO] = None) -> Optional[str]:
|
||||
"""Returns this application's deployments as a YAML string.
|
||||
|
@ -1834,6 +1832,7 @@ class Application:
|
|||
Optional[String]: The deployments' YAML string. The output is from
|
||||
yaml.safe_dump(). Returned only if no file pointer is passed in.
|
||||
"""
|
||||
|
||||
return yaml.safe_dump(
|
||||
self.to_dict(), stream=f, default_flow_style=False, sort_keys=False
|
||||
)
|
||||
|
@ -1872,7 +1871,7 @@ def run(
|
|||
*,
|
||||
host: str = DEFAULT_HTTP_HOST,
|
||||
port: int = DEFAULT_HTTP_PORT,
|
||||
) -> RayServeHandle:
|
||||
) -> Optional[RayServeHandle]:
|
||||
"""Run a Serve application and return a ServeHandle to the ingress.
|
||||
|
||||
Either a DeploymentNode, DeploymentFunctionNode, or a pre-built application
|
||||
|
@ -1954,13 +1953,13 @@ def run(
|
|||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def build(target: DeploymentNode) -> Application:
|
||||
def build(target: Union[DeploymentNode, DeploymentFunctionNode]) -> Application:
|
||||
"""Builds a Serve application into a static application.
|
||||
|
||||
Takes in a DeploymentNode and converts it to a Serve application
|
||||
consisting of one or more deployments. This is intended to be used for
|
||||
production scenarios and deployed via the Serve REST API or CLI, so there
|
||||
are some restrictions placed on the deployments:
|
||||
Takes in a DeploymentNode or DeploymentFunctionNode and converts it to a
|
||||
Serve application consisting of one or more deployments. This is intended
|
||||
to be used for production scenarios and deployed via the Serve REST API or
|
||||
CLI, so there are some restrictions placed on the deployments:
|
||||
1) All of the deployments must be importable. That is, they cannot be
|
||||
defined in __main__ or inline defined. The deployments will be
|
||||
imported in production using the same import path they were here.
|
||||
|
@ -1969,9 +1968,19 @@ def build(target: DeploymentNode) -> Application:
|
|||
The returned Application object can be exported to a dictionary or YAML
|
||||
config.
|
||||
"""
|
||||
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
|
||||
if in_interactive_shell():
|
||||
raise RuntimeError(
|
||||
"serve.build cannot be called from an interactive shell like "
|
||||
"IPython or Jupyter because it requires all deployments to be "
|
||||
"importable to run the app after building."
|
||||
)
|
||||
|
||||
# TODO(edoakes): this should accept host and port, but we don't
|
||||
# currently support them in the REST API.
|
||||
raise NotImplementedError()
|
||||
return Application(pipeline_build(target))
|
||||
|
||||
|
||||
def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
||||
|
@ -1983,6 +1992,7 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
|||
init_args and init_kwargs must also be JSON-serializable or this call will
|
||||
fail.
|
||||
"""
|
||||
from ray.serve.pipeline.json_serde import convert_to_json_safe_obj
|
||||
|
||||
if d.ray_actor_options is not None:
|
||||
ray_actor_options_schema = RayActorOptionsSchema.parse_obj(d.ray_actor_options)
|
||||
|
@ -1991,9 +2001,11 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
|||
|
||||
return DeploymentSchema(
|
||||
name=d.name,
|
||||
import_path=get_deployment_import_path(d),
|
||||
init_args=d.init_args,
|
||||
init_kwargs=d.init_kwargs,
|
||||
import_path=get_deployment_import_path(
|
||||
d, enforce_importable=True, replace_main=True
|
||||
),
|
||||
init_args=convert_to_json_safe_obj(d.init_args, err_key="init_args"),
|
||||
init_kwargs=convert_to_json_safe_obj(d.init_kwargs, err_key="init_kwargs"),
|
||||
num_replicas=d.num_replicas,
|
||||
route_prefix=d.route_prefix,
|
||||
max_concurrent_queries=d.max_concurrent_queries,
|
||||
|
@ -2017,12 +2029,12 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
|
|||
|
||||
return deployment(
|
||||
name=s.name,
|
||||
init_args=convert_from_json_safe_obj(s.init_args),
|
||||
init_kwargs=convert_from_json_safe_obj(s.init_kwargs),
|
||||
init_args=convert_from_json_safe_obj(s.init_args, err_key="init_args"),
|
||||
init_kwargs=convert_from_json_safe_obj(s.init_kwargs, err_key="init_kwargs"),
|
||||
num_replicas=s.num_replicas,
|
||||
route_prefix=s.route_prefix,
|
||||
max_concurrent_queries=s.max_concurrent_queries,
|
||||
user_config=convert_from_json_safe_obj(s.user_config),
|
||||
user_config=s.user_config,
|
||||
_autoscaling_config=s.autoscaling_config,
|
||||
_graceful_shutdown_wait_loop_s=s.graceful_shutdown_wait_loop_s,
|
||||
_graceful_shutdown_timeout_s=s.graceful_shutdown_timeout_s,
|
||||
|
@ -2035,8 +2047,9 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
|
|||
def serve_application_to_schema(
|
||||
deployments: List[Deployment],
|
||||
) -> ServeApplicationSchema:
|
||||
schemas = [deployment_to_schema(d) for d in deployments]
|
||||
return ServeApplicationSchema(deployments=schemas)
|
||||
return ServeApplicationSchema(
|
||||
deployments=[deployment_to_schema(d) for d in deployments]
|
||||
)
|
||||
|
||||
|
||||
def schema_to_serve_application(schema: ServeApplicationSchema) -> List[Deployment]:
|
||||
|
|
|
@ -31,13 +31,36 @@ from ray.serve.api import RayServeDAGHandle
|
|||
|
||||
|
||||
def convert_to_json_safe_obj(obj: Any, *, err_key: str) -> Any:
|
||||
# XXX: comment, err msg
|
||||
"""Converts the provided object into a JSON-safe version of it.
|
||||
|
||||
The returned object can safely be `json.dumps`'d to a string.
|
||||
|
||||
Uses the Ray Serve encoder to serialize special objects such as
|
||||
ServeHandles and DAGHandles.
|
||||
|
||||
Raises: TypeError if the object contains fields that cannot be
|
||||
JSON-serialized.
|
||||
"""
|
||||
try:
|
||||
return json.loads(json.dumps(obj, cls=DAGNodeEncoder))
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"All provided fields must be JSON-serializable to build the "
|
||||
f"Serve app. Failed while serializing {err_key}:\n{e}"
|
||||
)
|
||||
|
||||
|
||||
def convert_from_json_safe_obj(obj: Any) -> Any:
|
||||
# XXX: comment, err msg
|
||||
def convert_from_json_safe_obj(obj: Any, *, err_key: str) -> Any:
|
||||
"""Converts a JSON-safe object to one that contains Serve special types.
|
||||
|
||||
The provided object should have been serialized using
|
||||
convert_to_json_safe_obj. Any special-cased objects such as ServeHandles
|
||||
will be recovered on this pass.
|
||||
"""
|
||||
try:
|
||||
return json.loads(json.dumps(obj), object_hook=dagnode_from_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to convert {err_key} from JSON:\n{e}")
|
||||
|
||||
|
||||
class DAGNodeEncoder(json.JSONEncoder):
|
||||
|
@ -81,17 +104,7 @@ class DAGNodeEncoder(json.JSONEncoder):
|
|||
elif isinstance(obj, DAGNode):
|
||||
return obj.to_json(DAGNodeEncoder)
|
||||
else:
|
||||
# Let the base class default method raise the TypeError
|
||||
try:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"All args and kwargs used in Ray DAG building for serve "
|
||||
"deployment need to be JSON serializable. Please JSON "
|
||||
"serialize your args to make your ray application "
|
||||
"deployment ready."
|
||||
f"\n Original exception message: {e}"
|
||||
)
|
||||
|
||||
|
||||
def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
|
||||
|
|
|
@ -97,24 +97,6 @@ def test_non_json_serializable_args():
|
|||
|
||||
ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2))
|
||||
# General context
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=(
|
||||
"All args and kwargs used in Ray DAG building for serve "
|
||||
"deployment need to be JSON serializable."
|
||||
),
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
# User actionable item
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=(
|
||||
"Please JSON serialize your args to make your ray application "
|
||||
"deployment ready"
|
||||
),
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
# Original error message
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r"Object of type .* is not JSON serializable",
|
||||
|
|
|
@ -5,6 +5,7 @@ import pathlib
|
|||
import click
|
||||
import time
|
||||
import sys
|
||||
from typing import Optional, Union
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
|
@ -22,6 +23,8 @@ from ray.dashboard.modules.serve.sdk import ServeSubmissionClient
|
|||
from ray.autoscaler._private.cli_logger import cli_logger
|
||||
from ray.serve.api import (
|
||||
Application,
|
||||
DeploymentFunctionNode,
|
||||
DeploymentNode,
|
||||
get_deployment_statuses,
|
||||
serve_application_status_to_schema,
|
||||
)
|
||||
|
@ -378,3 +381,51 @@ def delete(address: str, yes: bool):
|
|||
cli_logger.newline()
|
||||
cli_logger.success("\nSent delete request successfully!\n")
|
||||
cli_logger.newline()
|
||||
|
||||
|
||||
@cli.command(
|
||||
short_help="Writes a Pipeline's config file.",
|
||||
help=(
|
||||
"Imports the DeploymentNode or DeploymentFunctionNode at IMPORT_PATH "
|
||||
"and generates a structured config for it that can be used by "
|
||||
"`serve deploy` or the REST API. "
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--app-dir",
|
||||
"-d",
|
||||
default=".",
|
||||
type=str,
|
||||
help=APP_DIR_HELP_STR,
|
||||
)
|
||||
@click.option(
|
||||
"--output-path",
|
||||
"-o",
|
||||
default=None,
|
||||
type=str,
|
||||
help=(
|
||||
"Local path where the output config will be written in YAML format. "
|
||||
"If not provided, the config will be printed to STDOUT."
|
||||
),
|
||||
)
|
||||
@click.argument("import_path")
|
||||
def build(app_dir: str, output_path: Optional[str], import_path: str):
|
||||
sys.path.insert(0, app_dir)
|
||||
|
||||
node: Union[DeploymentNode, DeploymentFunctionNode] = import_attr(import_path)
|
||||
if not isinstance(node, (DeploymentNode, DeploymentFunctionNode)):
|
||||
raise TypeError(
|
||||
f"Expected '{import_path}' to be DeploymentNode or "
|
||||
f"DeploymentFunctionNode, but got {type(node)}."
|
||||
)
|
||||
|
||||
app = serve.build(node)
|
||||
|
||||
if output_path is not None:
|
||||
if not output_path.endswith(".yaml"):
|
||||
raise ValueError("FILE_PATH must end with '.yaml'.")
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
app.to_yaml(f)
|
||||
else:
|
||||
print(app.to_yaml(), end="")
|
||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
|||
import os
|
||||
import yaml
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
@ -42,7 +43,7 @@ class TestApplicationConstruction:
|
|||
Application([self.f, 5, "hello"])
|
||||
|
||||
|
||||
class TestRun:
|
||||
class TestServeRun:
|
||||
@serve.deployment
|
||||
def f():
|
||||
return "f reached"
|
||||
|
@ -272,6 +273,35 @@ class DecoratedClass:
|
|||
return "got decorated class"
|
||||
|
||||
|
||||
class TestServeBuild:
|
||||
@serve.deployment
|
||||
class A:
|
||||
pass
|
||||
|
||||
def test_build_non_json_serializable_args(self, serve_instance):
|
||||
with pytest.raises(
|
||||
TypeError, match="must be JSON-serializable to build.*init_args"
|
||||
):
|
||||
serve.build(self.A.bind(np.zeros(100))).to_dict()
|
||||
|
||||
def test_build_non_json_serializable_kwargs(self, serve_instance):
|
||||
with pytest.raises(
|
||||
TypeError, match="must be JSON-serializable to build.*init_kwargs"
|
||||
):
|
||||
serve.build(self.A.bind(kwarg=np.zeros(100))).to_dict()
|
||||
|
||||
def test_build_non_importable(self, serve_instance):
|
||||
def gen_deployment():
|
||||
@serve.deployment
|
||||
def f():
|
||||
pass
|
||||
|
||||
return f
|
||||
|
||||
with pytest.raises(RuntimeError, match="must be importable"):
|
||||
serve.build(gen_deployment().bind()).to_dict()
|
||||
|
||||
|
||||
def compare_specified_options(deployments1: Dict, deployments2: Dict):
|
||||
"""
|
||||
Helper method that takes 2 deployment dictionaries in the REST API
|
||||
|
|
|
@ -6,12 +6,13 @@ import sys
|
|||
import signal
|
||||
import pytest
|
||||
import requests
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.tests.conftest import tmp_working_dir # noqa: F401, E501
|
||||
from ray._private.test_utils import wait_for_condition
|
||||
from ray.serve.api import Application
|
||||
from ray.serve.api import Application, RayServeDAGHandle
|
||||
|
||||
|
||||
def ping_endpoint(endpoint: str, params: str = ""):
|
||||
|
@ -372,6 +373,45 @@ def test_run_runtime_env(ray_start_stop):
|
|||
p.wait()
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def global_f(*args):
|
||||
return "wonderful world"
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class NoArgDriver:
|
||||
def __init__(self, dag: RayServeDAGHandle):
|
||||
self.dag = dag
|
||||
|
||||
async def __call__(self):
|
||||
return await self.dag.remote()
|
||||
|
||||
|
||||
TestBuildFNode = global_f.bind()
|
||||
TestBuildDagNode = NoArgDriver.bind(TestBuildFNode)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
|
||||
@pytest.mark.parametrize("node", ["TestBuildFNode", "TestBuildDagNode"])
|
||||
def test_build(ray_start_stop, node):
|
||||
with NamedTemporaryFile(mode="w+", suffix=".yaml") as tmp:
|
||||
|
||||
# Build an app
|
||||
subprocess.check_output(
|
||||
[
|
||||
"serve",
|
||||
"build",
|
||||
f"ray.serve.tests.test_cli.{node}",
|
||||
"-o",
|
||||
tmp.name,
|
||||
]
|
||||
)
|
||||
subprocess.check_output(["serve", "deploy", tmp.name])
|
||||
assert ping_endpoint("") == "wonderful world"
|
||||
subprocess.check_output(["serve", "delete", "-y"])
|
||||
assert ping_endpoint("") == "connection error"
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
|
||||
@pytest.mark.parametrize("use_command", [True, False])
|
||||
def test_idempotence_after_controller_death(ray_start_stop, use_command: bool):
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.api import RayServeDAGHandle
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray.serve.api import Application, DeploymentNode, RayServeDAGHandle
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
from ray.serve.drivers import DAGDriver
|
||||
import starlette.requests
|
||||
|
@ -19,6 +19,15 @@ RayHandleLike = TypeVar("RayHandleLike")
|
|||
NESTED_HANDLE_KEY = "nested_handle"
|
||||
|
||||
|
||||
def maybe_build(
|
||||
node: DeploymentNode, use_build: bool
|
||||
) -> Union[Application, DeploymentNode]:
|
||||
if use_build:
|
||||
return Application.from_dict(serve.build(node).to_dict())
|
||||
else:
|
||||
return node
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class ClassHello:
|
||||
def __init__(self):
|
||||
|
@ -113,11 +122,12 @@ class NoargDriver:
|
|||
return await self.dag.remote()
|
||||
|
||||
|
||||
def test_single_func_no_input(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_single_func_no_input(serve_instance, use_build):
|
||||
dag = fn_hello.bind()
|
||||
serve_dag = NoargDriver.bind(dag)
|
||||
|
||||
handle = serve.run(serve_dag)
|
||||
handle = serve.run(maybe_build(serve_dag, use_build))
|
||||
assert ray.get(handle.remote()) == "hello"
|
||||
assert requests.get("http://127.0.0.1:8000/").text == "hello"
|
||||
|
||||
|
@ -126,7 +136,8 @@ async def json_resolver(request: starlette.requests.Request):
|
|||
return await request.json()
|
||||
|
||||
|
||||
def test_single_func_deployment_dag(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_single_func_deployment_dag(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||
|
@ -135,7 +146,8 @@ def test_single_func_deployment_dag(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
|
||||
|
||||
|
||||
def test_chained_function(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_chained_function(serve_instance, use_build):
|
||||
@serve.deployment
|
||||
def func_1(input):
|
||||
return input
|
||||
|
@ -156,7 +168,8 @@ def test_chained_function(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
|
||||
|
||||
|
||||
def test_simple_class_with_class_method(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_simple_class_with_class_method(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
model = Model.bind(2, ratio=0.3)
|
||||
dag = model.forward.bind(dag_input)
|
||||
|
@ -166,7 +179,8 @@ def test_simple_class_with_class_method(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
|
||||
|
||||
|
||||
def test_func_class_with_class_method(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_func_class_with_class_method(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
|
@ -180,7 +194,8 @@ def test_func_class_with_class_method(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
|
||||
|
||||
|
||||
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
|
@ -193,7 +208,8 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
|
||||
|
||||
|
||||
def test_shared_deployment_handle(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_shared_deployment_handle(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
m = Model.bind(2)
|
||||
combine = Combine.bind(m, m2=m)
|
||||
|
@ -205,7 +221,8 @@ def test_shared_deployment_handle(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
|
||||
|
||||
|
||||
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use_build):
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
|
@ -238,13 +255,15 @@ class Echo:
|
|||
return self._s
|
||||
|
||||
|
||||
def test_single_node_deploy_success(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_single_node_deploy_success(serve_instance, use_build):
|
||||
m1 = Adder.bind(1)
|
||||
handle = serve.run(m1)
|
||||
handle = serve.run(maybe_build(m1, use_build))
|
||||
assert ray.get(handle.remote(41)) == 42
|
||||
|
||||
|
||||
def test_single_node_driver_sucess(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_single_node_driver_sucess(serve_instance, use_build):
|
||||
m1 = Adder.bind(1)
|
||||
m2 = Adder.bind(2)
|
||||
with InputNode() as input_node:
|
||||
|
@ -280,7 +299,8 @@ class TakeHandle:
|
|||
return ray.get(self.handle.remote(inp))
|
||||
|
||||
|
||||
def test_passing_handle(serve_instance):
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_passing_handle(serve_instance, use_build):
|
||||
child = Adder.bind(1)
|
||||
parent = TakeHandle.bind(child)
|
||||
driver = DAGDriver.bind(parent, input_schema=json_resolver)
|
||||
|
@ -289,30 +309,35 @@ def test_passing_handle(serve_instance):
|
|||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
|
||||
|
||||
|
||||
def test_passing_handle_in_obj(serve_instance):
|
||||
@serve.deployment
|
||||
class Parent:
|
||||
class DictParent:
|
||||
def __init__(self, d):
|
||||
self._d = d
|
||||
|
||||
async def __call__(self, key):
|
||||
return await self._d[key].remote()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_passing_handle_in_obj(serve_instance, use_build):
|
||||
|
||||
child1 = Echo.bind("ed")
|
||||
child2 = Echo.bind("simon")
|
||||
parent = Parent.bind({"child1": child1, "child2": child2})
|
||||
parent = maybe_build(
|
||||
DictParent.bind({"child1": child1, "child2": child2}), use_build
|
||||
)
|
||||
|
||||
handle = serve.run(parent)
|
||||
assert ray.get(handle.remote("child1")) == "ed"
|
||||
assert ray.get(handle.remote("child2")) == "simon"
|
||||
|
||||
|
||||
def test_pass_handle_to_multiple(serve_instance):
|
||||
@serve.deployment
|
||||
class Child:
|
||||
def __call__(self, *args):
|
||||
return os.getpid()
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Parent:
|
||||
def __init__(self, child):
|
||||
|
@ -321,6 +346,7 @@ def test_pass_handle_to_multiple(serve_instance):
|
|||
def __call__(self, *args):
|
||||
return ray.get(self._child.remote())
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class GrandParent:
|
||||
def __init__(self, child, parent):
|
||||
|
@ -332,31 +358,43 @@ def test_pass_handle_to_multiple(serve_instance):
|
|||
assert ray.get(self._child.remote()) == ray.get(self._parent.remote())
|
||||
return "ok"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_build", [False, True])
|
||||
def test_pass_handle_to_multiple(serve_instance, use_build):
|
||||
|
||||
child = Child.bind()
|
||||
parent = Parent.bind(child)
|
||||
grandparent = GrandParent.bind(child, parent)
|
||||
grandparent = maybe_build(GrandParent.bind(child, parent), use_build)
|
||||
|
||||
handle = serve.run(grandparent)
|
||||
assert ray.get(handle.remote()) == "ok"
|
||||
|
||||
|
||||
def test_non_json_serializable_args(serve_instance):
|
||||
def test_run_non_json_serializable_args(serve_instance):
|
||||
# Test that we can capture and bind non-json-serializable arguments.
|
||||
arr1 = np.zeros(100)
|
||||
arr2 = np.zeros(200)
|
||||
arr3 = np.zeros(300)
|
||||
|
||||
@serve.deployment
|
||||
class A:
|
||||
def __init__(self, arr1):
|
||||
def __init__(self, arr1, *, arr2):
|
||||
self.arr1 = arr1
|
||||
self.arr2 = arr2
|
||||
self.arr3 = arr3
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.arr1, self.arr2
|
||||
return self.arr1, self.arr2, self.arr3
|
||||
|
||||
handle = serve.run(A.bind(arr1))
|
||||
ret1, ret2 = ray.get(handle.remote())
|
||||
assert np.array_equal(ret1, arr1) and np.array_equal(ret2, arr2)
|
||||
handle = serve.run(A.bind(arr1, arr2=arr2))
|
||||
ret1, ret2, ret3 = ray.get(handle.remote())
|
||||
assert all(
|
||||
[
|
||||
np.array_equal(ret1, arr1),
|
||||
np.array_equal(ret2, arr2),
|
||||
np.array_equal(ret3, arr3),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@serve.deployment
|
||||
|
@ -406,8 +444,5 @@ def test_unsupported_remote():
|
|||
_ = func.bind().remote()
|
||||
|
||||
|
||||
# TODO: check that serve.build raises an exception.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -39,7 +39,44 @@ class DecoratedActor:
|
|||
return "reached decorated_actor"
|
||||
|
||||
|
||||
def gen_func():
|
||||
@serve.deployment
|
||||
def f():
|
||||
pass
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def gen_class():
|
||||
@serve.deployment
|
||||
class A:
|
||||
pass
|
||||
|
||||
return A
|
||||
|
||||
|
||||
class TestGetDeploymentImportPath:
|
||||
def test_invalid_inline_defined(self):
|
||||
@serve.deployment
|
||||
def inline_f():
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="must be importable"):
|
||||
get_deployment_import_path(inline_f, enforce_importable=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="must be importable"):
|
||||
get_deployment_import_path(gen_func(), enforce_importable=True)
|
||||
|
||||
@serve.deployment
|
||||
class InlineCls:
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="must be importable"):
|
||||
get_deployment_import_path(InlineCls, enforce_importable=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="must be importable"):
|
||||
get_deployment_import_path(gen_class(), enforce_importable=True)
|
||||
|
||||
def test_get_import_path_basic(self):
|
||||
d = decorated_f.options()
|
||||
|
||||
|
|
|
@ -255,7 +255,9 @@ def msgpack_serialize(obj):
|
|||
return serialized
|
||||
|
||||
|
||||
def get_deployment_import_path(deployment, replace_main=False):
|
||||
def get_deployment_import_path(
|
||||
deployment, replace_main=False, enforce_importable=False
|
||||
):
|
||||
"""
|
||||
Gets the import path for deployment's func_or_class.
|
||||
|
||||
|
@ -275,8 +277,15 @@ def get_deployment_import_path(deployment, replace_main=False):
|
|||
|
||||
import_path = f"{body.__module__}.{body.__qualname__}"
|
||||
|
||||
if replace_main:
|
||||
if enforce_importable and "<locals>" in body.__qualname__:
|
||||
raise RuntimeError(
|
||||
"Deployment definitions must be importable to build the Serve app, "
|
||||
f"but deployment '{deployment.name}' is inline defined or returned "
|
||||
"from another function. Please restructure your code so that "
|
||||
f"'{import_path}' can be imported (i.e., put it in a module)."
|
||||
)
|
||||
|
||||
if replace_main:
|
||||
# Replaces __main__ with its file name. E.g. suppose the import path
|
||||
# is __main__.classname and classname is defined in filename.py.
|
||||
# Its import path becomes filename.classname.
|
||||
|
@ -365,3 +374,11 @@ def require_packages(packages: List[str]):
|
|||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def in_interactive_shell():
|
||||
# Taken from:
|
||||
# https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
|
||||
import __main__ as main
|
||||
|
||||
return not hasattr(main, "__file__")
|
||||
|
|
Loading…
Add table
Reference in a new issue