Restore "[Serve] Deploy Serve deployment graphs via REST API" (#25073) (#25333)

This commit is contained in:
shrekris-anyscale 2022-06-02 11:06:53 -07:00 committed by GitHub
parent ab8785ca5c
commit 16bdfe6a39
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 549 additions and 91 deletions

View file

@ -56,8 +56,17 @@ class ServeHead(dashboard_utils.DashboardHeadModule):
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True) @optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def put_all_deployments(self, req: Request) -> Response: async def put_all_deployments(self, req: Request) -> Response:
from ray import serve from ray import serve
from ray.serve.context import get_global_client
from ray.serve.schema import ServeApplicationSchema
from ray.serve.application import Application from ray.serve.application import Application
config = ServeApplicationSchema.parse_obj(await req.json())
if config.import_path is not None:
client = get_global_client(_override_controller_namespace="serve")
client.deploy_app(config)
else:
# TODO (shrekris-anyscale): Remove this conditional path
app = Application.from_dict(await req.json()) app = Application.from_dict(await req.json())
serve.run(app, _blocking=False) serve.run(app, _blocking=False)

View file

@ -153,6 +153,39 @@ def test_put_get_success(ray_start_stop):
) )
def test_put_new_rest_api(ray_start_stop):
config = {
"import_path": "ray.serve.tests.test_config_files.pizza.serve_dag",
"deployments": [
{
"name": "Multiplier",
"user_config": {
"factor": 1,
},
},
{
"name": "Adder",
"user_config": {
"increment": 1,
},
},
],
}
put_response = requests.put(GET_OR_PUT_URL, json=config, timeout=30)
assert put_response.status_code == 200
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "3 pizzas please!",
timeout=30,
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 2]).json()
== "2 pizzas please!",
timeout=30,
)
def test_delete_success(ray_start_stop): def test_delete_success(ray_start_stop):
ray_actor_options = { ray_actor_options = {
"runtime_env": { "runtime_env": {

View file

@ -399,7 +399,7 @@ def test_ordinal_encoder_no_encode_list():
"unique_values(B)": {"cold": 0, "hot": 1, "warm": 2}, "unique_values(B)": {"cold": 0, "hot": 1, "warm": 2},
"unique_values(C)": {1: 0, 5: 1, 10: 2}, "unique_values(C)": {1: 0, 5: 1, 10: 2},
"unique_values(D)": { "unique_values(D)": {
(): 0, tuple(): 0,
("cold", "cold"): 1, ("cold", "cold"): 1,
("hot", "warm", "cold"): 2, ("hot", "warm", "cold"): 2,
("warm",): 3, ("warm",): 3,
@ -475,7 +475,7 @@ def test_one_hot_encoder():
"unique_values(B)": {"cold": 0, "hot": 1, "warm": 2}, "unique_values(B)": {"cold": 0, "hot": 1, "warm": 2},
"unique_values(C)": {1: 0, 5: 1, 10: 2}, "unique_values(C)": {1: 0, 5: 1, 10: 2},
"unique_values(D)": { "unique_values(D)": {
(): 0, tuple(): 0,
("cold", "cold"): 1, ("cold", "cold"): 1,
("hot", "warm", "cold"): 2, ("hot", "warm", "cold"): 2,
("warm",): 3, ("warm",): 3,

View file

@ -28,6 +28,7 @@ from ray.serve.config import (
HTTPOptions, HTTPOptions,
ReplicaConfig, ReplicaConfig,
) )
from ray.serve.schema import ServeApplicationSchema
from ray.serve.constants import ( from ray.serve.constants import (
MAX_CACHED_HANDLES, MAX_CACHED_HANDLES,
CLIENT_POLLING_INTERVAL_S, CLIENT_POLLING_INTERVAL_S,
@ -325,6 +326,16 @@ class ServeControllerClient:
) )
self.delete_deployments(deployment_names_to_delete, blocking=_blocking) self.delete_deployments(deployment_names_to_delete, blocking=_blocking)
@_ensure_connected
def deploy_app(self, config: ServeApplicationSchema) -> None:
ray.get(
self._controller.deploy_app.remote(
config.import_path,
config.runtime_env,
config.dict(by_alias=True, exclude_unset=True).get("deployments", []),
)
)
@_ensure_connected @_ensure_connected
def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None: def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None:
ray.get(self._controller.delete_deployments.remote(names)) ray.get(self._controller.delete_deployments.remote(names))

View file

@ -136,7 +136,7 @@ class StatusOverview:
) )
@classmethod @classmethod
def from_proto(cls, proto: StatusOverviewProto): def from_proto(cls, proto: StatusOverviewProto) -> "StatusOverview":
# Recreate Serve Application info # Recreate Serve Application info
app_status = ApplicationStatusInfo.from_proto(proto.app_status) app_status = ApplicationStatusInfo.from_proto(proto.app_status)

View file

@ -3,12 +3,16 @@ from collections import defaultdict
from copy import copy from copy import copy
import json import json
import logging import logging
import traceback
import time import time
import os import os
from typing import Dict, Iterable, List, Optional, Tuple, Any from typing import Dict, Iterable, List, Optional, Tuple, Any
import ray import ray
from ray.types import ObjectRef
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray._private.utils import import_attr
from ray.exceptions import RayTaskError
from ray.serve.autoscaling_metrics import InMemoryMetricsStore from ray.serve.autoscaling_metrics import InMemoryMetricsStore
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
@ -119,6 +123,12 @@ class ServeController:
_override_controller_namespace=_override_controller_namespace, _override_controller_namespace=_override_controller_namespace,
) )
# Reference to Ray task executing most recent deployment request
self.config_deployment_request_ref: ObjectRef = None
# Unix timestamp of latest config deployment request. Defaults to 0.
self.deployment_timestamp = 0
# TODO(simon): move autoscaling related stuff into a manager. # TODO(simon): move autoscaling related stuff into a manager.
self.autoscaling_metrics_store = InMemoryMetricsStore() self.autoscaling_metrics_store = InMemoryMetricsStore()
self.handle_metrics_store = InMemoryMetricsStore() self.handle_metrics_store = InMemoryMetricsStore()
@ -408,6 +418,36 @@ class ServeController:
return [self.deploy(**args) for args in deployment_args_list] return [self.deploy(**args) for args in deployment_args_list]
def deploy_app(
self,
import_path: str,
runtime_env: str,
deployment_override_options: List[Dict],
) -> None:
"""Kicks off a task that deploys a Serve application.
Cancels any previous in-progress task that is deploying a Serve
application.
Args:
import_path: Serve deployment graph's import path
runtime_env: runtime_env to run the deployment graph in
deployment_override_options: All dictionaries should
contain argument-value options that can be passed directly
into a set_options() call. Overrides deployment options set
in the graph itself.
"""
if self.config_deployment_request_ref is not None:
ray.cancel(self.config_deployment_request_ref)
logger.debug("Canceled existing config deployment request.")
self.config_deployment_request_ref = run_graph.options(
runtime_env=runtime_env
).remote(import_path, deployment_override_options)
self.deployment_timestamp = time.time()
def delete_deployment(self, name: str): def delete_deployment(self, name: str):
self.endpoint_state.delete_endpoint(name) self.endpoint_state.delete_endpoint(name)
return self.deployment_state_manager.delete_deployment(name) return self.deployment_state_manager.delete_deployment(name)
@ -494,12 +534,25 @@ class ServeController:
) )
return deployment_route_list.SerializeToString() return deployment_route_list.SerializeToString()
def get_serve_status(self) -> bytes: async def get_serve_status(self) -> bytes:
# TODO (shrekris-anyscale): Replace defaults with actual REST API status
serve_app_status = ApplicationStatus.RUNNING serve_app_status = ApplicationStatus.RUNNING
serve_app_message = "" serve_app_message = ""
deployment_timestamp = time.time() deployment_timestamp = self.deployment_timestamp
if self.config_deployment_request_ref:
finished, pending = ray.wait(
[self.config_deployment_request_ref], timeout=0
)
if pending:
serve_app_status = ApplicationStatus.DEPLOYING
else:
try:
await finished[0]
except RayTaskError:
serve_app_status = ApplicationStatus.DEPLOY_FAILED
serve_app_message = f"Deployment failed:\n{traceback.format_exc()}"
app_status = ApplicationStatusInfo( app_status = ApplicationStatusInfo(
serve_app_status, serve_app_message, deployment_timestamp serve_app_status, serve_app_message, deployment_timestamp
@ -512,3 +565,23 @@ class ServeController:
) )
return status_info.to_proto().SerializeToString() return status_info.to_proto().SerializeToString()
@ray.remote(max_calls=1)
def run_graph(import_path: str, deployment_override_options: List[Dict]):
"""Deploys a Serve application to the controller's Ray cluster."""
from ray import serve
from ray.serve.api import build
# Import and build the graph
graph = import_attr(import_path)
app = build(graph)
# Override options for each deployment
for options_dict in deployment_override_options:
name = options_dict["name"]
app.deployments[name].set_options(**options_dict)
# Run the graph locally on the cluster
serve.start(_override_controller_namespace="serve")
serve.run(app)

View file

@ -1,13 +1,17 @@
import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from ray import cloudpickle from ray import cloudpickle
from ray.serve.common import EndpointInfo, EndpointTag from ray.serve.common import EndpointInfo, EndpointTag
from ray.serve.constants import SERVE_LOGGER_NAME
from ray.serve.long_poll import LongPollNamespace from ray.serve.long_poll import LongPollNamespace
from ray.serve.storage.kv_store import KVStoreBase from ray.serve.storage.kv_store import KVStoreBase
from ray.serve.long_poll import LongPollHost from ray.serve.long_poll import LongPollHost
CHECKPOINT_KEY = "serve-endpoint-state-checkpoint" CHECKPOINT_KEY = "serve-endpoint-state-checkpoint"
logger = logging.getLogger(SERVE_LOGGER_NAME)
class EndpointState: class EndpointState:
"""Manages all state for endpoints in the system. """Manages all state for endpoints in the system.
@ -54,15 +58,19 @@ class EndpointState:
updated to match the given parameters. Calling this twice with the same updated to match the given parameters. Calling this twice with the same
arguments is a no-op. arguments is a no-op.
""" """
if self._endpoints.get(endpoint) == endpoint_info:
return
existing_route_endpoint = self._get_endpoint_for_route(endpoint_info.route) existing_route_endpoint = self._get_endpoint_for_route(endpoint_info.route)
if existing_route_endpoint is not None and existing_route_endpoint != endpoint: if existing_route_endpoint is not None and existing_route_endpoint != endpoint:
raise ValueError( logger.warn(
f"route_prefix '{endpoint_info.route}' is already registered." f'route_prefix "{endpoint_info.route}" is currently '
f'registered to deployment "{existing_route_endpoint}". '
f'Re-registering route_prefix "{endpoint_info.route}" to '
f'deployment "{endpoint}".'
) )
del self._endpoints[existing_route_endpoint]
if endpoint in self._endpoints:
if self._endpoints[endpoint] == endpoint_info:
return
self._endpoints[endpoint] = endpoint_info self._endpoints[endpoint] = endpoint_info

View file

@ -78,7 +78,9 @@ class RayActorOptionsSchema(BaseModel, extra=Extra.forbid):
return v return v
class DeploymentSchema(BaseModel, extra=Extra.forbid): class DeploymentSchema(
BaseModel, extra=Extra.forbid, allow_population_by_field_name=True
):
name: str = Field( name: str = Field(
..., description=("Globally-unique name identifying this deployment.") ..., description=("Globally-unique name identifying this deployment.")
) )
@ -153,6 +155,7 @@ class DeploymentSchema(BaseModel, extra=Extra.forbid):
"replicas; the number of replicas will be fixed at " "replicas; the number of replicas will be fixed at "
"num_replicas." "num_replicas."
), ),
alias="_autoscaling_config",
) )
graceful_shutdown_wait_loop_s: float = Field( graceful_shutdown_wait_loop_s: float = Field(
default=None, default=None,
@ -162,6 +165,7 @@ class DeploymentSchema(BaseModel, extra=Extra.forbid):
"default if null." "default if null."
), ),
ge=0, ge=0,
alias="_graceful_shutdown_wait_loop_s",
) )
graceful_shutdown_timeout_s: float = Field( graceful_shutdown_timeout_s: float = Field(
default=None, default=None,
@ -171,6 +175,7 @@ class DeploymentSchema(BaseModel, extra=Extra.forbid):
"default if null." "default if null."
), ),
ge=0, ge=0,
alias="_graceful_shutdown_timeout_s",
) )
health_check_period_s: float = Field( health_check_period_s: float = Field(
default=None, default=None,
@ -179,6 +184,7 @@ class DeploymentSchema(BaseModel, extra=Extra.forbid):
"replicas. Uses a default if null." "replicas. Uses a default if null."
), ),
gt=0, gt=0,
alias="_health_check_period_s",
) )
health_check_timeout_s: float = Field( health_check_timeout_s: float = Field(
default=None, default=None,
@ -188,66 +194,12 @@ class DeploymentSchema(BaseModel, extra=Extra.forbid):
"unhealthy. Uses a default if null." "unhealthy. Uses a default if null."
), ),
gt=0, gt=0,
alias="_health_check_timeout_s",
) )
ray_actor_options: RayActorOptionsSchema = Field( ray_actor_options: RayActorOptionsSchema = Field(
default=None, description="Options set for each replica actor." default=None, description="Options set for each replica actor."
) )
@root_validator
def application_sufficiently_specified(cls, values):
"""
Some application information, such as the path to the function or class
must be specified. Additionally, some attributes only work in specific
languages (e.g. init_args and init_kwargs make sense in Python but not
Java). Specifying attributes that belong to different languages is
invalid.
"""
# Ensure that an application path is set
application_paths = {"import_path"}
specified_path = None
for path in application_paths:
if path in values and values[path] is not None:
specified_path = path
if specified_path is None:
raise ValueError(
"A path to the application's class or function must be specified."
)
# Ensure that only attributes belonging to the application path's
# language are specified.
# language_attributes contains all attributes in this schema related to
# the application's language
language_attributes = {"import_path", "init_args", "init_kwargs"}
# corresponding_attributes maps application_path attributes to all the
# attributes that may be set in that path's language
corresponding_attributes = {
# Python
"import_path": {"import_path", "init_args", "init_kwargs"}
}
possible_attributes = corresponding_attributes[specified_path]
for attribute in values:
if (
attribute not in possible_attributes
and attribute in language_attributes
):
raise ValueError(
f'Got "{values[specified_path]}" for '
f"{specified_path} and {values[attribute]} "
f"for {attribute}. {specified_path} and "
f"{attribute} do not belong to the same "
f"language and cannot be specified at the "
f"same time. Expected one of these to be "
f"null."
)
return values
@root_validator @root_validator
def num_replicas_and_autoscaling_config_mutually_exclusive(cls, values): def num_replicas_and_autoscaling_config_mutually_exclusive(cls, values):
if ( if (
@ -345,7 +297,10 @@ class ServeApplicationSchema(BaseModel, extra=Extra.forbid):
"and py_modules may contain only remote URIs." "and py_modules may contain only remote URIs."
), ),
) )
deployments: List[DeploymentSchema] = Field(...) deployments: List[DeploymentSchema] = Field(
default=[],
description=("Deployment options that override options specified in the code."),
)
@validator("runtime_env") @validator("runtime_env")
def runtime_env_contains_remote_uris(cls, v): def runtime_env_contains_remote_uris(cls, v):
@ -396,6 +351,8 @@ class ServeApplicationSchema(BaseModel, extra=Extra.forbid):
"import path may not start or end with a dot." "import path may not start or end with a dot."
) )
return v
class ServeStatusSchema(BaseModel, extra=Extra.forbid): class ServeStatusSchema(BaseModel, extra=Extra.forbid):
app_status: ApplicationStatusInfo = Field( app_status: ApplicationStatusInfo = Field(

View file

@ -212,16 +212,6 @@ def test_user_config(serve_instance):
wait_for_condition(lambda: check("456", 3)) wait_for_condition(lambda: check("456", 3))
def test_reject_duplicate_route(serve_instance):
@serve.deployment(name="A", route_prefix="/api")
class A:
pass
A.deploy()
with pytest.raises(ValueError):
A.options(name="B").deploy()
def test_scaling_replicas(serve_instance): def test_scaling_replicas(serve_instance):
@serve.deployment(name="counter", num_replicas=2) @serve.deployment(name="counter", num_replicas=2)
class Counter: class Counter:

View file

@ -0,0 +1,87 @@
from enum import Enum
from typing import List, Dict, TypeVar
import ray
from ray import serve
import starlette.requests
from ray.serve.drivers import DAGDriver
from ray.serve.deployment_graph import InputNode
RayHandleLike = TypeVar("RayHandleLike")
class Operation(str, Enum):
ADDITION = "ADD"
MULTIPLICATION = "MUL"
@serve.deployment(ray_actor_options={"num_cpus": 0.15})
class Router:
def __init__(self, multiplier: RayHandleLike, adder: RayHandleLike):
self.adder = adder
self.multiplier = multiplier
def route(self, op: Operation, input: int) -> int:
if op == Operation.ADDITION:
return ray.get(self.adder.add.remote(input))
elif op == Operation.MULTIPLICATION:
return ray.get(self.multiplier.multiply.remote(input))
@serve.deployment(
user_config={
"factor": 3,
},
ray_actor_options={"num_cpus": 0.15},
)
class Multiplier:
def __init__(self, factor: int):
self.factor = factor
def reconfigure(self, config: Dict):
self.factor = config.get("factor", -1)
def multiply(self, input_factor: int) -> int:
return input_factor * self.factor
@serve.deployment(
user_config={
"increment": 2,
},
ray_actor_options={"num_cpus": 0.15},
)
class Adder:
def __init__(self, increment: int):
self.increment = increment
def reconfigure(self, config: Dict):
self.increment = config.get("increment", -1)
def add(self, input: int) -> int:
return input + self.increment
@serve.deployment(ray_actor_options={"num_cpus": 0.15})
def create_order(amount: int) -> str:
return f"{amount} pizzas please!"
async def json_resolver(request: starlette.requests.Request) -> List:
return await request.json()
# Overwritten by user_config
ORIGINAL_INCREMENT = 1
ORIGINAL_FACTOR = 1
with InputNode() as inp:
operation, amount_input = inp[0], inp[1]
multiplier = Multiplier.bind(ORIGINAL_FACTOR)
adder = Adder.bind(ORIGINAL_INCREMENT)
router = Router.bind(multiplier, adder)
amount = router.route.bind(operation, amount_input)
order = create_order.bind(amount)
serve_dag = DAGDriver.bind(order, http_adapter=json_resolver)

View file

@ -0,0 +1,20 @@
from ray import serve
from ray.serve.deployment_graph import RayServeDAGHandle
@serve.deployment(ray_actor_options={"num_cpus": 0.1})
def f(*args):
return "wonderful world"
@serve.deployment(ray_actor_options={"num_cpus": 0.1})
class BasicDriver:
def __init__(self, dag: RayServeDAGHandle):
self.dag = dag
async def __call__(self):
return await self.dag.remote()
FNode = f.bind()
DagNode = BasicDriver.bind(FNode)

View file

@ -37,8 +37,7 @@ def test_path_validation(serve_instance):
D4.deploy() D4.deploy()
# Reject duplicate route. # Allow duplicate route.
with pytest.raises(ValueError):
D4.options(name="test2").deploy() D4.options(name="test2").deploy()

View file

@ -272,12 +272,8 @@ class TestDeploymentSchema:
# Python requires an import path # Python requires an import path
deployment_schema = self.get_minimal_deployment_schema() deployment_schema = self.get_minimal_deployment_schema()
del deployment_schema["import_path"]
with pytest.raises(ValueError, match="must be specified"): # DeploymentSchema should be generated with valid import_paths
DeploymentSchema.parse_obj(deployment_schema)
# DeploymentSchema should be generated once import_path is set
for path in get_valid_import_paths(): for path in get_valid_import_paths():
deployment_schema["import_path"] = path deployment_schema["import_path"] = path
DeploymentSchema.parse_obj(deployment_schema) DeploymentSchema.parse_obj(deployment_schema)
@ -504,6 +500,67 @@ class TestServeApplicationSchema:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
ServeApplicationSchema.parse_obj(serve_application_schema) ServeApplicationSchema.parse_obj(serve_application_schema)
def test_serve_application_aliasing(self):
"""Check aliasing behavior for schemas."""
# Check that private options can optionally include underscore
app_dict = {
"import_path": "module.graph",
"runtime_env": {},
"deployments": [
{
"name": "d1",
"max_concurrent_queries": 3,
"autoscaling_config": {},
"_graceful_shutdown_wait_loop_s": 30,
"graceful_shutdown_timeout_s": 10,
"_health_check_period_s": 5,
"health_check_timeout_s": 7,
},
{
"name": "d2",
"max_concurrent_queries": 6,
"_autoscaling_config": {},
"graceful_shutdown_wait_loop_s": 50,
"_graceful_shutdown_timeout_s": 15,
"health_check_period_s": 53,
"_health_check_timeout_s": 73,
},
],
}
schema = ServeApplicationSchema.parse_obj(app_dict)
# Check that schema dictionary can include private options with an
# underscore (using the aliases)
private_options = {
"_autoscaling_config",
"_graceful_shutdown_wait_loop_s",
"_graceful_shutdown_timeout_s",
"_health_check_period_s",
"_health_check_timeout_s",
}
for deployment in schema.dict(by_alias=True)["deployments"]:
for option in private_options:
# Option with leading underscore
assert option in deployment
# Option without leading underscore
assert option[1:] not in deployment
# Check that schema dictionary can include private options without an
# underscore (using the field names)
for deployment in schema.dict()["deployments"]:
for option in private_options:
# Option without leading underscore
assert option[1:] in deployment
# Option with leading underscore
assert option not in deployment
class TestServeStatusSchema: class TestServeStatusSchema:
def get_valid_serve_status_schema(self): def get_valid_serve_status_schema(self):

View file

@ -1,8 +1,10 @@
from contextlib import contextmanager
import sys import sys
import os import os
import subprocess import subprocess
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import requests import requests
from typing import Dict
import pytest import pytest
from ray.cluster_utils import AutoscalingCluster from ray.cluster_utils import AutoscalingCluster
@ -12,6 +14,9 @@ import ray
import ray.state import ray.state
from ray import serve from ray import serve
from ray.serve.context import get_global_client from ray.serve.context import get_global_client
from ray.serve.schema import ServeApplicationSchema
from ray.serve.client import ServeControllerClient
from ray.serve.common import ApplicationStatus
from ray._private.test_utils import wait_for_condition from ray._private.test_utils import wait_for_condition
from ray.tests.conftest import call_ray_stop_only # noqa: F401 from ray.tests.conftest import call_ray_stop_only # noqa: F401
@ -25,13 +30,25 @@ def shutdown_ray():
ray.shutdown() ray.shutdown()
@pytest.fixture @contextmanager
def start_and_shutdown_ray_cli(): def start_and_shutdown_ray_cli():
subprocess.check_output(["ray", "start", "--head"]) subprocess.check_output(["ray", "start", "--head"])
yield yield
subprocess.check_output(["ray", "stop", "--force"]) subprocess.check_output(["ray", "stop", "--force"])
@pytest.fixture(scope="function")
def start_and_shutdown_ray_cli_function():
with start_and_shutdown_ray_cli():
yield
@pytest.fixture(scope="class")
def start_and_shutdown_ray_cli_class():
with start_and_shutdown_ray_cli():
yield
def test_standalone_actor_outside_serve(): def test_standalone_actor_outside_serve():
# https://github.com/ray-project/ray/issues/20066 # https://github.com/ray-project/ray/issues/20066
@ -215,7 +232,204 @@ def test_get_serve_status(shutdown_ray):
ray.shutdown() ray.shutdown()
def test_shutdown_remote(start_and_shutdown_ray_cli): @pytest.mark.usefixtures("start_and_shutdown_ray_cli_class")
class TestDeployApp:
@pytest.fixture()
def client(self):
ray.init(address="auto", namespace="serve")
client = serve.start(detached=True)
yield client
serve.shutdown()
ray.shutdown()
def get_test_config(self) -> Dict:
return {"import_path": "ray.serve.tests.test_config_files.pizza.serve_dag"}
def test_deploy_app_basic(self, client: ServeControllerClient):
config = ServeApplicationSchema.parse_obj(self.get_test_config())
client.deploy_app(config)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "4 pizzas please!"
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 3]).json()
== "9 pizzas please!"
)
def test_deploy_app_with_overriden_config(self, client: ServeControllerClient):
config = self.get_test_config()
config["deployments"] = [
{
"name": "Multiplier",
"user_config": {
"factor": 4,
},
},
{
"name": "Adder",
"user_config": {
"increment": 5,
},
},
]
client.deploy_app(ServeApplicationSchema.parse_obj(config))
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 0]).json()
== "5 pizzas please!"
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 2]).json()
== "8 pizzas please!"
)
def test_deploy_app_update_config(self, client: ServeControllerClient):
config = ServeApplicationSchema.parse_obj(self.get_test_config())
client.deploy_app(config)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "4 pizzas please!"
)
config = self.get_test_config()
config["deployments"] = [
{
"name": "Adder",
"user_config": {
"increment": -1,
},
},
]
client.deploy_app(ServeApplicationSchema.parse_obj(config))
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "1 pizzas please!"
)
def test_deploy_app_update_num_replicas(self, client: ServeControllerClient):
config = ServeApplicationSchema.parse_obj(self.get_test_config())
client.deploy_app(config)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "4 pizzas please!"
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 3]).json()
== "9 pizzas please!"
)
actors = ray.util.list_named_actors(all_namespaces=True)
config = self.get_test_config()
config["deployments"] = [
{
"name": "Adder",
"num_replicas": 2,
"user_config": {
"increment": 0,
},
"ray_actor_options": {"num_cpus": 0.1},
},
{
"name": "Multiplier",
"num_replicas": 3,
"user_config": {
"factor": 0,
},
"ray_actor_options": {"num_cpus": 0.1},
},
]
client.deploy_app(ServeApplicationSchema.parse_obj(config))
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "2 pizzas please!"
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 3]).json()
== "0 pizzas please!"
)
wait_for_condition(
lambda: client.get_serve_status().app_status.status
== ApplicationStatus.RUNNING,
timeout=15,
)
updated_actors = ray.util.list_named_actors(all_namespaces=True)
assert len(updated_actors) == len(actors) + 3
def test_deploy_app_update_timestamp(self, client: ServeControllerClient):
assert client.get_serve_status().app_status.deployment_timestamp == 0
config = ServeApplicationSchema.parse_obj(self.get_test_config())
client.deploy_app(config)
wait_for_condition(
lambda: client.get_serve_status().app_status.deployment_timestamp > 0
)
first_deploy_time = client.get_serve_status().app_status.deployment_timestamp
config = self.get_test_config()
config["deployments"] = [
{
"name": "Adder",
"num_replicas": 2,
},
]
client.deploy_app(ServeApplicationSchema.parse_obj(config))
wait_for_condition(
lambda: client.get_serve_status().app_status.deployment_timestamp
> first_deploy_time
)
assert client.get_serve_status().app_status.status in {
ApplicationStatus.DEPLOYING,
ApplicationStatus.RUNNING,
}
def test_deploy_app_overwrite_apps(self, client: ServeControllerClient):
"""Check that overwriting a live app with a new one works."""
# Launch first graph. Its driver's route_prefix should be "/".
test_config_1 = ServeApplicationSchema.parse_obj(
{
"import_path": "ray.serve.tests.test_config_files.world.DagNode",
}
)
client.deploy_app(test_config_1)
wait_for_condition(
lambda: requests.get("http://localhost:8000/").text == "wonderful world"
)
# Launch second graph. Its driver's route_prefix should also be "/".
# "/" should lead to the new driver.
test_config_2 = ServeApplicationSchema.parse_obj(
{
"import_path": "ray.serve.tests.test_config_files.pizza.serve_dag",
}
)
client.deploy_app(test_config_2)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "4 pizzas please!"
)
def test_shutdown_remote(start_and_shutdown_ray_cli_function):
"""Check that serve.shutdown() works on a remote Ray cluster.""" """Check that serve.shutdown() works on a remote Ray cluster."""
deploy_serve_script = ( deploy_serve_script = (