[serve] Introduce context.py and client.py (#24067)

Serve stores context state, including the `_INTERNAL_REPLICA_CONTEXT` and the `_global_client` in `api.py`. However, these data structures are referenced throughout the codebase, causing circular dependencies. This change introduces two new files:

* `context.py`
    * Intended to expose process-wide state to internal Serve code as well as `api.py`
    * Stores the `_INTERNAL_REPLICA_CONTEXT` and the `_global_client` global variables
* `client.py`
    * Stores the definition for the Serve `Client` object, now called the `ServeControllerClient`
This commit is contained in:
shrekris-anyscale 2022-04-21 16:35:09 -07:00 committed by GitHub
parent 8c5fe44542
commit b51d0aa8b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 812 additions and 767 deletions

View file

@ -56,7 +56,7 @@ class ServeHead(dashboard_utils.DashboardHeadModule):
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def put_all_deployments(self, req: Request) -> Response:
from ray import serve
from ray.serve.api import internal_get_global_client
from ray.serve.context import get_global_client
from ray.serve.application import Application
app = Application.from_dict(await req.json())
@ -69,7 +69,7 @@ class ServeHead(dashboard_utils.DashboardHeadModule):
all_deployments = serve.list_deployments()
all_names = set(all_deployments.keys())
names_to_delete = all_names.difference(new_names)
internal_get_global_client().delete_deployments(names_to_delete)
get_global_client().delete_deployments(names_to_delete)
return Response()

View file

@ -1,50 +1,31 @@
import asyncio
import atexit
import collections
import inspect
import logging
import random
import re
import time
from dataclasses import dataclass
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
List,
Iterable,
overload,
)
from fastapi import APIRouter, FastAPI
from ray.exceptions import RayActorError
from starlette.requests import Request
from uvicorn.config import Config
from uvicorn.lifespan.on import LifespanOn
from ray.actor import ActorHandle
from ray.serve.common import (
DeploymentInfo,
DeploymentStatus,
DeploymentStatusInfo,
ReplicaTag,
)
from ray.serve.common import DeploymentStatusInfo
from ray.serve.config import (
AutoscalingConfig,
DeploymentConfig,
HTTPOptions,
ReplicaConfig,
)
from ray.serve.constants import (
DEFAULT_CHECKPOINT_PATH,
HTTP_PROXY_TIMEOUT,
SERVE_CONTROLLER_NAME,
MAX_CACHED_HANDLES,
CONTROLLER_MAX_CONCURRENCY,
DEFAULT_HTTP_HOST,
DEFAULT_HTTP_PORT,
@ -52,13 +33,8 @@ from ray.serve.constants import (
from ray.serve.controller import ServeController
from ray.serve.deployment import Deployment
from ray.serve.exceptions import RayServeException
from ray.serve.generated.serve_pb2 import (
DeploymentRoute,
DeploymentRouteList,
DeploymentStatusInfoList,
)
from ray.experimental.dag import DAGNode
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.handle import RayServeHandle
from ray.serve.http_util import ASGIHTTPSender, make_fastapi_class_based_view
from ray.serve.logging_utils import LoggingContext
from ray.serve.utils import (
@ -74,620 +50,15 @@ import ray
from ray import cloudpickle
from ray.serve.deployment_graph import ClassNode, FunctionNode
from ray.serve.application import Application
logger = logging.getLogger(__file__)
_INTERNAL_REPLICA_CONTEXT = None
_global_client: "Client" = None
_UUID_RE = re.compile(
"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12}"
from ray.serve.client import ServeControllerClient, get_controller_namespace
from ray.serve.context import (
set_global_client,
get_global_client,
get_internal_replica_context,
ReplicaContext,
)
# The polling interval for serve client to wait to deployment state
_CLIENT_POLLING_INTERVAL_S: float = 1
def _get_controller_namespace(
detached: bool, _override_controller_namespace: Optional[str] = None
):
"""Gets the controller's namespace.
Args:
detached (bool): Whether serve.start() was called with detached=True
_override_controller_namespace (Optional[str]): When set, this is the
controller's namespace
"""
if _override_controller_namespace is not None:
return _override_controller_namespace
controller_namespace = ray.get_runtime_context().namespace
if not detached:
return controller_namespace
# Start controller in "serve" namespace if detached and currently
# in anonymous namespace.
if _UUID_RE.fullmatch(controller_namespace) is not None:
controller_namespace = "serve"
return controller_namespace
def internal_get_global_client(
_override_controller_namespace: Optional[str] = None,
_health_check_controller: bool = False,
) -> "Client":
"""Gets the global client, which stores the controller's handle.
Args:
_override_controller_namespace (Optional[str]): If None and there's no
cached client, searches for the controller in this namespace.
_health_check_controller (bool): If True, run a health check on the
cached controller if it exists. If the check fails, try reconnecting
to the controller.
Raises:
RayServeException: if there is no Serve controller actor in the
expected namespace.
"""
try:
if _global_client is not None:
if _health_check_controller:
ray.get(_global_client._controller.check_alive.remote())
return _global_client
except RayActorError:
logger.info("The cached controller has died. Reconnecting.")
_set_global_client(None)
return _connect(_override_controller_namespace=_override_controller_namespace)
def _set_global_client(client):
global _global_client
_global_client = client
@dataclass
class ReplicaContext:
"""Stores data for Serve API calls from within deployments."""
deployment: str
replica_tag: ReplicaTag
_internal_controller_name: str
_internal_controller_namespace: str
servable_object: Callable
def _set_internal_replica_context(
deployment: str,
replica_tag: ReplicaTag,
controller_name: str,
controller_namespace: str,
servable_object: Callable,
):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
deployment, replica_tag, controller_name, controller_namespace, servable_object
)
def _ensure_connected(f: Callable) -> Callable:
@wraps(f)
def check(self, *args, **kwargs):
if self._shutdown:
raise RayServeException("Client has already been shut down.")
return f(self, *args, **kwargs)
return check
class Client:
def __init__(
self,
controller: ActorHandle,
controller_name: str,
detached: bool = False,
_override_controller_namespace: Optional[str] = None,
):
self._controller: ServeController = controller
self._controller_name = controller_name
self._detached = detached
self._override_controller_namespace = _override_controller_namespace
self._shutdown = False
self._http_config: HTTPOptions = ray.get(controller.get_http_config.remote())
self._root_url = ray.get(controller.get_root_url.remote())
self._checkpoint_path = ray.get(controller.get_checkpoint_path.remote())
# Each handle has the overhead of long poll client, therefore cached.
self.handle_cache = dict()
self._evicted_handle_keys = set()
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
# throws a nasty stacktrace).
if not self._detached:
def shutdown_serve_client():
self.shutdown()
atexit.register(shutdown_serve_client)
@property
def root_url(self):
return self._root_url
@property
def http_config(self):
return self._http_config
@property
def checkpoint_path(self):
return self._checkpoint_path
def __del__(self):
if not self._detached:
logger.debug(
"Shutting down Ray Serve because client went out of "
"scope. To prevent this, either keep a reference to "
"the client or use serve.start(detached=True)."
)
self.shutdown()
def __reduce__(self):
raise RayServeException(("Ray Serve client cannot be serialized."))
def shutdown(self) -> None:
"""Completely shut down the connected Serve instance.
Shuts down all processes and deletes all state associated with the
instance.
"""
if ray.is_initialized() and not self._shutdown:
ray.get(self._controller.shutdown.remote())
self._wait_for_deployments_shutdown()
ray.kill(self._controller, no_restart=True)
# Wait for the named actor entry gets removed as well.
started = time.time()
while True:
try:
controller_namespace = _get_controller_namespace(
self._detached,
self._override_controller_namespace,
)
ray.get_actor(self._controller_name, namespace=controller_namespace)
if time.time() - started > 5:
logger.warning(
"Waited 5s for Serve to shutdown gracefully but "
"the controller is still not cleaned up. "
"You can ignore this warning if you are shutting "
"down the Ray cluster."
)
break
except ValueError: # actor name is removed
break
self._shutdown = True
def _wait_for_deployments_shutdown(self, timeout_s: int = 60):
"""Waits for all deployments to be shut down and deleted.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = self.get_deployment_statuses()
if len(statuses) == 0:
break
else:
logger.debug(
f"Waiting for shutdown, {len(statuses)} deployments still alive."
)
time.sleep(_CLIENT_POLLING_INTERVAL_S)
else:
live_names = list(statuses.keys())
raise TimeoutError(
f"Shutdown didn't complete after {timeout_s}s. "
f"Deployments still alive: {live_names}."
)
def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1):
"""Waits for the named deployment to enter "HEALTHY" status.
Raises RuntimeError if the deployment enters the "UNHEALTHY" status
instead.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s or timeout_s < 0:
statuses = self.get_deployment_statuses()
try:
status = statuses[name]
except KeyError:
raise RuntimeError(
f"Waiting for deployment {name} to be HEALTHY, "
"but deployment doesn't exist."
) from None
if status.status == DeploymentStatus.HEALTHY:
break
elif status.status == DeploymentStatus.UNHEALTHY:
raise RuntimeError(f"Deployment {name} is UNHEALTHY: {status.message}")
else:
# Guard against new unhandled statuses being added.
assert status.status == DeploymentStatus.UPDATING
logger.debug(
f"Waiting for {name} to be healthy, current status: {status.status}."
)
time.sleep(_CLIENT_POLLING_INTERVAL_S)
else:
raise TimeoutError(
f"Deployment {name} did not become HEALTHY after {timeout_s}s."
)
def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60):
"""Waits for the named deployment to be shut down and deleted.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = self.get_deployment_statuses()
if name not in statuses:
break
else:
curr_status = statuses[name].status
logger.debug(
f"Waiting for {name} to be deleted, current status: {curr_status}."
)
time.sleep(_CLIENT_POLLING_INTERVAL_S)
else:
raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.")
@_ensure_connected
def deploy(
self,
name: str,
deployment_def: Union[Callable, Type[Callable], str],
init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
prev_version: Optional[str] = None,
route_prefix: Optional[str] = None,
url: Optional[str] = None,
_blocking: Optional[bool] = True,
):
controller_deploy_args = self.get_deploy_args(
name=name,
deployment_def=deployment_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options,
config=config,
version=version,
prev_version=prev_version,
route_prefix=route_prefix,
)
updating = ray.get(self._controller.deploy.remote(**controller_deploy_args))
tag = self.log_deployment_update_status(name, version, updating)
if _blocking:
self._wait_for_deployment_healthy(name)
self.log_deployment_ready(name, version, url, tag)
@_ensure_connected
def deploy_group(self, deployments: List[Dict], _blocking: bool = True):
deployment_args_list = []
for deployment in deployments:
deployment_args_list.append(
self.get_deploy_args(
deployment["name"],
deployment["func_or_class"],
deployment["init_args"],
deployment["init_kwargs"],
ray_actor_options=deployment["ray_actor_options"],
config=deployment["config"],
version=deployment["version"],
prev_version=deployment["prev_version"],
route_prefix=deployment["route_prefix"],
)
)
updating_list = ray.get(
self._controller.deploy_group.remote(deployment_args_list)
)
tags = []
for i, updating in enumerate(updating_list):
deployment = deployments[i]
name, version = deployment["name"], deployment["version"]
tags.append(self.log_deployment_update_status(name, version, updating))
for i, deployment in enumerate(deployments):
name = deployment["name"]
url = deployment["url"]
if _blocking:
self._wait_for_deployment_healthy(name)
self.log_deployment_ready(name, version, url, tags[i])
@_ensure_connected
def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None:
ray.get(self._controller.delete_deployments.remote(names))
if blocking:
for name in names:
self._wait_for_deployment_deleted(name)
@_ensure_connected
def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]:
deployment_route = DeploymentRoute.FromString(
ray.get(self._controller.get_deployment_info.remote(name))
)
return (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)
@_ensure_connected
def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]:
deployment_route_list = DeploymentRouteList.FromString(
ray.get(self._controller.list_deployments.remote())
)
return {
deployment_route.deployment_info.name: (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)
for deployment_route in deployment_route_list.deployment_routes
}
@_ensure_connected
def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]:
proto = DeploymentStatusInfoList.FromString(
ray.get(self._controller.get_deployment_statuses.remote())
)
return {
deployment_status_info.name: DeploymentStatusInfo.from_proto(
deployment_status_info
)
for deployment_status_info in proto.deployment_status_infos
}
@_ensure_connected
def get_handle(
self,
deployment_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True,
_internal_pickled_http_request: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.
Args:
deployment_name (str): A registered service deployment.
missing_ok (bool): If true, then Serve won't check the deployment
is registered. False by default.
sync (bool): If true, then Serve will return a ServeHandle that
works everywhere. Otherwise, Serve will return a ServeHandle
that's only usable in asyncio loop.
Returns:
RayServeHandle
"""
cache_key = (deployment_name, missing_ok, sync)
if cache_key in self.handle_cache:
cached_handle = self.handle_cache[cache_key]
if cached_handle.is_polling and cached_handle.is_same_loop:
return cached_handle
all_endpoints = ray.get(self._controller.get_all_endpoints.remote())
if not missing_ok and deployment_name not in all_endpoints:
raise KeyError(f"Deployment '{deployment_name}' does not exist.")
try:
asyncio_loop_running = asyncio.get_event_loop().is_running()
except RuntimeError as ex:
if "There is no current event loop in thread" in str(ex):
asyncio_loop_running = False
else:
raise ex
if asyncio_loop_running and sync:
logger.warning(
"You are retrieving a sync handle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance. Learn more at https://docs.ray.io/en/master/"
"serve/http-servehandle.html#sync-and-async-handles"
)
if not asyncio_loop_running and not sync:
logger.warning(
"You are retrieving an async handle outside an asyncio loop. "
"You should make sure client.get_handle is called inside a "
"running event loop. Or call client.get_handle(.., sync=True) "
"to create sync handle. Learn more at https://docs.ray.io/en/"
"master/serve/http-servehandle.html#sync-and-async-handles"
)
if sync:
handle = RayServeSyncHandle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
)
else:
handle = RayServeHandle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
)
self.handle_cache[cache_key] = handle
if cache_key in self._evicted_handle_keys:
logger.warning(
"You just got a ServeHandle that was evicted from internal "
"cache. This means you are getting too many ServeHandles in "
"the same process, this will bring down Serve's performance. "
"Please post a github issue at "
"https://github.com/ray-project/ray/issues to let the Serve "
"team to find workaround for your use case."
)
if len(self.handle_cache) > MAX_CACHED_HANDLES:
# Perform random eviction to keep the handle cache from growing
# infinitely. We used use WeakValueDictionary but hit
# https://github.com/ray-project/ray/issues/18980.
evict_key = random.choice(list(self.handle_cache.keys()))
self._evicted_handle_keys.add(evict_key)
self.handle_cache.pop(evict_key)
return handle
@_ensure_connected
def get_deploy_args(
self,
name: str,
deployment_def: Union[Callable, Type[Callable], str],
init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
prev_version: Optional[str] = None,
route_prefix: Optional[str] = None,
) -> Dict:
"""
Takes a deployment's configuration, and returns the arguments needed
for the controller to deploy it.
"""
if config is None:
config = {}
if ray_actor_options is None:
ray_actor_options = {}
curr_job_env = ray.get_runtime_context().runtime_env
if "runtime_env" in ray_actor_options:
# It is illegal to set field working_dir to None.
if curr_job_env.get("working_dir") is not None:
ray_actor_options["runtime_env"].setdefault(
"working_dir", curr_job_env.get("working_dir")
)
else:
ray_actor_options["runtime_env"] = curr_job_env
replica_config = ReplicaConfig(
deployment_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options,
)
if isinstance(config, dict):
deployment_config = DeploymentConfig.parse_obj(config)
elif isinstance(config, DeploymentConfig):
deployment_config = config
else:
raise TypeError("config must be a DeploymentConfig or a dictionary.")
deployment_config.version = version
deployment_config.prev_version = prev_version
if (
deployment_config.autoscaling_config is not None
and deployment_config.max_concurrent_queries
< deployment_config.autoscaling_config.target_num_ongoing_requests_per_replica # noqa: E501
):
logger.warning(
"Autoscaling will never happen, "
"because 'max_concurrent_queries' is less than "
"'target_num_ongoing_requests_per_replica' now."
)
controller_deploy_args = {
"name": name,
"deployment_config_proto_bytes": deployment_config.to_proto_bytes(),
"replica_config_proto_bytes": replica_config.to_proto_bytes(),
"route_prefix": route_prefix,
"deployer_job_id": ray.get_runtime_context().job_id,
}
return controller_deploy_args
@_ensure_connected
def log_deployment_update_status(
self, name: str, version: str, updating: bool
) -> str:
tag = f"component=serve deployment={name}"
if updating:
msg = f"Updating deployment '{name}'"
if version is not None:
msg += f" to version '{version}'"
logger.info(f"{msg}. {tag}")
else:
logger.info(
f"Deployment '{name}' is already at version "
f"'{version}', not updating. {tag}"
)
return tag
@_ensure_connected
def log_deployment_ready(self, name: str, version: str, url: str, tag: str) -> None:
if url is not None:
url_part = f" at `{url}`"
else:
url_part = ""
logger.info(
f"Deployment '{name}{':'+version if version else ''}' is ready"
f"{url_part}. {tag}"
)
def _check_http_and_checkpoint_options(
client: Client,
http_options: Union[dict, HTTPOptions],
checkpoint_path: str,
) -> None:
if checkpoint_path and checkpoint_path != client.checkpoint_path:
logger.warning(
f"The new client checkpoint path '{checkpoint_path}' "
f"is different from the existing one '{client.checkpoint_path}'. "
"The new checkpoint path is ignored."
)
if http_options:
client_http_options = client.http_config
new_http_options = (
http_options
if isinstance(http_options, HTTPOptions)
else HTTPOptions.parse_obj(http_options)
)
different_fields = []
all_http_option_fields = new_http_options.__dict__
for field in all_http_option_fields:
if getattr(new_http_options, field) != getattr(client_http_options, field):
different_fields.append(field)
if len(different_fields):
logger.warning(
"The new client HTTP config differs from the existing one "
f"in the following fields: {different_fields}. "
"The new HTTP config is ignored."
)
logger = logging.getLogger(__file__)
@PublicAPI(stability="beta")
@ -698,7 +69,7 @@ def start(
_checkpoint_path: str = DEFAULT_CHECKPOINT_PATH,
_override_controller_namespace: Optional[str] = None,
**kwargs,
) -> Client:
) -> ServeControllerClient:
"""Initialize a serve instance.
By default, the instance will be scoped to the lifetime of the returned
@ -749,12 +120,12 @@ def start(
if not ray.is_initialized():
ray.init(namespace="serve")
controller_namespace = _get_controller_namespace(
controller_namespace = get_controller_namespace(
detached, _override_controller_namespace=_override_controller_namespace
)
try:
client = internal_get_global_client(
client = get_global_client(
_override_controller_namespace=_override_controller_namespace,
_health_check_controller=True,
)
@ -808,13 +179,13 @@ def start(
"HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s."
)
client = Client(
client = ServeControllerClient(
controller,
controller_name,
detached=detached,
_override_controller_namespace=_override_controller_namespace,
)
_set_global_client(client)
set_global_client(client)
logger.info(
f"Started{' detached ' if detached else ' '}Serve instance in "
f"namespace '{controller_namespace}'."
@ -822,62 +193,6 @@ def start(
return client
def _connect(_override_controller_namespace: Optional[str] = None) -> Client:
"""Connect to an existing Serve instance on this Ray cluster.
If calling from the driver program, the Serve instance on this Ray cluster
must first have been initialized using `serve.start(detached=True)`.
If called from within a replica, this will connect to the same Serve
instance that the replica is running in.
Args:
_override_controller_namespace (Optional[str]): The namespace to use
when looking for the controller. If None, Serve recalculates the
controller's namespace using _get_controller_namespace().
Raises:
RayServeException: if there is no Serve controller actor in the
expected namespace.
"""
# Initialize ray if needed.
ray.worker.global_worker.filter_logs_by_job = False
if not ray.is_initialized():
ray.init(namespace="serve")
# When running inside of a replica, _INTERNAL_REPLICA_CONTEXT is set to
# ensure that the correct instance is connected to.
if _INTERNAL_REPLICA_CONTEXT is None:
controller_name = SERVE_CONTROLLER_NAME
controller_namespace = _get_controller_namespace(
detached=True, _override_controller_namespace=_override_controller_namespace
)
else:
controller_name = _INTERNAL_REPLICA_CONTEXT._internal_controller_name
controller_namespace = _INTERNAL_REPLICA_CONTEXT._internal_controller_namespace
# Try to get serve controller if it exists
try:
controller = ray.get_actor(controller_name, namespace=controller_namespace)
except ValueError:
raise RayServeException(
"There is no "
"instance running on this Ray cluster. Please "
"call `serve.start(detached=True) to start "
"one."
)
client = Client(
controller,
controller_name,
detached=True,
_override_controller_namespace=_override_controller_namespace,
)
_set_global_client(client)
return client
@PublicAPI
def shutdown() -> None:
"""Completely shut down the connected Serve instance.
@ -887,7 +202,7 @@ def shutdown() -> None:
"""
try:
client = internal_get_global_client()
client = get_global_client()
except RayServeException:
logger.info(
"Nothing to shut down. There's no Serve application "
@ -896,7 +211,7 @@ def shutdown() -> None:
return
client.shutdown()
_set_global_client(None)
set_global_client(None)
@PublicAPI
@ -917,13 +232,14 @@ def get_replica_context() -> ReplicaContext:
>>> # deployment_name#krcwoa
>>> serve.get_replica_context().replica_tag # doctest: +SKIP
"""
if _INTERNAL_REPLICA_CONTEXT is None:
internal_replica_context = get_internal_replica_context()
if internal_replica_context is None:
raise RayServeException(
"`serve.get_replica_context()` "
"may only be called from within a "
"Ray Serve deployment."
)
return _INTERNAL_REPLICA_CONTEXT
return internal_replica_context
@PublicAPI(stability="beta")
@ -1180,7 +496,7 @@ def get_deployment(name: str) -> Deployment:
(
deployment_info,
route_prefix,
) = internal_get_global_client().get_deployment_info(name)
) = get_global_client().get_deployment_info(name)
except KeyError:
raise KeyError(
f"Deployment {name} was not found. Did you call Deployment.deploy()?"
@ -1204,7 +520,7 @@ def list_deployments() -> Dict[str, Deployment]:
Dictionary maps deployment name to Deployment objects.
"""
infos = internal_get_global_client().list_deployments()
infos = get_global_client().list_deployments()
deployments = {}
for name, (deployment_info, route_prefix) in infos.items():
@ -1241,7 +557,7 @@ def get_deployment_statuses() -> Dict[str, DeploymentStatusInfo]:
status and a message explaining the status.
"""
return internal_get_global_client().get_deployment_statuses()
return get_global_client().get_deployment_statuses()
@PublicAPI(stability="alpha")
@ -1359,3 +675,36 @@ def build(target: Union[ClassNode, FunctionNode]) -> Application:
# TODO(edoakes): this should accept host and port, but we don't
# currently support them in the REST API.
return Application(pipeline_build(target))
def _check_http_and_checkpoint_options(
client: ServeControllerClient,
http_options: Union[dict, HTTPOptions],
checkpoint_path: str,
) -> None:
if checkpoint_path and checkpoint_path != client.checkpoint_path:
logger.warning(
f"The new client checkpoint path '{checkpoint_path}' "
f"is different from the existing one '{client.checkpoint_path}'. "
"The new checkpoint path is ignored."
)
if http_options:
client_http_options = client.http_config
new_http_options = (
http_options
if isinstance(http_options, HTTPOptions)
else HTTPOptions.parse_obj(http_options)
)
different_fields = []
all_http_option_fields = new_http_options.__dict__
for field in all_http_option_fields:
if getattr(new_http_options, field) != getattr(client_http_options, field):
different_fields.append(field)
if len(different_fields):
logger.warning(
"The new client HTTP config differs from the existing one "
f"in the following fields: {different_fields}. "
"The new HTTP config is ignored."
)

555
python/ray/serve/client.py Normal file
View file

@ -0,0 +1,555 @@
import asyncio
import atexit
import random
import logging
import time
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
List,
Iterable,
)
import ray
from ray.actor import ActorHandle
from ray.serve.common import (
DeploymentInfo,
DeploymentStatus,
DeploymentStatusInfo,
)
from ray.serve.config import (
DeploymentConfig,
HTTPOptions,
ReplicaConfig,
)
from ray.serve.constants import (
MAX_CACHED_HANDLES,
CLIENT_POLLING_INTERVAL_S,
ANONYMOUS_NAMESPACE_PATTERN,
)
from ray.serve.controller import ServeController
from ray.serve.exceptions import RayServeException
from ray.serve.generated.serve_pb2 import (
DeploymentRoute,
DeploymentRouteList,
DeploymentStatusInfoList,
)
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
logger = logging.getLogger(__file__)
def _ensure_connected(f: Callable) -> Callable:
@wraps(f)
def check(self, *args, **kwargs):
if self._shutdown:
raise RayServeException("Client has already been shut down.")
return f(self, *args, **kwargs)
return check
class ServeControllerClient:
def __init__(
self,
controller: ActorHandle,
controller_name: str,
detached: bool = False,
_override_controller_namespace: Optional[str] = None,
):
self._controller: ServeController = controller
self._controller_name = controller_name
self._detached = detached
self._override_controller_namespace = _override_controller_namespace
self._shutdown = False
self._http_config: HTTPOptions = ray.get(controller.get_http_config.remote())
self._root_url = ray.get(controller.get_root_url.remote())
self._checkpoint_path = ray.get(controller.get_checkpoint_path.remote())
# Each handle has the overhead of long poll client, therefore cached.
self.handle_cache = dict()
self._evicted_handle_keys = set()
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
# throws a nasty stacktrace).
if not self._detached:
def shutdown_serve_client():
self.shutdown()
atexit.register(shutdown_serve_client)
@property
def root_url(self):
return self._root_url
@property
def http_config(self):
return self._http_config
@property
def checkpoint_path(self):
return self._checkpoint_path
def __del__(self):
if not self._detached:
logger.debug(
"Shutting down Ray Serve because client went out of "
"scope. To prevent this, either keep a reference to "
"the client or use serve.start(detached=True)."
)
self.shutdown()
def __reduce__(self):
raise RayServeException(("Ray Serve client cannot be serialized."))
def shutdown(self) -> None:
"""Completely shut down the connected Serve instance.
Shuts down all processes and deletes all state associated with the
instance.
"""
if ray.is_initialized() and not self._shutdown:
ray.get(self._controller.shutdown.remote())
self._wait_for_deployments_shutdown()
ray.kill(self._controller, no_restart=True)
# Wait for the named actor entry gets removed as well.
started = time.time()
while True:
try:
controller_namespace = get_controller_namespace(
self._detached,
self._override_controller_namespace,
)
ray.get_actor(self._controller_name, namespace=controller_namespace)
if time.time() - started > 5:
logger.warning(
"Waited 5s for Serve to shutdown gracefully but "
"the controller is still not cleaned up. "
"You can ignore this warning if you are shutting "
"down the Ray cluster."
)
break
except ValueError: # actor name is removed
break
self._shutdown = True
def _wait_for_deployments_shutdown(self, timeout_s: int = 60):
"""Waits for all deployments to be shut down and deleted.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = self.get_deployment_statuses()
if len(statuses) == 0:
break
else:
logger.debug(
f"Waiting for shutdown, {len(statuses)} deployments still alive."
)
time.sleep(CLIENT_POLLING_INTERVAL_S)
else:
live_names = list(statuses.keys())
raise TimeoutError(
f"Shutdown didn't complete after {timeout_s}s. "
f"Deployments still alive: {live_names}."
)
def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1):
"""Waits for the named deployment to enter "HEALTHY" status.
Raises RuntimeError if the deployment enters the "UNHEALTHY" status
instead.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s or timeout_s < 0:
statuses = self.get_deployment_statuses()
try:
status = statuses[name]
except KeyError:
raise RuntimeError(
f"Waiting for deployment {name} to be HEALTHY, "
"but deployment doesn't exist."
) from None
if status.status == DeploymentStatus.HEALTHY:
break
elif status.status == DeploymentStatus.UNHEALTHY:
raise RuntimeError(f"Deployment {name} is UNHEALTHY: {status.message}")
else:
# Guard against new unhandled statuses being added.
assert status.status == DeploymentStatus.UPDATING
logger.debug(
f"Waiting for {name} to be healthy, current status: {status.status}."
)
time.sleep(CLIENT_POLLING_INTERVAL_S)
else:
raise TimeoutError(
f"Deployment {name} did not become HEALTHY after {timeout_s}s."
)
def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60):
"""Waits for the named deployment to be shut down and deleted.
Raises TimeoutError if this doesn't happen before timeout_s.
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = self.get_deployment_statuses()
if name not in statuses:
break
else:
curr_status = statuses[name].status
logger.debug(
f"Waiting for {name} to be deleted, current status: {curr_status}."
)
time.sleep(CLIENT_POLLING_INTERVAL_S)
else:
raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.")
@_ensure_connected
def deploy(
self,
name: str,
deployment_def: Union[Callable, Type[Callable], str],
init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
prev_version: Optional[str] = None,
route_prefix: Optional[str] = None,
url: Optional[str] = None,
_blocking: Optional[bool] = True,
):
controller_deploy_args = self.get_deploy_args(
name=name,
deployment_def=deployment_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options,
config=config,
version=version,
prev_version=prev_version,
route_prefix=route_prefix,
)
updating = ray.get(self._controller.deploy.remote(**controller_deploy_args))
tag = self.log_deployment_update_status(name, version, updating)
if _blocking:
self._wait_for_deployment_healthy(name)
self.log_deployment_ready(name, version, url, tag)
@_ensure_connected
def deploy_group(self, deployments: List[Dict], _blocking: bool = True):
deployment_args_list = []
for deployment in deployments:
deployment_args_list.append(
self.get_deploy_args(
deployment["name"],
deployment["func_or_class"],
deployment["init_args"],
deployment["init_kwargs"],
ray_actor_options=deployment["ray_actor_options"],
config=deployment["config"],
version=deployment["version"],
prev_version=deployment["prev_version"],
route_prefix=deployment["route_prefix"],
)
)
updating_list = ray.get(
self._controller.deploy_group.remote(deployment_args_list)
)
tags = []
for i, updating in enumerate(updating_list):
deployment = deployments[i]
name, version = deployment["name"], deployment["version"]
tags.append(self.log_deployment_update_status(name, version, updating))
for i, deployment in enumerate(deployments):
name = deployment["name"]
url = deployment["url"]
if _blocking:
self._wait_for_deployment_healthy(name)
self.log_deployment_ready(name, version, url, tags[i])
@_ensure_connected
def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None:
ray.get(self._controller.delete_deployments.remote(names))
if blocking:
for name in names:
self._wait_for_deployment_deleted(name)
@_ensure_connected
def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]:
deployment_route = DeploymentRoute.FromString(
ray.get(self._controller.get_deployment_info.remote(name))
)
return (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)
@_ensure_connected
def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]:
deployment_route_list = DeploymentRouteList.FromString(
ray.get(self._controller.list_deployments.remote())
)
return {
deployment_route.deployment_info.name: (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)
for deployment_route in deployment_route_list.deployment_routes
}
@_ensure_connected
def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]:
proto = DeploymentStatusInfoList.FromString(
ray.get(self._controller.get_deployment_statuses.remote())
)
return {
deployment_status_info.name: DeploymentStatusInfo.from_proto(
deployment_status_info
)
for deployment_status_info in proto.deployment_status_infos
}
@_ensure_connected
def get_handle(
self,
deployment_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True,
_internal_pickled_http_request: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.
Args:
deployment_name (str): A registered service deployment.
missing_ok (bool): If true, then Serve won't check the deployment
is registered. False by default.
sync (bool): If true, then Serve will return a ServeHandle that
works everywhere. Otherwise, Serve will return a ServeHandle
that's only usable in asyncio loop.
Returns:
RayServeHandle
"""
cache_key = (deployment_name, missing_ok, sync)
if cache_key in self.handle_cache:
cached_handle = self.handle_cache[cache_key]
if cached_handle.is_polling and cached_handle.is_same_loop:
return cached_handle
all_endpoints = ray.get(self._controller.get_all_endpoints.remote())
if not missing_ok and deployment_name not in all_endpoints:
raise KeyError(f"Deployment '{deployment_name}' does not exist.")
try:
asyncio_loop_running = asyncio.get_event_loop().is_running()
except RuntimeError as ex:
if "There is no current event loop in thread" in str(ex):
asyncio_loop_running = False
else:
raise ex
if asyncio_loop_running and sync:
logger.warning(
"You are retrieving a sync handle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance. Learn more at https://docs.ray.io/en/master/"
"serve/http-servehandle.html#sync-and-async-handles"
)
if not asyncio_loop_running and not sync:
logger.warning(
"You are retrieving an async handle outside an asyncio loop. "
"You should make sure client.get_handle is called inside a "
"running event loop. Or call client.get_handle(.., sync=True) "
"to create sync handle. Learn more at https://docs.ray.io/en/"
"master/serve/http-servehandle.html#sync-and-async-handles"
)
if sync:
handle = RayServeSyncHandle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
)
else:
handle = RayServeHandle(
self._controller,
deployment_name,
_internal_pickled_http_request=_internal_pickled_http_request,
)
self.handle_cache[cache_key] = handle
if cache_key in self._evicted_handle_keys:
logger.warning(
"You just got a ServeHandle that was evicted from internal "
"cache. This means you are getting too many ServeHandles in "
"the same process, this will bring down Serve's performance. "
"Please post a github issue at "
"https://github.com/ray-project/ray/issues to let the Serve "
"team to find workaround for your use case."
)
if len(self.handle_cache) > MAX_CACHED_HANDLES:
# Perform random eviction to keep the handle cache from growing
# infinitely. We used use WeakValueDictionary but hit
# https://github.com/ray-project/ray/issues/18980.
evict_key = random.choice(list(self.handle_cache.keys()))
self._evicted_handle_keys.add(evict_key)
self.handle_cache.pop(evict_key)
return handle
@_ensure_connected
def get_deploy_args(
self,
name: str,
deployment_def: Union[Callable, Type[Callable], str],
init_args: Tuple[Any],
init_kwargs: Dict[Any, Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
prev_version: Optional[str] = None,
route_prefix: Optional[str] = None,
) -> Dict:
"""
Takes a deployment's configuration, and returns the arguments needed
for the controller to deploy it.
"""
if config is None:
config = {}
if ray_actor_options is None:
ray_actor_options = {}
curr_job_env = ray.get_runtime_context().runtime_env
if "runtime_env" in ray_actor_options:
# It is illegal to set field working_dir to None.
if curr_job_env.get("working_dir") is not None:
ray_actor_options["runtime_env"].setdefault(
"working_dir", curr_job_env.get("working_dir")
)
else:
ray_actor_options["runtime_env"] = curr_job_env
replica_config = ReplicaConfig(
deployment_def,
init_args=init_args,
init_kwargs=init_kwargs,
ray_actor_options=ray_actor_options,
)
if isinstance(config, dict):
deployment_config = DeploymentConfig.parse_obj(config)
elif isinstance(config, DeploymentConfig):
deployment_config = config
else:
raise TypeError("config must be a DeploymentConfig or a dictionary.")
deployment_config.version = version
deployment_config.prev_version = prev_version
if (
deployment_config.autoscaling_config is not None
and deployment_config.max_concurrent_queries
< deployment_config.autoscaling_config.target_num_ongoing_requests_per_replica # noqa: E501
):
logger.warning(
"Autoscaling will never happen, "
"because 'max_concurrent_queries' is less than "
"'target_num_ongoing_requests_per_replica' now."
)
controller_deploy_args = {
"name": name,
"deployment_config_proto_bytes": deployment_config.to_proto_bytes(),
"replica_config_proto_bytes": replica_config.to_proto_bytes(),
"route_prefix": route_prefix,
"deployer_job_id": ray.get_runtime_context().job_id,
}
return controller_deploy_args
@_ensure_connected
def log_deployment_update_status(
self, name: str, version: str, updating: bool
) -> str:
tag = f"component=serve deployment={name}"
if updating:
msg = f"Updating deployment '{name}'"
if version is not None:
msg += f" to version '{version}'"
logger.info(f"{msg}. {tag}")
else:
logger.info(
f"Deployment '{name}' is already at version "
f"'{version}', not updating. {tag}"
)
return tag
@_ensure_connected
def log_deployment_ready(self, name: str, version: str, url: str, tag: str) -> None:
if url is not None:
url_part = f" at `{url}`"
else:
url_part = ""
logger.info(
f"Deployment '{name}{':'+version if version else ''}' is ready"
f"{url_part}. {tag}"
)
def get_controller_namespace(
detached: bool, _override_controller_namespace: Optional[str] = None
):
"""Gets the controller's namespace.
Args:
detached (bool): Whether serve.start() was called with detached=True
_override_controller_namespace (Optional[str]): When set, this is the
controller's namespace
"""
if _override_controller_namespace is not None:
return _override_controller_namespace
controller_namespace = ray.get_runtime_context().namespace
if not detached:
return controller_namespace
# Start controller in "serve" namespace if detached and currently
# in anonymous namespace.
if ANONYMOUS_NAMESPACE_PATTERN.fullmatch(controller_namespace) is not None:
controller_namespace = "serve"
return controller_namespace

View file

@ -1,4 +1,5 @@
from enum import Enum
import re
#: Used for debugging to turn on DEBUG-level logs
DEBUG_LOG_ENV_VAR = "SERVE_DEBUG_LOG"
@ -86,6 +87,15 @@ REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD = 3
# Key used to idenfity given json represents a serialized RayServeHandle
SERVE_HANDLE_JSON_KEY = "__SerializedServeHandle__"
# The time in seconds that the Serve client waits before rechecking deployment state
CLIENT_POLLING_INTERVAL_S: float = 1
# Regex pattern for anonymous namespace. Should match the pattern used in
# src/ray/gcs/gcs_server/gcs_actor_manager.cc's is_uuid() method.
ANONYMOUS_NAMESPACE_PATTERN = re.compile(
"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12}"
)
class ServeHandleType(str, Enum):
SYNC = "SYNC"

141
python/ray/serve/context.py Normal file
View file

@ -0,0 +1,141 @@
"""
This file stores global state for a Serve application. Deployment replicas
can use this state to access metadata or the Serve controller.
"""
import logging
from dataclasses import dataclass
from typing import Callable, Optional
import ray
from ray.exceptions import RayActorError
from ray.serve.common import ReplicaTag
from ray.serve.constants import SERVE_CONTROLLER_NAME
from ray.serve.exceptions import RayServeException
from ray.serve.client import ServeControllerClient, get_controller_namespace
logger = logging.getLogger(__file__)
_INTERNAL_REPLICA_CONTEXT: "ReplicaContext" = None
_global_client: ServeControllerClient = None
@dataclass
class ReplicaContext:
"""Stores data for Serve API calls from within deployments."""
deployment: str
replica_tag: ReplicaTag
_internal_controller_name: str
_internal_controller_namespace: str
servable_object: Callable
def get_global_client(
_override_controller_namespace: Optional[str] = None,
_health_check_controller: bool = False,
) -> ServeControllerClient:
"""Gets the global client, which stores the controller's handle.
Args:
_override_controller_namespace (Optional[str]): If None and there's no
cached client, searches for the controller in this namespace.
_health_check_controller (bool): If True, run a health check on the
cached controller if it exists. If the check fails, try reconnecting
to the controller.
Raises:
RayServeException: if there is no Serve controller actor in the
expected namespace.
"""
try:
if _global_client is not None:
if _health_check_controller:
ray.get(_global_client._controller.check_alive.remote())
return _global_client
except RayActorError:
logger.info("The cached controller has died. Reconnecting.")
set_global_client(None)
return _connect(_override_controller_namespace=_override_controller_namespace)
def set_global_client(client):
global _global_client
_global_client = client
def get_internal_replica_context():
return _INTERNAL_REPLICA_CONTEXT
def set_internal_replica_context(
deployment: str,
replica_tag: ReplicaTag,
controller_name: str,
controller_namespace: str,
servable_object: Callable,
):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
deployment, replica_tag, controller_name, controller_namespace, servable_object
)
def _connect(
_override_controller_namespace: Optional[str] = None,
) -> ServeControllerClient:
"""Connect to an existing Serve instance on this Ray cluster.
If calling from the driver program, the Serve instance on this Ray cluster
must first have been initialized using `serve.start(detached=True)`.
If called from within a replica, this will connect to the same Serve
instance that the replica is running in.
Args:
_override_controller_namespace (Optional[str]): The namespace to use
when looking for the controller. If None, Serve recalculates the
controller's namespace using get_controller_namespace().
Raises:
RayServeException: if there is no Serve controller actor in the
expected namespace.
"""
# Initialize ray if needed.
ray.worker.global_worker.filter_logs_by_job = False
if not ray.is_initialized():
ray.init(namespace="serve")
# When running inside of a replica, _INTERNAL_REPLICA_CONTEXT is set to
# ensure that the correct instance is connected to.
if _INTERNAL_REPLICA_CONTEXT is None:
controller_name = SERVE_CONTROLLER_NAME
controller_namespace = get_controller_namespace(
detached=True, _override_controller_namespace=_override_controller_namespace
)
else:
controller_name = _INTERNAL_REPLICA_CONTEXT._internal_controller_name
controller_namespace = _INTERNAL_REPLICA_CONTEXT._internal_controller_namespace
# Try to get serve controller if it exists
try:
controller = ray.get_actor(controller_name, namespace=controller_namespace)
except ValueError:
raise RayServeException(
"There is no "
"instance running on this Ray cluster. Please "
"call `serve.start(detached=True) to start "
"one."
)
client = ServeControllerClient(
controller,
controller_name,
detached=True,
_override_controller_namespace=_override_controller_namespace,
)
set_global_client(client)
return client

View file

@ -8,6 +8,8 @@ from typing import (
Tuple,
Union,
)
from ray.serve.context import get_global_client
from ray.experimental.dag.class_node import ClassNode
from ray.experimental.dag.function_node import FunctionNode
from ray.serve.config import (
@ -23,10 +25,6 @@ from ray.serve.schema import (
)
# TODO (shrekris-anyscale): remove following dependencies on api.py:
# - internal_get_global_client
@PublicAPI
class Deployment:
def __init__(
@ -175,9 +173,7 @@ class Deployment:
# this deployment is not exposed over HTTP
return None
from ray.serve.api import internal_get_global_client
return internal_get_global_client().root_url + self.route_prefix
return get_global_client().root_url + self.route_prefix
def __call__(self):
raise RuntimeError(
@ -237,9 +233,7 @@ class Deployment:
if len(init_kwargs) == 0 and self._init_kwargs is not None:
init_kwargs = self._init_kwargs
from ray.serve.api import internal_get_global_client
return internal_get_global_client().deploy(
return get_global_client().deploy(
self._name,
self._func_or_class,
init_args,
@ -257,9 +251,7 @@ class Deployment:
def delete(self):
"""Delete this deployment."""
from ray.serve.api import internal_get_global_client
return internal_get_global_client().delete_deployments([self._name])
return get_global_client().delete_deployments([self._name])
@PublicAPI
def get_handle(
@ -277,11 +269,7 @@ class Deployment:
ServeHandle
"""
from ray.serve.api import internal_get_global_client
return internal_get_global_client().get_handle(
self._name, missing_ok=True, sync=sync
)
return get_global_client().get_handle(self._name, missing_ok=True, sync=sync)
@PublicAPI
def options(

View file

@ -147,7 +147,7 @@ class ActorReplicaWrapper:
self._placement_group_name = self._actor_name + "_placement_group"
self._detached = detached
self._controller_name = controller_name
self._controller_namespace = ray.serve.api._get_controller_namespace(
self._controller_namespace = ray.serve.client.get_controller_namespace(
detached, _override_controller_namespace=_override_controller_namespace
)

View file

@ -321,7 +321,7 @@ def serve_handle_from_json_dict(d: Dict[str, str]) -> RayServeHandle:
if SERVE_HANDLE_JSON_KEY not in d:
raise ValueError(f"dict must contain {SERVE_HANDLE_JSON_KEY} key.")
return serve.api.internal_get_global_client().get_handle(
return serve.context.get_global_client().get_handle(
d["deployment_name"],
sync=d[SERVE_HANDLE_JSON_KEY] == ServeHandleType.SYNC,
missing_ok=True,

View file

@ -197,7 +197,7 @@ class HTTPProxy:
):
# Set the controller name so that serve will connect to the
# controller instance this proxy is running in.
ray.serve.api._set_internal_replica_context(
ray.serve.context.set_internal_replica_context(
None, None, controller_name, controller_namespace, None
)
@ -205,7 +205,7 @@ class HTTPProxy:
self.route_info: Dict[str, EndpointTag] = dict()
def get_handle(name):
return serve.api.internal_get_global_client().get_handle(
return serve.context.get_global_client().get_handle(
name,
sync=False,
missing_ok=True,

View file

@ -35,7 +35,7 @@ class HTTPState:
_start_proxies_on_init: bool = True,
):
self._controller_name = controller_name
self._controller_namespace = ray.serve.api._get_controller_namespace(
self._controller_namespace = ray.serve.client.get_controller_namespace(
detached, _override_controller_namespace=_override_controller_namespace
)
self._detached = detached

View file

@ -116,7 +116,7 @@ def create_replica_wrapper(
# Set the controller name so that serve.connect() in the user's
# code will connect to the instance that this deployment is running
# in.
ray.serve.api._set_internal_replica_context(
ray.serve.context.set_internal_replica_context(
deployment_name,
replica_tag,
controller_name,
@ -148,7 +148,7 @@ def create_replica_wrapper(
await sync_to_async(_callable.__init__)(*init_args, **init_kwargs)
# Setting the context again to update the servable_object.
ray.serve.api._set_internal_replica_context(
ray.serve.context.set_internal_replica_context(
deployment_name,
replica_tag,
controller_name,

View file

@ -22,7 +22,7 @@ from ray.dashboard.modules.dashboard_sdk import parse_runtime_env_args
from ray.dashboard.modules.serve.sdk import ServeSubmissionClient
from ray.autoscaler._private.cli_logger import cli_logger
from ray.serve.api import build as build_app
from ray.serve.api import Application
from ray.serve.application import Application
from ray.serve.deployment_graph import (
FunctionNode,
ClassNode,
@ -140,7 +140,7 @@ def shutdown(address: str, namespace: str):
address=address,
namespace=namespace,
)
serve.api._connect()
serve.context._connect()
serve.shutdown()

View file

@ -68,7 +68,7 @@ def test_recover_start_from_replica_actor_names(serve_instance):
), "Should have two running replicas fetched from ray API."
# Kill controller and wait for endpoint to be available again
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
for _ in range(10):
response = request_with_retries(
"/recover_start_from_replica_actor_names/", timeout=30
@ -171,7 +171,7 @@ def test_recover_rolling_update_from_replica_actor_names(serve_instance):
responses2, blocking2 = make_nonblocking_calls({"1": 1}, expect_blocking=True)
assert list(responses2["1"])[0] in pids1
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
# Redeploy new version. Since there is one replica blocking, only one new
# replica should be started up.
@ -181,7 +181,7 @@ def test_recover_rolling_update_from_replica_actor_names(serve_instance):
client._wait_for_deployment_healthy(V2.name, timeout_s=0.1)
responses3, blocking3 = make_nonblocking_calls({"1": 1}, expect_blocking=True)
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
# Signal the original call to exit.
ray.get(signal.send.remote())

View file

@ -44,7 +44,7 @@ def test_scale_up(ray_cluster):
return pids
serve.start(detached=True)
client = serve.api._connect()
client = serve.context._connect()
D.deploy()
pids1 = get_pids(1)

View file

@ -10,7 +10,7 @@ from ray.serve.generated.serve_pb2 import DeploymentRoute
def test_redeploy_start_time(serve_instance):
"""Check that redeploying a deployment doesn't reset its start time."""
controller = serve.api._global_client._controller
controller = serve.context._global_client._controller
@serve.deployment
def test(_):

View file

@ -34,7 +34,7 @@ def test_controller_failure(serve_instance):
response = request_with_retries("/controller_failure/", timeout=30)
assert response.text == "hello1"
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
for _ in range(10):
response = request_with_retries("/controller_failure/", timeout=30)
@ -43,7 +43,7 @@ def test_controller_failure(serve_instance):
def function2(_):
return "hello2"
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
function.options(func_or_class=function2).deploy()
@ -57,9 +57,9 @@ def test_controller_failure(serve_instance):
def function3(_):
return "hello3"
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
function3.deploy()
ray.kill(serve.api._global_client._controller, no_restart=False)
ray.kill(serve.context._global_client._controller, no_restart=False)
for _ in range(10):
response = request_with_retries("/controller_failure/", timeout=30)
@ -70,7 +70,7 @@ def test_controller_failure(serve_instance):
def _kill_http_proxies():
http_proxies = ray.get(
serve.api._global_client._controller.get_http_proxies.remote()
serve.context._global_client._controller.get_http_proxies.remote()
)
for http_proxy in http_proxies.values():
ray.kill(http_proxy, no_restart=False)
@ -107,7 +107,7 @@ def test_http_proxy_failure(serve_instance):
def _get_worker_handles(deployment):
controller = serve.api._global_client._controller
controller = serve.context._global_client._controller
deployment_dict = ray.get(controller._all_running_replicas.remote())
return [replica.actor_handle for replica in deployment_dict[deployment]]

View file

@ -8,7 +8,7 @@ from ray.serve.http_state import HTTPState
@pytest.fixture
def patch_get_namespace():
with patch("ray.serve.api._get_controller_namespace") as func:
with patch("ray.serve.client.get_controller_namespace") as func:
func.return_value = "dummy_namespace"
yield

View file

@ -11,7 +11,7 @@ import ray
from ray.exceptions import GetTimeoutError
from ray import serve
from ray._private.test_utils import SignalActor
from ray.serve.api import internal_get_global_client
from ray.serve.context import get_global_client
@pytest.fixture
@ -146,7 +146,7 @@ def test_nested_actors(serve_instance):
def test_handle_cache_out_of_scope(serve_instance):
# https://github.com/ray-project/ray/issues/18980
initial_num_cached = len(internal_get_global_client().handle_cache)
initial_num_cached = len(get_global_client().handle_cache)
@serve.deployment(name="f")
def f():
@ -155,7 +155,7 @@ def test_handle_cache_out_of_scope(serve_instance):
f.deploy()
handle = serve.get_deployment("f").get_handle()
handle_cache = internal_get_global_client().handle_cache
handle_cache = get_global_client().handle_cache
assert len(handle_cache) == initial_num_cached + 1
def sender_where_handle_goes_out_of_scope():

View file

@ -25,7 +25,7 @@ from ray._private.test_utils import (
from ray.cluster_utils import Cluster, cluster_not_supported
from ray import serve
from ray.serve.api import internal_get_global_client
from ray.serve.context import get_global_client
from ray.serve.config import HTTPOptions
from ray.serve.constants import SERVE_ROOT_URL_ENV_KEY, SERVE_PROXY_NAME
from ray.serve.exceptions import RayServeException
@ -58,12 +58,12 @@ def test_shutdown(ray_shutdown):
f.deploy()
serve_controller_name = serve.api._global_client._controller_name
serve_controller_name = serve.context._global_client._controller_name
actor_names = [
serve_controller_name,
format_actor_name(
SERVE_PROXY_NAME,
serve.api._global_client._controller_name,
serve.context._global_client._controller_name,
get_all_node_ids()[0][0],
),
]
@ -124,7 +124,7 @@ def test_detached_deployment(ray_cluster):
assert ray.get(f.get_handle().remote()) == "from_f"
assert requests.get("http://localhost:8000/say_hi_f").text == "from_f"
serve.api._global_client = None
serve.context._global_client = None
ray.shutdown()
# Create the second job, make sure we can still create new deployments.
@ -200,7 +200,9 @@ def test_multiple_routers(ray_cluster):
for node_id, _ in get_all_node_ids():
proxy_names.append(
format_actor_name(
SERVE_PROXY_NAME, serve.api._global_client._controller_name, node_id
SERVE_PROXY_NAME,
serve.context._global_client._controller_name,
node_id,
)
)
return proxy_names
@ -375,7 +377,7 @@ def test_no_http(ray_shutdown):
if actor["State"] == convert_actor_state(gcs_utils.ActorTableData.ALIVE)
]
assert len(live_actors) == 1
controller = serve.api._global_client._controller
controller = serve.context._global_client._controller
assert len(ray.get(controller.get_http_proxies.remote())) == 0
# Test that the handle still works.
@ -446,7 +448,7 @@ def test_fixed_number_proxies(ray_cluster):
)
# Only the controller and two http proxy should be started.
controller_handle = internal_get_global_client()._controller
controller_handle = get_global_client()._controller
node_to_http_actors = ray.get(controller_handle.get_http_proxies.remote())
assert len(node_to_http_actors) == 2
@ -507,7 +509,7 @@ def test_serve_controller_namespace(
ray.init(namespace=namespace)
serve.start(detached=detached)
client = serve.api._global_client
client = serve.context._global_client
if namespace:
controller_namespace = namespace
elif detached:

View file

@ -11,7 +11,7 @@ from ray.exceptions import RayActorError
import ray
import ray.state
from ray import serve
from ray.serve.api import internal_get_global_client
from ray.serve.context import get_global_client
from ray._private.test_utils import wait_for_condition
from ray.tests.conftest import call_ray_stop_only # noqa: F401
@ -76,7 +76,7 @@ def test_override_namespace(shutdown_ray, detached):
ray.init(namespace=ray_namespace)
serve.start(detached=detached, _override_controller_namespace=controller_namespace)
controller_name = internal_get_global_client()._controller_name
controller_name = get_global_client()._controller_name
ray.get_actor(controller_name, namespace=controller_namespace)
serve.shutdown()
@ -148,7 +148,7 @@ def test_refresh_controller_after_death(shutdown_ray, detached):
serve.shutdown() # Ensure serve isn't running before beginning the test
serve.start(detached=detached, _override_controller_namespace=controller_namespace)
old_handle = internal_get_global_client()._controller
old_handle = get_global_client()._controller
ray.kill(old_handle, no_restart=True)
def controller_died(handle):
@ -163,7 +163,7 @@ def test_refresh_controller_after_death(shutdown_ray, detached):
# Call start again to refresh handle
serve.start(detached=detached, _override_controller_namespace=controller_namespace)
new_handle = internal_get_global_client()._controller
new_handle = get_global_client()._controller
assert new_handle is not old_handle
# Health check should not error

View file

@ -68,7 +68,7 @@ class RandomKiller:
self.kill_period_s = kill_period_s
def _get_all_serve_actors(self):
controller = serve.api.internal_get_global_client()._controller
controller = serve.context.get_global_client()._controller
routers = list(ray.get(controller.get_http_proxies.remote()).values())
all_handles = routers + [controller]
worker_handle_dict = ray.get(controller._all_running_replicas.remote())

View file

@ -83,10 +83,10 @@ def main():
# Kill current cluster, recover from remote checkpoint and ensure endpoint
# is still available with expected results
ray.kill(serve.api._global_client._controller, no_restart=True)
ray.kill(serve.context._global_client._controller, no_restart=True)
ray.shutdown()
cluster.shutdown()
serve.api._set_global_client(None)
serve.context.set_global_client(None)
# Start another ray cluster with same namespace to resume from previous
# checkpoints with no new deploy() call.

View file

@ -82,10 +82,10 @@ def main():
# Kill current cluster, recover from remote checkpoint and ensure endpoint
# is still available with expected results
ray.kill(serve.api._global_client._controller, no_restart=True)
ray.kill(serve.context._global_client._controller, no_restart=True)
ray.shutdown()
cluster.shutdown()
serve.api._set_global_client(None)
serve.context.set_global_client(None)
# Start another ray cluster with same namespace to resume from previous
# checkpoints with no new deploy() call.