mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[serve] Implement serve.build
(#23232)
The Serve REST API relies on YAML config files to specify and deploy deployments. This change introduces `serve.build()` and `serve build`, which translate Pipelines to YAML files. Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
This commit is contained in:
parent
be216a0e8c
commit
cf7b4e65c2
10 changed files with 343 additions and 114 deletions
|
@ -1061,6 +1061,7 @@ def import_attr(full_path: str):
|
||||||
"""Given a full import path to a module attr, return the imported attr.
|
"""Given a full import path to a module attr, return the imported attr.
|
||||||
|
|
||||||
For example, the following are equivalent:
|
For example, the following are equivalent:
|
||||||
|
MyClass = import_attr("module.submodule:MyClass")
|
||||||
MyClass = import_attr("module.submodule.MyClass")
|
MyClass = import_attr("module.submodule.MyClass")
|
||||||
from module.submodule import MyClass
|
from module.submodule import MyClass
|
||||||
|
|
||||||
|
@ -1069,9 +1070,19 @@ def import_attr(full_path: str):
|
||||||
"""
|
"""
|
||||||
if full_path is None:
|
if full_path is None:
|
||||||
raise TypeError("import path cannot be None")
|
raise TypeError("import path cannot be None")
|
||||||
last_period_idx = full_path.rfind(".")
|
|
||||||
attr_name = full_path[last_period_idx + 1 :]
|
if ":" in full_path:
|
||||||
module_name = full_path[:last_period_idx]
|
if full_path.count(":") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'Got invalid import path "{full_path}". An '
|
||||||
|
"import path may have at most one colon."
|
||||||
|
)
|
||||||
|
module_name, attr_name = full_path.split(":")
|
||||||
|
else:
|
||||||
|
last_period_idx = full_path.rfind(".")
|
||||||
|
module_name = full_path[:last_period_idx]
|
||||||
|
attr_name = full_path[last_period_idx + 1 :]
|
||||||
|
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
return getattr(module, attr_name)
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,7 @@ from ray.serve.utils import (
|
||||||
get_current_node_resource_key,
|
get_current_node_resource_key,
|
||||||
get_random_letters,
|
get_random_letters,
|
||||||
get_deployment_import_path,
|
get_deployment_import_path,
|
||||||
|
in_interactive_shell,
|
||||||
logger,
|
logger,
|
||||||
DEFAULT,
|
DEFAULT,
|
||||||
)
|
)
|
||||||
|
@ -1792,9 +1793,7 @@ class Application:
|
||||||
Returns:
|
Returns:
|
||||||
Dict: The Application's deployments formatted in a dictionary.
|
Dict: The Application's deployments formatted in a dictionary.
|
||||||
"""
|
"""
|
||||||
return ServeApplicationSchema(
|
return serve_application_to_schema(self._deployments.values()).dict()
|
||||||
deployments=[deployment_to_schema(d) for d in self._deployments.values()]
|
|
||||||
).dict()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: Dict) -> "Application":
|
def from_dict(cls, d: Dict) -> "Application":
|
||||||
|
@ -1811,8 +1810,7 @@ class Application:
|
||||||
Application: a new application object containing the deployments.
|
Application: a new application object containing the deployments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
schema = ServeApplicationSchema.parse_obj(d)
|
return cls(schema_to_serve_application(ServeApplicationSchema.parse_obj(d)))
|
||||||
return cls([schema_to_deployment(s) for s in schema.deployments])
|
|
||||||
|
|
||||||
def to_yaml(self, f: Optional[TextIO] = None) -> Optional[str]:
|
def to_yaml(self, f: Optional[TextIO] = None) -> Optional[str]:
|
||||||
"""Returns this application's deployments as a YAML string.
|
"""Returns this application's deployments as a YAML string.
|
||||||
|
@ -1834,6 +1832,7 @@ class Application:
|
||||||
Optional[String]: The deployments' YAML string. The output is from
|
Optional[String]: The deployments' YAML string. The output is from
|
||||||
yaml.safe_dump(). Returned only if no file pointer is passed in.
|
yaml.safe_dump(). Returned only if no file pointer is passed in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return yaml.safe_dump(
|
return yaml.safe_dump(
|
||||||
self.to_dict(), stream=f, default_flow_style=False, sort_keys=False
|
self.to_dict(), stream=f, default_flow_style=False, sort_keys=False
|
||||||
)
|
)
|
||||||
|
@ -1872,7 +1871,7 @@ def run(
|
||||||
*,
|
*,
|
||||||
host: str = DEFAULT_HTTP_HOST,
|
host: str = DEFAULT_HTTP_HOST,
|
||||||
port: int = DEFAULT_HTTP_PORT,
|
port: int = DEFAULT_HTTP_PORT,
|
||||||
) -> RayServeHandle:
|
) -> Optional[RayServeHandle]:
|
||||||
"""Run a Serve application and return a ServeHandle to the ingress.
|
"""Run a Serve application and return a ServeHandle to the ingress.
|
||||||
|
|
||||||
Either a DeploymentNode, DeploymentFunctionNode, or a pre-built application
|
Either a DeploymentNode, DeploymentFunctionNode, or a pre-built application
|
||||||
|
@ -1954,13 +1953,13 @@ def run(
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="alpha")
|
@PublicAPI(stability="alpha")
|
||||||
def build(target: DeploymentNode) -> Application:
|
def build(target: Union[DeploymentNode, DeploymentFunctionNode]) -> Application:
|
||||||
"""Builds a Serve application into a static application.
|
"""Builds a Serve application into a static application.
|
||||||
|
|
||||||
Takes in a DeploymentNode and converts it to a Serve application
|
Takes in a DeploymentNode or DeploymentFunctionNode and converts it to a
|
||||||
consisting of one or more deployments. This is intended to be used for
|
Serve application consisting of one or more deployments. This is intended
|
||||||
production scenarios and deployed via the Serve REST API or CLI, so there
|
to be used for production scenarios and deployed via the Serve REST API or
|
||||||
are some restrictions placed on the deployments:
|
CLI, so there are some restrictions placed on the deployments:
|
||||||
1) All of the deployments must be importable. That is, they cannot be
|
1) All of the deployments must be importable. That is, they cannot be
|
||||||
defined in __main__ or inline defined. The deployments will be
|
defined in __main__ or inline defined. The deployments will be
|
||||||
imported in production using the same import path they were here.
|
imported in production using the same import path they were here.
|
||||||
|
@ -1969,9 +1968,19 @@ def build(target: DeploymentNode) -> Application:
|
||||||
The returned Application object can be exported to a dictionary or YAML
|
The returned Application object can be exported to a dictionary or YAML
|
||||||
config.
|
config.
|
||||||
"""
|
"""
|
||||||
|
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
|
||||||
|
from ray.serve.pipeline.api import build as pipeline_build
|
||||||
|
|
||||||
|
if in_interactive_shell():
|
||||||
|
raise RuntimeError(
|
||||||
|
"serve.build cannot be called from an interactive shell like "
|
||||||
|
"IPython or Jupyter because it requires all deployments to be "
|
||||||
|
"importable to run the app after building."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(edoakes): this should accept host and port, but we don't
|
# TODO(edoakes): this should accept host and port, but we don't
|
||||||
# currently support them in the REST API.
|
# currently support them in the REST API.
|
||||||
raise NotImplementedError()
|
return Application(pipeline_build(target))
|
||||||
|
|
||||||
|
|
||||||
def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
||||||
|
@ -1983,6 +1992,7 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
||||||
init_args and init_kwargs must also be JSON-serializable or this call will
|
init_args and init_kwargs must also be JSON-serializable or this call will
|
||||||
fail.
|
fail.
|
||||||
"""
|
"""
|
||||||
|
from ray.serve.pipeline.json_serde import convert_to_json_safe_obj
|
||||||
|
|
||||||
if d.ray_actor_options is not None:
|
if d.ray_actor_options is not None:
|
||||||
ray_actor_options_schema = RayActorOptionsSchema.parse_obj(d.ray_actor_options)
|
ray_actor_options_schema = RayActorOptionsSchema.parse_obj(d.ray_actor_options)
|
||||||
|
@ -1991,9 +2001,11 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
|
||||||
|
|
||||||
return DeploymentSchema(
|
return DeploymentSchema(
|
||||||
name=d.name,
|
name=d.name,
|
||||||
import_path=get_deployment_import_path(d),
|
import_path=get_deployment_import_path(
|
||||||
init_args=d.init_args,
|
d, enforce_importable=True, replace_main=True
|
||||||
init_kwargs=d.init_kwargs,
|
),
|
||||||
|
init_args=convert_to_json_safe_obj(d.init_args, err_key="init_args"),
|
||||||
|
init_kwargs=convert_to_json_safe_obj(d.init_kwargs, err_key="init_kwargs"),
|
||||||
num_replicas=d.num_replicas,
|
num_replicas=d.num_replicas,
|
||||||
route_prefix=d.route_prefix,
|
route_prefix=d.route_prefix,
|
||||||
max_concurrent_queries=d.max_concurrent_queries,
|
max_concurrent_queries=d.max_concurrent_queries,
|
||||||
|
@ -2017,12 +2029,12 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
|
||||||
|
|
||||||
return deployment(
|
return deployment(
|
||||||
name=s.name,
|
name=s.name,
|
||||||
init_args=convert_from_json_safe_obj(s.init_args),
|
init_args=convert_from_json_safe_obj(s.init_args, err_key="init_args"),
|
||||||
init_kwargs=convert_from_json_safe_obj(s.init_kwargs),
|
init_kwargs=convert_from_json_safe_obj(s.init_kwargs, err_key="init_kwargs"),
|
||||||
num_replicas=s.num_replicas,
|
num_replicas=s.num_replicas,
|
||||||
route_prefix=s.route_prefix,
|
route_prefix=s.route_prefix,
|
||||||
max_concurrent_queries=s.max_concurrent_queries,
|
max_concurrent_queries=s.max_concurrent_queries,
|
||||||
user_config=convert_from_json_safe_obj(s.user_config),
|
user_config=s.user_config,
|
||||||
_autoscaling_config=s.autoscaling_config,
|
_autoscaling_config=s.autoscaling_config,
|
||||||
_graceful_shutdown_wait_loop_s=s.graceful_shutdown_wait_loop_s,
|
_graceful_shutdown_wait_loop_s=s.graceful_shutdown_wait_loop_s,
|
||||||
_graceful_shutdown_timeout_s=s.graceful_shutdown_timeout_s,
|
_graceful_shutdown_timeout_s=s.graceful_shutdown_timeout_s,
|
||||||
|
@ -2035,8 +2047,9 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
|
||||||
def serve_application_to_schema(
|
def serve_application_to_schema(
|
||||||
deployments: List[Deployment],
|
deployments: List[Deployment],
|
||||||
) -> ServeApplicationSchema:
|
) -> ServeApplicationSchema:
|
||||||
schemas = [deployment_to_schema(d) for d in deployments]
|
return ServeApplicationSchema(
|
||||||
return ServeApplicationSchema(deployments=schemas)
|
deployments=[deployment_to_schema(d) for d in deployments]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def schema_to_serve_application(schema: ServeApplicationSchema) -> List[Deployment]:
|
def schema_to_serve_application(schema: ServeApplicationSchema) -> List[Deployment]:
|
||||||
|
|
|
@ -31,13 +31,36 @@ from ray.serve.api import RayServeDAGHandle
|
||||||
|
|
||||||
|
|
||||||
def convert_to_json_safe_obj(obj: Any, *, err_key: str) -> Any:
|
def convert_to_json_safe_obj(obj: Any, *, err_key: str) -> Any:
|
||||||
# XXX: comment, err msg
|
"""Converts the provided object into a JSON-safe version of it.
|
||||||
return json.loads(json.dumps(obj, cls=DAGNodeEncoder))
|
|
||||||
|
The returned object can safely be `json.dumps`'d to a string.
|
||||||
|
|
||||||
|
Uses the Ray Serve encoder to serialize special objects such as
|
||||||
|
ServeHandles and DAGHandles.
|
||||||
|
|
||||||
|
Raises: TypeError if the object contains fields that cannot be
|
||||||
|
JSON-serialized.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(json.dumps(obj, cls=DAGNodeEncoder))
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError(
|
||||||
|
"All provided fields must be JSON-serializable to build the "
|
||||||
|
f"Serve app. Failed while serializing {err_key}:\n{e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_from_json_safe_obj(obj: Any) -> Any:
|
def convert_from_json_safe_obj(obj: Any, *, err_key: str) -> Any:
|
||||||
# XXX: comment, err msg
|
"""Converts a JSON-safe object to one that contains Serve special types.
|
||||||
return json.loads(json.dumps(obj), object_hook=dagnode_from_json)
|
|
||||||
|
The provided object should have been serialized using
|
||||||
|
convert_to_json_safe_obj. Any special-cased objects such as ServeHandles
|
||||||
|
will be recovered on this pass.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(json.dumps(obj), object_hook=dagnode_from_json)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to convert {err_key} from JSON:\n{e}")
|
||||||
|
|
||||||
|
|
||||||
class DAGNodeEncoder(json.JSONEncoder):
|
class DAGNodeEncoder(json.JSONEncoder):
|
||||||
|
@ -81,17 +104,7 @@ class DAGNodeEncoder(json.JSONEncoder):
|
||||||
elif isinstance(obj, DAGNode):
|
elif isinstance(obj, DAGNode):
|
||||||
return obj.to_json(DAGNodeEncoder)
|
return obj.to_json(DAGNodeEncoder)
|
||||||
else:
|
else:
|
||||||
# Let the base class default method raise the TypeError
|
return json.JSONEncoder.default(self, obj)
|
||||||
try:
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
except Exception as e:
|
|
||||||
raise TypeError(
|
|
||||||
"All args and kwargs used in Ray DAG building for serve "
|
|
||||||
"deployment need to be JSON serializable. Please JSON "
|
|
||||||
"serialize your args to make your ray application "
|
|
||||||
"deployment ready."
|
|
||||||
f"\n Original exception message: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
|
def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
|
||||||
|
|
|
@ -97,24 +97,6 @@ def test_non_json_serializable_args():
|
||||||
|
|
||||||
ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2))
|
ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2))
|
||||||
# General context
|
# General context
|
||||||
with pytest.raises(
|
|
||||||
TypeError,
|
|
||||||
match=(
|
|
||||||
"All args and kwargs used in Ray DAG building for serve "
|
|
||||||
"deployment need to be JSON serializable."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
|
||||||
# User actionable item
|
|
||||||
with pytest.raises(
|
|
||||||
TypeError,
|
|
||||||
match=(
|
|
||||||
"Please JSON serialize your args to make your ray application "
|
|
||||||
"deployment ready"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
|
||||||
# Original error message
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match=r"Object of type .* is not JSON serializable",
|
match=r"Object of type .* is not JSON serializable",
|
||||||
|
|
|
@ -5,6 +5,7 @@ import pathlib
|
||||||
import click
|
import click
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Optional, Union
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -22,6 +23,8 @@ from ray.dashboard.modules.serve.sdk import ServeSubmissionClient
|
||||||
from ray.autoscaler._private.cli_logger import cli_logger
|
from ray.autoscaler._private.cli_logger import cli_logger
|
||||||
from ray.serve.api import (
|
from ray.serve.api import (
|
||||||
Application,
|
Application,
|
||||||
|
DeploymentFunctionNode,
|
||||||
|
DeploymentNode,
|
||||||
get_deployment_statuses,
|
get_deployment_statuses,
|
||||||
serve_application_status_to_schema,
|
serve_application_status_to_schema,
|
||||||
)
|
)
|
||||||
|
@ -378,3 +381,51 @@ def delete(address: str, yes: bool):
|
||||||
cli_logger.newline()
|
cli_logger.newline()
|
||||||
cli_logger.success("\nSent delete request successfully!\n")
|
cli_logger.success("\nSent delete request successfully!\n")
|
||||||
cli_logger.newline()
|
cli_logger.newline()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(
|
||||||
|
short_help="Writes a Pipeline's config file.",
|
||||||
|
help=(
|
||||||
|
"Imports the DeploymentNode or DeploymentFunctionNode at IMPORT_PATH "
|
||||||
|
"and generates a structured config for it that can be used by "
|
||||||
|
"`serve deploy` or the REST API. "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--app-dir",
|
||||||
|
"-d",
|
||||||
|
default=".",
|
||||||
|
type=str,
|
||||||
|
help=APP_DIR_HELP_STR,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output-path",
|
||||||
|
"-o",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Local path where the output config will be written in YAML format. "
|
||||||
|
"If not provided, the config will be printed to STDOUT."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.argument("import_path")
|
||||||
|
def build(app_dir: str, output_path: Optional[str], import_path: str):
|
||||||
|
sys.path.insert(0, app_dir)
|
||||||
|
|
||||||
|
node: Union[DeploymentNode, DeploymentFunctionNode] = import_attr(import_path)
|
||||||
|
if not isinstance(node, (DeploymentNode, DeploymentFunctionNode)):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected '{import_path}' to be DeploymentNode or "
|
||||||
|
f"DeploymentFunctionNode, but got {type(node)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
app = serve.build(node)
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
if not output_path.endswith(".yaml"):
|
||||||
|
raise ValueError("FILE_PATH must end with '.yaml'.")
|
||||||
|
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
app.to_yaml(f)
|
||||||
|
else:
|
||||||
|
print(app.to_yaml(), end="")
|
||||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import requests
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
@ -42,7 +43,7 @@ class TestApplicationConstruction:
|
||||||
Application([self.f, 5, "hello"])
|
Application([self.f, 5, "hello"])
|
||||||
|
|
||||||
|
|
||||||
class TestRun:
|
class TestServeRun:
|
||||||
@serve.deployment
|
@serve.deployment
|
||||||
def f():
|
def f():
|
||||||
return "f reached"
|
return "f reached"
|
||||||
|
@ -272,6 +273,35 @@ class DecoratedClass:
|
||||||
return "got decorated class"
|
return "got decorated class"
|
||||||
|
|
||||||
|
|
||||||
|
class TestServeBuild:
|
||||||
|
@serve.deployment
|
||||||
|
class A:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_build_non_json_serializable_args(self, serve_instance):
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must be JSON-serializable to build.*init_args"
|
||||||
|
):
|
||||||
|
serve.build(self.A.bind(np.zeros(100))).to_dict()
|
||||||
|
|
||||||
|
def test_build_non_json_serializable_kwargs(self, serve_instance):
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must be JSON-serializable to build.*init_kwargs"
|
||||||
|
):
|
||||||
|
serve.build(self.A.bind(kwarg=np.zeros(100))).to_dict()
|
||||||
|
|
||||||
|
def test_build_non_importable(self, serve_instance):
|
||||||
|
def gen_deployment():
|
||||||
|
@serve.deployment
|
||||||
|
def f():
|
||||||
|
pass
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="must be importable"):
|
||||||
|
serve.build(gen_deployment().bind()).to_dict()
|
||||||
|
|
||||||
|
|
||||||
def compare_specified_options(deployments1: Dict, deployments2: Dict):
|
def compare_specified_options(deployments1: Dict, deployments2: Dict):
|
||||||
"""
|
"""
|
||||||
Helper method that takes 2 deployment dictionaries in the REST API
|
Helper method that takes 2 deployment dictionaries in the REST API
|
||||||
|
|
|
@ -6,12 +6,13 @@ import sys
|
||||||
import signal
|
import signal
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray.tests.conftest import tmp_working_dir # noqa: F401, E501
|
from ray.tests.conftest import tmp_working_dir # noqa: F401, E501
|
||||||
from ray._private.test_utils import wait_for_condition
|
from ray._private.test_utils import wait_for_condition
|
||||||
from ray.serve.api import Application
|
from ray.serve.api import Application, RayServeDAGHandle
|
||||||
|
|
||||||
|
|
||||||
def ping_endpoint(endpoint: str, params: str = ""):
|
def ping_endpoint(endpoint: str, params: str = ""):
|
||||||
|
@ -372,6 +373,45 @@ def test_run_runtime_env(ray_start_stop):
|
||||||
p.wait()
|
p.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@serve.deployment
|
||||||
|
def global_f(*args):
|
||||||
|
return "wonderful world"
|
||||||
|
|
||||||
|
|
||||||
|
@serve.deployment
|
||||||
|
class NoArgDriver:
|
||||||
|
def __init__(self, dag: RayServeDAGHandle):
|
||||||
|
self.dag = dag
|
||||||
|
|
||||||
|
async def __call__(self):
|
||||||
|
return await self.dag.remote()
|
||||||
|
|
||||||
|
|
||||||
|
TestBuildFNode = global_f.bind()
|
||||||
|
TestBuildDagNode = NoArgDriver.bind(TestBuildFNode)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
|
||||||
|
@pytest.mark.parametrize("node", ["TestBuildFNode", "TestBuildDagNode"])
|
||||||
|
def test_build(ray_start_stop, node):
|
||||||
|
with NamedTemporaryFile(mode="w+", suffix=".yaml") as tmp:
|
||||||
|
|
||||||
|
# Build an app
|
||||||
|
subprocess.check_output(
|
||||||
|
[
|
||||||
|
"serve",
|
||||||
|
"build",
|
||||||
|
f"ray.serve.tests.test_cli.{node}",
|
||||||
|
"-o",
|
||||||
|
tmp.name,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
subprocess.check_output(["serve", "deploy", tmp.name])
|
||||||
|
assert ping_endpoint("") == "wonderful world"
|
||||||
|
subprocess.check_output(["serve", "delete", "-y"])
|
||||||
|
assert ping_endpoint("") == "connection error"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
|
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
|
||||||
@pytest.mark.parametrize("use_command", [True, False])
|
@pytest.mark.parametrize("use_command", [True, False])
|
||||||
def test_idempotence_after_controller_death(ray_start_stop, use_command: bool):
|
def test_idempotence_after_controller_death(ray_start_stop, use_command: bool):
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import TypeVar
|
from typing import TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray.serve.api import RayServeDAGHandle
|
|
||||||
from ray.experimental.dag.input_node import InputNode
|
from ray.experimental.dag.input_node import InputNode
|
||||||
|
from ray.serve.api import Application, DeploymentNode, RayServeDAGHandle
|
||||||
from ray.serve.pipeline.api import build as pipeline_build
|
from ray.serve.pipeline.api import build as pipeline_build
|
||||||
from ray.serve.drivers import DAGDriver
|
from ray.serve.drivers import DAGDriver
|
||||||
import starlette.requests
|
import starlette.requests
|
||||||
|
@ -19,6 +19,15 @@ RayHandleLike = TypeVar("RayHandleLike")
|
||||||
NESTED_HANDLE_KEY = "nested_handle"
|
NESTED_HANDLE_KEY = "nested_handle"
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_build(
|
||||||
|
node: DeploymentNode, use_build: bool
|
||||||
|
) -> Union[Application, DeploymentNode]:
|
||||||
|
if use_build:
|
||||||
|
return Application.from_dict(serve.build(node).to_dict())
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment
|
@serve.deployment
|
||||||
class ClassHello:
|
class ClassHello:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -113,11 +122,12 @@ class NoargDriver:
|
||||||
return await self.dag.remote()
|
return await self.dag.remote()
|
||||||
|
|
||||||
|
|
||||||
def test_single_func_no_input(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_single_func_no_input(serve_instance, use_build):
|
||||||
dag = fn_hello.bind()
|
dag = fn_hello.bind()
|
||||||
serve_dag = NoargDriver.bind(dag)
|
serve_dag = NoargDriver.bind(dag)
|
||||||
|
|
||||||
handle = serve.run(serve_dag)
|
handle = serve.run(maybe_build(serve_dag, use_build))
|
||||||
assert ray.get(handle.remote()) == "hello"
|
assert ray.get(handle.remote()) == "hello"
|
||||||
assert requests.get("http://127.0.0.1:8000/").text == "hello"
|
assert requests.get("http://127.0.0.1:8000/").text == "hello"
|
||||||
|
|
||||||
|
@ -126,7 +136,8 @@ async def json_resolver(request: starlette.requests.Request):
|
||||||
return await request.json()
|
return await request.json()
|
||||||
|
|
||||||
|
|
||||||
def test_single_func_deployment_dag(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_single_func_deployment_dag(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
||||||
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
|
||||||
|
@ -135,7 +146,8 @@ def test_single_func_deployment_dag(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
|
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
|
||||||
|
|
||||||
|
|
||||||
def test_chained_function(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_chained_function(serve_instance, use_build):
|
||||||
@serve.deployment
|
@serve.deployment
|
||||||
def func_1(input):
|
def func_1(input):
|
||||||
return input
|
return input
|
||||||
|
@ -156,7 +168,8 @@ def test_chained_function(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
|
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
|
||||||
|
|
||||||
|
|
||||||
def test_simple_class_with_class_method(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_simple_class_with_class_method(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
model = Model.bind(2, ratio=0.3)
|
model = Model.bind(2, ratio=0.3)
|
||||||
dag = model.forward.bind(dag_input)
|
dag = model.forward.bind(dag_input)
|
||||||
|
@ -166,7 +179,8 @@ def test_simple_class_with_class_method(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
|
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
|
||||||
|
|
||||||
|
|
||||||
def test_func_class_with_class_method(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_func_class_with_class_method(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
m1 = Model.bind(1)
|
m1 = Model.bind(1)
|
||||||
m2 = Model.bind(2)
|
m2 = Model.bind(2)
|
||||||
|
@ -180,7 +194,8 @@ def test_func_class_with_class_method(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
|
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
|
||||||
|
|
||||||
|
|
||||||
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
m1 = Model.bind(2)
|
m1 = Model.bind(2)
|
||||||
m2 = Model.bind(3)
|
m2 = Model.bind(3)
|
||||||
|
@ -193,7 +208,8 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
|
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
|
||||||
|
|
||||||
|
|
||||||
def test_shared_deployment_handle(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_shared_deployment_handle(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
m = Model.bind(2)
|
m = Model.bind(2)
|
||||||
combine = Combine.bind(m, m2=m)
|
combine = Combine.bind(m, m2=m)
|
||||||
|
@ -205,7 +221,8 @@ def test_shared_deployment_handle(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
|
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
|
||||||
|
|
||||||
|
|
||||||
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use_build):
|
||||||
with InputNode() as dag_input:
|
with InputNode() as dag_input:
|
||||||
m1 = Model.bind(2)
|
m1 = Model.bind(2)
|
||||||
m2 = Model.bind(3)
|
m2 = Model.bind(3)
|
||||||
|
@ -238,13 +255,15 @@ class Echo:
|
||||||
return self._s
|
return self._s
|
||||||
|
|
||||||
|
|
||||||
def test_single_node_deploy_success(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_single_node_deploy_success(serve_instance, use_build):
|
||||||
m1 = Adder.bind(1)
|
m1 = Adder.bind(1)
|
||||||
handle = serve.run(m1)
|
handle = serve.run(maybe_build(m1, use_build))
|
||||||
assert ray.get(handle.remote(41)) == 42
|
assert ray.get(handle.remote(41)) == 42
|
||||||
|
|
||||||
|
|
||||||
def test_single_node_driver_sucess(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_single_node_driver_sucess(serve_instance, use_build):
|
||||||
m1 = Adder.bind(1)
|
m1 = Adder.bind(1)
|
||||||
m2 = Adder.bind(2)
|
m2 = Adder.bind(2)
|
||||||
with InputNode() as input_node:
|
with InputNode() as input_node:
|
||||||
|
@ -280,7 +299,8 @@ class TakeHandle:
|
||||||
return ray.get(self.handle.remote(inp))
|
return ray.get(self.handle.remote(inp))
|
||||||
|
|
||||||
|
|
||||||
def test_passing_handle(serve_instance):
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_passing_handle(serve_instance, use_build):
|
||||||
child = Adder.bind(1)
|
child = Adder.bind(1)
|
||||||
parent = TakeHandle.bind(child)
|
parent = TakeHandle.bind(child)
|
||||||
driver = DAGDriver.bind(parent, input_schema=json_resolver)
|
driver = DAGDriver.bind(parent, input_schema=json_resolver)
|
||||||
|
@ -289,74 +309,92 @@ def test_passing_handle(serve_instance):
|
||||||
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
|
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
|
||||||
|
|
||||||
|
|
||||||
def test_passing_handle_in_obj(serve_instance):
|
@serve.deployment
|
||||||
@serve.deployment
|
class DictParent:
|
||||||
class Parent:
|
def __init__(self, d):
|
||||||
def __init__(self, d):
|
self._d = d
|
||||||
self._d = d
|
|
||||||
|
|
||||||
async def __call__(self, key):
|
async def __call__(self, key):
|
||||||
return await self._d[key].remote()
|
return await self._d[key].remote()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_passing_handle_in_obj(serve_instance, use_build):
|
||||||
|
|
||||||
child1 = Echo.bind("ed")
|
child1 = Echo.bind("ed")
|
||||||
child2 = Echo.bind("simon")
|
child2 = Echo.bind("simon")
|
||||||
parent = Parent.bind({"child1": child1, "child2": child2})
|
parent = maybe_build(
|
||||||
|
DictParent.bind({"child1": child1, "child2": child2}), use_build
|
||||||
|
)
|
||||||
|
|
||||||
handle = serve.run(parent)
|
handle = serve.run(parent)
|
||||||
assert ray.get(handle.remote("child1")) == "ed"
|
assert ray.get(handle.remote("child1")) == "ed"
|
||||||
assert ray.get(handle.remote("child2")) == "simon"
|
assert ray.get(handle.remote("child2")) == "simon"
|
||||||
|
|
||||||
|
|
||||||
def test_pass_handle_to_multiple(serve_instance):
|
@serve.deployment
|
||||||
@serve.deployment
|
class Child:
|
||||||
class Child:
|
def __call__(self, *args):
|
||||||
def __call__(self, *args):
|
return os.getpid()
|
||||||
return os.getpid()
|
|
||||||
|
|
||||||
@serve.deployment
|
|
||||||
class Parent:
|
|
||||||
def __init__(self, child):
|
|
||||||
self._child = child
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
@serve.deployment
|
||||||
return ray.get(self._child.remote())
|
class Parent:
|
||||||
|
def __init__(self, child):
|
||||||
|
self._child = child
|
||||||
|
|
||||||
@serve.deployment
|
def __call__(self, *args):
|
||||||
class GrandParent:
|
return ray.get(self._child.remote())
|
||||||
def __init__(self, child, parent):
|
|
||||||
self._child = child
|
|
||||||
self._parent = parent
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
|
||||||
# Check that the grandparent and parent are talking to the same child.
|
@serve.deployment
|
||||||
assert ray.get(self._child.remote()) == ray.get(self._parent.remote())
|
class GrandParent:
|
||||||
return "ok"
|
def __init__(self, child, parent):
|
||||||
|
self._child = child
|
||||||
|
self._parent = parent
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
# Check that the grandparent and parent are talking to the same child.
|
||||||
|
assert ray.get(self._child.remote()) == ray.get(self._parent.remote())
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_build", [False, True])
|
||||||
|
def test_pass_handle_to_multiple(serve_instance, use_build):
|
||||||
|
|
||||||
child = Child.bind()
|
child = Child.bind()
|
||||||
parent = Parent.bind(child)
|
parent = Parent.bind(child)
|
||||||
grandparent = GrandParent.bind(child, parent)
|
grandparent = maybe_build(GrandParent.bind(child, parent), use_build)
|
||||||
|
|
||||||
handle = serve.run(grandparent)
|
handle = serve.run(grandparent)
|
||||||
assert ray.get(handle.remote()) == "ok"
|
assert ray.get(handle.remote()) == "ok"
|
||||||
|
|
||||||
|
|
||||||
def test_non_json_serializable_args(serve_instance):
|
def test_run_non_json_serializable_args(serve_instance):
|
||||||
# Test that we can capture and bind non-json-serializable arguments.
|
# Test that we can capture and bind non-json-serializable arguments.
|
||||||
arr1 = np.zeros(100)
|
arr1 = np.zeros(100)
|
||||||
arr2 = np.zeros(200)
|
arr2 = np.zeros(200)
|
||||||
|
arr3 = np.zeros(300)
|
||||||
|
|
||||||
@serve.deployment
|
@serve.deployment
|
||||||
class A:
|
class A:
|
||||||
def __init__(self, arr1):
|
def __init__(self, arr1, *, arr2):
|
||||||
self.arr1 = arr1
|
self.arr1 = arr1
|
||||||
self.arr2 = arr2
|
self.arr2 = arr2
|
||||||
|
self.arr3 = arr3
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
return self.arr1, self.arr2
|
return self.arr1, self.arr2, self.arr3
|
||||||
|
|
||||||
handle = serve.run(A.bind(arr1))
|
handle = serve.run(A.bind(arr1, arr2=arr2))
|
||||||
ret1, ret2 = ray.get(handle.remote())
|
ret1, ret2, ret3 = ray.get(handle.remote())
|
||||||
assert np.array_equal(ret1, arr1) and np.array_equal(ret2, arr2)
|
assert all(
|
||||||
|
[
|
||||||
|
np.array_equal(ret1, arr1),
|
||||||
|
np.array_equal(ret2, arr2),
|
||||||
|
np.array_equal(ret3, arr3),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment
|
@serve.deployment
|
||||||
|
@ -406,8 +444,5 @@ def test_unsupported_remote():
|
||||||
_ = func.bind().remote()
|
_ = func.bind().remote()
|
||||||
|
|
||||||
|
|
||||||
# TODO: check that serve.build raises an exception.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -39,7 +39,44 @@ class DecoratedActor:
|
||||||
return "reached decorated_actor"
|
return "reached decorated_actor"
|
||||||
|
|
||||||
|
|
||||||
|
def gen_func():
|
||||||
|
@serve.deployment
|
||||||
|
def f():
|
||||||
|
pass
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def gen_class():
|
||||||
|
@serve.deployment
|
||||||
|
class A:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return A
|
||||||
|
|
||||||
|
|
||||||
class TestGetDeploymentImportPath:
|
class TestGetDeploymentImportPath:
|
||||||
|
def test_invalid_inline_defined(self):
|
||||||
|
@serve.deployment
|
||||||
|
def inline_f():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="must be importable"):
|
||||||
|
get_deployment_import_path(inline_f, enforce_importable=True)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="must be importable"):
|
||||||
|
get_deployment_import_path(gen_func(), enforce_importable=True)
|
||||||
|
|
||||||
|
@serve.deployment
|
||||||
|
class InlineCls:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="must be importable"):
|
||||||
|
get_deployment_import_path(InlineCls, enforce_importable=True)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="must be importable"):
|
||||||
|
get_deployment_import_path(gen_class(), enforce_importable=True)
|
||||||
|
|
||||||
def test_get_import_path_basic(self):
|
def test_get_import_path_basic(self):
|
||||||
d = decorated_f.options()
|
d = decorated_f.options()
|
||||||
|
|
||||||
|
|
|
@ -255,7 +255,9 @@ def msgpack_serialize(obj):
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
|
||||||
def get_deployment_import_path(deployment, replace_main=False):
|
def get_deployment_import_path(
|
||||||
|
deployment, replace_main=False, enforce_importable=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Gets the import path for deployment's func_or_class.
|
Gets the import path for deployment's func_or_class.
|
||||||
|
|
||||||
|
@ -275,8 +277,15 @@ def get_deployment_import_path(deployment, replace_main=False):
|
||||||
|
|
||||||
import_path = f"{body.__module__}.{body.__qualname__}"
|
import_path = f"{body.__module__}.{body.__qualname__}"
|
||||||
|
|
||||||
if replace_main:
|
if enforce_importable and "<locals>" in body.__qualname__:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Deployment definitions must be importable to build the Serve app, "
|
||||||
|
f"but deployment '{deployment.name}' is inline defined or returned "
|
||||||
|
"from another function. Please restructure your code so that "
|
||||||
|
f"'{import_path}' can be imported (i.e., put it in a module)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if replace_main:
|
||||||
# Replaces __main__ with its file name. E.g. suppose the import path
|
# Replaces __main__ with its file name. E.g. suppose the import path
|
||||||
# is __main__.classname and classname is defined in filename.py.
|
# is __main__.classname and classname is defined in filename.py.
|
||||||
# Its import path becomes filename.classname.
|
# Its import path becomes filename.classname.
|
||||||
|
@ -365,3 +374,11 @@ def require_packages(packages: List[str]):
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def in_interactive_shell():
|
||||||
|
# Taken from:
|
||||||
|
# https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
|
||||||
|
import __main__ as main
|
||||||
|
|
||||||
|
return not hasattr(main, "__file__")
|
||||||
|
|
Loading…
Add table
Reference in a new issue