diff --git a/dashboard/modules/serve/serve_head.py b/dashboard/modules/serve/serve_head.py index 5f33e8330..64ebf1332 100644 --- a/dashboard/modules/serve/serve_head.py +++ b/dashboard/modules/serve/serve_head.py @@ -56,7 +56,7 @@ class ServeHead(dashboard_utils.DashboardHeadModule): @optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True) async def put_all_deployments(self, req: Request) -> Response: from ray import serve - from ray.serve.api import internal_get_global_client + from ray.serve.context import get_global_client from ray.serve.application import Application app = Application.from_dict(await req.json()) @@ -69,7 +69,7 @@ class ServeHead(dashboard_utils.DashboardHeadModule): all_deployments = serve.list_deployments() all_names = set(all_deployments.keys()) names_to_delete = all_names.difference(new_names) - internal_get_global_client().delete_deployments(names_to_delete) + get_global_client().delete_deployments(names_to_delete) return Response() diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 4a3693f6a..d9c7f10d1 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -1,50 +1,31 @@ -import asyncio -import atexit import collections import inspect import logging -import random -import re -import time -from dataclasses import dataclass -from functools import wraps from typing import ( Any, Callable, Dict, Optional, Tuple, - Type, Union, - List, - Iterable, overload, ) from fastapi import APIRouter, FastAPI -from ray.exceptions import RayActorError from starlette.requests import Request from uvicorn.config import Config from uvicorn.lifespan.on import LifespanOn -from ray.actor import ActorHandle -from ray.serve.common import ( - DeploymentInfo, - DeploymentStatus, - DeploymentStatusInfo, - ReplicaTag, -) +from ray.serve.common import DeploymentStatusInfo from ray.serve.config import ( AutoscalingConfig, DeploymentConfig, HTTPOptions, - ReplicaConfig, ) from ray.serve.constants import ( DEFAULT_CHECKPOINT_PATH, HTTP_PROXY_TIMEOUT, SERVE_CONTROLLER_NAME, - MAX_CACHED_HANDLES, CONTROLLER_MAX_CONCURRENCY, DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, @@ -52,13 +33,8 @@ from ray.serve.constants import ( from ray.serve.controller import ServeController from ray.serve.deployment import Deployment from ray.serve.exceptions import RayServeException -from ray.serve.generated.serve_pb2 import ( - DeploymentRoute, - DeploymentRouteList, - DeploymentStatusInfoList, -) from ray.experimental.dag import DAGNode -from ray.serve.handle import RayServeHandle, RayServeSyncHandle +from ray.serve.handle import RayServeHandle from ray.serve.http_util import ASGIHTTPSender, make_fastapi_class_based_view from ray.serve.logging_utils import LoggingContext from ray.serve.utils import ( @@ -74,620 +50,15 @@ import ray from ray import cloudpickle from ray.serve.deployment_graph import ClassNode, FunctionNode from ray.serve.application import Application - -logger = logging.getLogger(__file__) - - -_INTERNAL_REPLICA_CONTEXT = None -_global_client: "Client" = None - -_UUID_RE = re.compile( - "[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12}" +from ray.serve.client import ServeControllerClient, get_controller_namespace +from ray.serve.context import ( + set_global_client, + get_global_client, + get_internal_replica_context, + ReplicaContext, ) -# The polling interval for serve client to wait to deployment state -_CLIENT_POLLING_INTERVAL_S: float = 1 - - -def _get_controller_namespace( - detached: bool, _override_controller_namespace: Optional[str] = None -): - """Gets the controller's namespace. - - Args: - detached (bool): Whether serve.start() was called with detached=True - _override_controller_namespace (Optional[str]): When set, this is the - controller's namespace - """ - - if _override_controller_namespace is not None: - return _override_controller_namespace - - controller_namespace = ray.get_runtime_context().namespace - - if not detached: - return controller_namespace - - # Start controller in "serve" namespace if detached and currently - # in anonymous namespace. - if _UUID_RE.fullmatch(controller_namespace) is not None: - controller_namespace = "serve" - return controller_namespace - - -def internal_get_global_client( - _override_controller_namespace: Optional[str] = None, - _health_check_controller: bool = False, -) -> "Client": - """Gets the global client, which stores the controller's handle. - - Args: - _override_controller_namespace (Optional[str]): If None and there's no - cached client, searches for the controller in this namespace. - _health_check_controller (bool): If True, run a health check on the - cached controller if it exists. If the check fails, try reconnecting - to the controller. - - Raises: - RayServeException: if there is no Serve controller actor in the - expected namespace. - """ - - try: - if _global_client is not None: - if _health_check_controller: - ray.get(_global_client._controller.check_alive.remote()) - return _global_client - except RayActorError: - logger.info("The cached controller has died. Reconnecting.") - _set_global_client(None) - - return _connect(_override_controller_namespace=_override_controller_namespace) - - -def _set_global_client(client): - global _global_client - _global_client = client - - -@dataclass -class ReplicaContext: - """Stores data for Serve API calls from within deployments.""" - - deployment: str - replica_tag: ReplicaTag - _internal_controller_name: str - _internal_controller_namespace: str - servable_object: Callable - - -def _set_internal_replica_context( - deployment: str, - replica_tag: ReplicaTag, - controller_name: str, - controller_namespace: str, - servable_object: Callable, -): - global _INTERNAL_REPLICA_CONTEXT - _INTERNAL_REPLICA_CONTEXT = ReplicaContext( - deployment, replica_tag, controller_name, controller_namespace, servable_object - ) - - -def _ensure_connected(f: Callable) -> Callable: - @wraps(f) - def check(self, *args, **kwargs): - if self._shutdown: - raise RayServeException("Client has already been shut down.") - return f(self, *args, **kwargs) - - return check - - -class Client: - def __init__( - self, - controller: ActorHandle, - controller_name: str, - detached: bool = False, - _override_controller_namespace: Optional[str] = None, - ): - self._controller: ServeController = controller - self._controller_name = controller_name - self._detached = detached - self._override_controller_namespace = _override_controller_namespace - self._shutdown = False - self._http_config: HTTPOptions = ray.get(controller.get_http_config.remote()) - self._root_url = ray.get(controller.get_root_url.remote()) - self._checkpoint_path = ray.get(controller.get_checkpoint_path.remote()) - - # Each handle has the overhead of long poll client, therefore cached. - self.handle_cache = dict() - self._evicted_handle_keys = set() - - # NOTE(edoakes): Need this because the shutdown order isn't guaranteed - # when the interpreter is exiting so we can't rely on __del__ (it - # throws a nasty stacktrace). - if not self._detached: - - def shutdown_serve_client(): - self.shutdown() - - atexit.register(shutdown_serve_client) - - @property - def root_url(self): - return self._root_url - - @property - def http_config(self): - return self._http_config - - @property - def checkpoint_path(self): - return self._checkpoint_path - - def __del__(self): - if not self._detached: - logger.debug( - "Shutting down Ray Serve because client went out of " - "scope. To prevent this, either keep a reference to " - "the client or use serve.start(detached=True)." - ) - self.shutdown() - - def __reduce__(self): - raise RayServeException(("Ray Serve client cannot be serialized.")) - - def shutdown(self) -> None: - """Completely shut down the connected Serve instance. - - Shuts down all processes and deletes all state associated with the - instance. - """ - if ray.is_initialized() and not self._shutdown: - ray.get(self._controller.shutdown.remote()) - self._wait_for_deployments_shutdown() - - ray.kill(self._controller, no_restart=True) - - # Wait for the named actor entry gets removed as well. - started = time.time() - while True: - try: - controller_namespace = _get_controller_namespace( - self._detached, - self._override_controller_namespace, - ) - ray.get_actor(self._controller_name, namespace=controller_namespace) - if time.time() - started > 5: - logger.warning( - "Waited 5s for Serve to shutdown gracefully but " - "the controller is still not cleaned up. " - "You can ignore this warning if you are shutting " - "down the Ray cluster." - ) - break - except ValueError: # actor name is removed - break - - self._shutdown = True - - def _wait_for_deployments_shutdown(self, timeout_s: int = 60): - """Waits for all deployments to be shut down and deleted. - - Raises TimeoutError if this doesn't happen before timeout_s. - """ - start = time.time() - while time.time() - start < timeout_s: - statuses = self.get_deployment_statuses() - if len(statuses) == 0: - break - else: - logger.debug( - f"Waiting for shutdown, {len(statuses)} deployments still alive." - ) - time.sleep(_CLIENT_POLLING_INTERVAL_S) - else: - live_names = list(statuses.keys()) - raise TimeoutError( - f"Shutdown didn't complete after {timeout_s}s. " - f"Deployments still alive: {live_names}." - ) - - def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1): - """Waits for the named deployment to enter "HEALTHY" status. - - Raises RuntimeError if the deployment enters the "UNHEALTHY" status - instead. - - Raises TimeoutError if this doesn't happen before timeout_s. - """ - start = time.time() - while time.time() - start < timeout_s or timeout_s < 0: - statuses = self.get_deployment_statuses() - try: - status = statuses[name] - except KeyError: - raise RuntimeError( - f"Waiting for deployment {name} to be HEALTHY, " - "but deployment doesn't exist." - ) from None - - if status.status == DeploymentStatus.HEALTHY: - break - elif status.status == DeploymentStatus.UNHEALTHY: - raise RuntimeError(f"Deployment {name} is UNHEALTHY: {status.message}") - else: - # Guard against new unhandled statuses being added. - assert status.status == DeploymentStatus.UPDATING - - logger.debug( - f"Waiting for {name} to be healthy, current status: {status.status}." - ) - time.sleep(_CLIENT_POLLING_INTERVAL_S) - else: - raise TimeoutError( - f"Deployment {name} did not become HEALTHY after {timeout_s}s." - ) - - def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60): - """Waits for the named deployment to be shut down and deleted. - - Raises TimeoutError if this doesn't happen before timeout_s. - """ - start = time.time() - while time.time() - start < timeout_s: - statuses = self.get_deployment_statuses() - if name not in statuses: - break - else: - curr_status = statuses[name].status - logger.debug( - f"Waiting for {name} to be deleted, current status: {curr_status}." - ) - time.sleep(_CLIENT_POLLING_INTERVAL_S) - else: - raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.") - - @_ensure_connected - def deploy( - self, - name: str, - deployment_def: Union[Callable, Type[Callable], str], - init_args: Tuple[Any], - init_kwargs: Dict[Any, Any], - ray_actor_options: Optional[Dict] = None, - config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - route_prefix: Optional[str] = None, - url: Optional[str] = None, - _blocking: Optional[bool] = True, - ): - - controller_deploy_args = self.get_deploy_args( - name=name, - deployment_def=deployment_def, - init_args=init_args, - init_kwargs=init_kwargs, - ray_actor_options=ray_actor_options, - config=config, - version=version, - prev_version=prev_version, - route_prefix=route_prefix, - ) - - updating = ray.get(self._controller.deploy.remote(**controller_deploy_args)) - - tag = self.log_deployment_update_status(name, version, updating) - - if _blocking: - self._wait_for_deployment_healthy(name) - self.log_deployment_ready(name, version, url, tag) - - @_ensure_connected - def deploy_group(self, deployments: List[Dict], _blocking: bool = True): - deployment_args_list = [] - for deployment in deployments: - deployment_args_list.append( - self.get_deploy_args( - deployment["name"], - deployment["func_or_class"], - deployment["init_args"], - deployment["init_kwargs"], - ray_actor_options=deployment["ray_actor_options"], - config=deployment["config"], - version=deployment["version"], - prev_version=deployment["prev_version"], - route_prefix=deployment["route_prefix"], - ) - ) - - updating_list = ray.get( - self._controller.deploy_group.remote(deployment_args_list) - ) - - tags = [] - for i, updating in enumerate(updating_list): - deployment = deployments[i] - name, version = deployment["name"], deployment["version"] - - tags.append(self.log_deployment_update_status(name, version, updating)) - - for i, deployment in enumerate(deployments): - name = deployment["name"] - url = deployment["url"] - - if _blocking: - self._wait_for_deployment_healthy(name) - self.log_deployment_ready(name, version, url, tags[i]) - - @_ensure_connected - def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None: - ray.get(self._controller.delete_deployments.remote(names)) - if blocking: - for name in names: - self._wait_for_deployment_deleted(name) - - @_ensure_connected - def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]: - deployment_route = DeploymentRoute.FromString( - ray.get(self._controller.get_deployment_info.remote(name)) - ) - return ( - DeploymentInfo.from_proto(deployment_route.deployment_info), - deployment_route.route if deployment_route.route != "" else None, - ) - - @_ensure_connected - def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]: - deployment_route_list = DeploymentRouteList.FromString( - ray.get(self._controller.list_deployments.remote()) - ) - return { - deployment_route.deployment_info.name: ( - DeploymentInfo.from_proto(deployment_route.deployment_info), - deployment_route.route if deployment_route.route != "" else None, - ) - for deployment_route in deployment_route_list.deployment_routes - } - - @_ensure_connected - def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]: - proto = DeploymentStatusInfoList.FromString( - ray.get(self._controller.get_deployment_statuses.remote()) - ) - return { - deployment_status_info.name: DeploymentStatusInfo.from_proto( - deployment_status_info - ) - for deployment_status_info in proto.deployment_status_infos - } - - @_ensure_connected - def get_handle( - self, - deployment_name: str, - missing_ok: Optional[bool] = False, - sync: bool = True, - _internal_pickled_http_request: bool = False, - ) -> Union[RayServeHandle, RayServeSyncHandle]: - """Retrieve RayServeHandle for service deployment to invoke it from Python. - - Args: - deployment_name (str): A registered service deployment. - missing_ok (bool): If true, then Serve won't check the deployment - is registered. False by default. - sync (bool): If true, then Serve will return a ServeHandle that - works everywhere. Otherwise, Serve will return a ServeHandle - that's only usable in asyncio loop. - - Returns: - RayServeHandle - """ - cache_key = (deployment_name, missing_ok, sync) - if cache_key in self.handle_cache: - cached_handle = self.handle_cache[cache_key] - if cached_handle.is_polling and cached_handle.is_same_loop: - return cached_handle - - all_endpoints = ray.get(self._controller.get_all_endpoints.remote()) - if not missing_ok and deployment_name not in all_endpoints: - raise KeyError(f"Deployment '{deployment_name}' does not exist.") - - try: - asyncio_loop_running = asyncio.get_event_loop().is_running() - except RuntimeError as ex: - if "There is no current event loop in thread" in str(ex): - asyncio_loop_running = False - else: - raise ex - - if asyncio_loop_running and sync: - logger.warning( - "You are retrieving a sync handle inside an asyncio loop. " - "Try getting client.get_handle(.., sync=False) to get better " - "performance. Learn more at https://docs.ray.io/en/master/" - "serve/http-servehandle.html#sync-and-async-handles" - ) - - if not asyncio_loop_running and not sync: - logger.warning( - "You are retrieving an async handle outside an asyncio loop. " - "You should make sure client.get_handle is called inside a " - "running event loop. Or call client.get_handle(.., sync=True) " - "to create sync handle. Learn more at https://docs.ray.io/en/" - "master/serve/http-servehandle.html#sync-and-async-handles" - ) - - if sync: - handle = RayServeSyncHandle( - self._controller, - deployment_name, - _internal_pickled_http_request=_internal_pickled_http_request, - ) - else: - handle = RayServeHandle( - self._controller, - deployment_name, - _internal_pickled_http_request=_internal_pickled_http_request, - ) - - self.handle_cache[cache_key] = handle - if cache_key in self._evicted_handle_keys: - logger.warning( - "You just got a ServeHandle that was evicted from internal " - "cache. This means you are getting too many ServeHandles in " - "the same process, this will bring down Serve's performance. " - "Please post a github issue at " - "https://github.com/ray-project/ray/issues to let the Serve " - "team to find workaround for your use case." - ) - - if len(self.handle_cache) > MAX_CACHED_HANDLES: - # Perform random eviction to keep the handle cache from growing - # infinitely. We used use WeakValueDictionary but hit - # https://github.com/ray-project/ray/issues/18980. - evict_key = random.choice(list(self.handle_cache.keys())) - self._evicted_handle_keys.add(evict_key) - self.handle_cache.pop(evict_key) - - return handle - - @_ensure_connected - def get_deploy_args( - self, - name: str, - deployment_def: Union[Callable, Type[Callable], str], - init_args: Tuple[Any], - init_kwargs: Dict[Any, Any], - ray_actor_options: Optional[Dict] = None, - config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - route_prefix: Optional[str] = None, - ) -> Dict: - """ - Takes a deployment's configuration, and returns the arguments needed - for the controller to deploy it. - """ - - if config is None: - config = {} - if ray_actor_options is None: - ray_actor_options = {} - - curr_job_env = ray.get_runtime_context().runtime_env - if "runtime_env" in ray_actor_options: - # It is illegal to set field working_dir to None. - if curr_job_env.get("working_dir") is not None: - ray_actor_options["runtime_env"].setdefault( - "working_dir", curr_job_env.get("working_dir") - ) - else: - ray_actor_options["runtime_env"] = curr_job_env - - replica_config = ReplicaConfig( - deployment_def, - init_args=init_args, - init_kwargs=init_kwargs, - ray_actor_options=ray_actor_options, - ) - - if isinstance(config, dict): - deployment_config = DeploymentConfig.parse_obj(config) - elif isinstance(config, DeploymentConfig): - deployment_config = config - else: - raise TypeError("config must be a DeploymentConfig or a dictionary.") - - deployment_config.version = version - deployment_config.prev_version = prev_version - - if ( - deployment_config.autoscaling_config is not None - and deployment_config.max_concurrent_queries - < deployment_config.autoscaling_config.target_num_ongoing_requests_per_replica # noqa: E501 - ): - logger.warning( - "Autoscaling will never happen, " - "because 'max_concurrent_queries' is less than " - "'target_num_ongoing_requests_per_replica' now." - ) - - controller_deploy_args = { - "name": name, - "deployment_config_proto_bytes": deployment_config.to_proto_bytes(), - "replica_config_proto_bytes": replica_config.to_proto_bytes(), - "route_prefix": route_prefix, - "deployer_job_id": ray.get_runtime_context().job_id, - } - - return controller_deploy_args - - @_ensure_connected - def log_deployment_update_status( - self, name: str, version: str, updating: bool - ) -> str: - tag = f"component=serve deployment={name}" - - if updating: - msg = f"Updating deployment '{name}'" - if version is not None: - msg += f" to version '{version}'" - logger.info(f"{msg}. {tag}") - else: - logger.info( - f"Deployment '{name}' is already at version " - f"'{version}', not updating. {tag}" - ) - - return tag - - @_ensure_connected - def log_deployment_ready(self, name: str, version: str, url: str, tag: str) -> None: - if url is not None: - url_part = f" at `{url}`" - else: - url_part = "" - logger.info( - f"Deployment '{name}{':'+version if version else ''}' is ready" - f"{url_part}. {tag}" - ) - - -def _check_http_and_checkpoint_options( - client: Client, - http_options: Union[dict, HTTPOptions], - checkpoint_path: str, -) -> None: - if checkpoint_path and checkpoint_path != client.checkpoint_path: - logger.warning( - f"The new client checkpoint path '{checkpoint_path}' " - f"is different from the existing one '{client.checkpoint_path}'. " - "The new checkpoint path is ignored." - ) - - if http_options: - client_http_options = client.http_config - new_http_options = ( - http_options - if isinstance(http_options, HTTPOptions) - else HTTPOptions.parse_obj(http_options) - ) - different_fields = [] - all_http_option_fields = new_http_options.__dict__ - for field in all_http_option_fields: - if getattr(new_http_options, field) != getattr(client_http_options, field): - different_fields.append(field) - - if len(different_fields): - logger.warning( - "The new client HTTP config differs from the existing one " - f"in the following fields: {different_fields}. " - "The new HTTP config is ignored." - ) +logger = logging.getLogger(__file__) @PublicAPI(stability="beta") @@ -698,7 +69,7 @@ def start( _checkpoint_path: str = DEFAULT_CHECKPOINT_PATH, _override_controller_namespace: Optional[str] = None, **kwargs, -) -> Client: +) -> ServeControllerClient: """Initialize a serve instance. By default, the instance will be scoped to the lifetime of the returned @@ -749,12 +120,12 @@ def start( if not ray.is_initialized(): ray.init(namespace="serve") - controller_namespace = _get_controller_namespace( + controller_namespace = get_controller_namespace( detached, _override_controller_namespace=_override_controller_namespace ) try: - client = internal_get_global_client( + client = get_global_client( _override_controller_namespace=_override_controller_namespace, _health_check_controller=True, ) @@ -808,13 +179,13 @@ def start( "HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s." ) - client = Client( + client = ServeControllerClient( controller, controller_name, detached=detached, _override_controller_namespace=_override_controller_namespace, ) - _set_global_client(client) + set_global_client(client) logger.info( f"Started{' detached ' if detached else ' '}Serve instance in " f"namespace '{controller_namespace}'." @@ -822,62 +193,6 @@ def start( return client -def _connect(_override_controller_namespace: Optional[str] = None) -> Client: - """Connect to an existing Serve instance on this Ray cluster. - - If calling from the driver program, the Serve instance on this Ray cluster - must first have been initialized using `serve.start(detached=True)`. - - If called from within a replica, this will connect to the same Serve - instance that the replica is running in. - - Args: - _override_controller_namespace (Optional[str]): The namespace to use - when looking for the controller. If None, Serve recalculates the - controller's namespace using _get_controller_namespace(). - - Raises: - RayServeException: if there is no Serve controller actor in the - expected namespace. - """ - - # Initialize ray if needed. - ray.worker.global_worker.filter_logs_by_job = False - if not ray.is_initialized(): - ray.init(namespace="serve") - - # When running inside of a replica, _INTERNAL_REPLICA_CONTEXT is set to - # ensure that the correct instance is connected to. - if _INTERNAL_REPLICA_CONTEXT is None: - controller_name = SERVE_CONTROLLER_NAME - controller_namespace = _get_controller_namespace( - detached=True, _override_controller_namespace=_override_controller_namespace - ) - else: - controller_name = _INTERNAL_REPLICA_CONTEXT._internal_controller_name - controller_namespace = _INTERNAL_REPLICA_CONTEXT._internal_controller_namespace - - # Try to get serve controller if it exists - try: - controller = ray.get_actor(controller_name, namespace=controller_namespace) - except ValueError: - raise RayServeException( - "There is no " - "instance running on this Ray cluster. Please " - "call `serve.start(detached=True) to start " - "one." - ) - - client = Client( - controller, - controller_name, - detached=True, - _override_controller_namespace=_override_controller_namespace, - ) - _set_global_client(client) - return client - - @PublicAPI def shutdown() -> None: """Completely shut down the connected Serve instance. @@ -887,7 +202,7 @@ def shutdown() -> None: """ try: - client = internal_get_global_client() + client = get_global_client() except RayServeException: logger.info( "Nothing to shut down. There's no Serve application " @@ -896,7 +211,7 @@ def shutdown() -> None: return client.shutdown() - _set_global_client(None) + set_global_client(None) @PublicAPI @@ -917,13 +232,14 @@ def get_replica_context() -> ReplicaContext: >>> # deployment_name#krcwoa >>> serve.get_replica_context().replica_tag # doctest: +SKIP """ - if _INTERNAL_REPLICA_CONTEXT is None: + internal_replica_context = get_internal_replica_context() + if internal_replica_context is None: raise RayServeException( "`serve.get_replica_context()` " "may only be called from within a " "Ray Serve deployment." ) - return _INTERNAL_REPLICA_CONTEXT + return internal_replica_context @PublicAPI(stability="beta") @@ -1180,7 +496,7 @@ def get_deployment(name: str) -> Deployment: ( deployment_info, route_prefix, - ) = internal_get_global_client().get_deployment_info(name) + ) = get_global_client().get_deployment_info(name) except KeyError: raise KeyError( f"Deployment {name} was not found. Did you call Deployment.deploy()?" @@ -1204,7 +520,7 @@ def list_deployments() -> Dict[str, Deployment]: Dictionary maps deployment name to Deployment objects. """ - infos = internal_get_global_client().list_deployments() + infos = get_global_client().list_deployments() deployments = {} for name, (deployment_info, route_prefix) in infos.items(): @@ -1241,7 +557,7 @@ def get_deployment_statuses() -> Dict[str, DeploymentStatusInfo]: status and a message explaining the status. """ - return internal_get_global_client().get_deployment_statuses() + return get_global_client().get_deployment_statuses() @PublicAPI(stability="alpha") @@ -1359,3 +675,36 @@ def build(target: Union[ClassNode, FunctionNode]) -> Application: # TODO(edoakes): this should accept host and port, but we don't # currently support them in the REST API. return Application(pipeline_build(target)) + + +def _check_http_and_checkpoint_options( + client: ServeControllerClient, + http_options: Union[dict, HTTPOptions], + checkpoint_path: str, +) -> None: + if checkpoint_path and checkpoint_path != client.checkpoint_path: + logger.warning( + f"The new client checkpoint path '{checkpoint_path}' " + f"is different from the existing one '{client.checkpoint_path}'. " + "The new checkpoint path is ignored." + ) + + if http_options: + client_http_options = client.http_config + new_http_options = ( + http_options + if isinstance(http_options, HTTPOptions) + else HTTPOptions.parse_obj(http_options) + ) + different_fields = [] + all_http_option_fields = new_http_options.__dict__ + for field in all_http_option_fields: + if getattr(new_http_options, field) != getattr(client_http_options, field): + different_fields.append(field) + + if len(different_fields): + logger.warning( + "The new client HTTP config differs from the existing one " + f"in the following fields: {different_fields}. " + "The new HTTP config is ignored." + ) diff --git a/python/ray/serve/client.py b/python/ray/serve/client.py new file mode 100644 index 000000000..90374ece7 --- /dev/null +++ b/python/ray/serve/client.py @@ -0,0 +1,555 @@ +import asyncio +import atexit +import random +import logging +import time +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + List, + Iterable, +) + +import ray +from ray.actor import ActorHandle +from ray.serve.common import ( + DeploymentInfo, + DeploymentStatus, + DeploymentStatusInfo, +) +from ray.serve.config import ( + DeploymentConfig, + HTTPOptions, + ReplicaConfig, +) +from ray.serve.constants import ( + MAX_CACHED_HANDLES, + CLIENT_POLLING_INTERVAL_S, + ANONYMOUS_NAMESPACE_PATTERN, +) +from ray.serve.controller import ServeController +from ray.serve.exceptions import RayServeException +from ray.serve.generated.serve_pb2 import ( + DeploymentRoute, + DeploymentRouteList, + DeploymentStatusInfoList, +) +from ray.serve.handle import RayServeHandle, RayServeSyncHandle + + +logger = logging.getLogger(__file__) + + +def _ensure_connected(f: Callable) -> Callable: + @wraps(f) + def check(self, *args, **kwargs): + if self._shutdown: + raise RayServeException("Client has already been shut down.") + return f(self, *args, **kwargs) + + return check + + +class ServeControllerClient: + def __init__( + self, + controller: ActorHandle, + controller_name: str, + detached: bool = False, + _override_controller_namespace: Optional[str] = None, + ): + self._controller: ServeController = controller + self._controller_name = controller_name + self._detached = detached + self._override_controller_namespace = _override_controller_namespace + self._shutdown = False + self._http_config: HTTPOptions = ray.get(controller.get_http_config.remote()) + self._root_url = ray.get(controller.get_root_url.remote()) + self._checkpoint_path = ray.get(controller.get_checkpoint_path.remote()) + + # Each handle has the overhead of long poll client, therefore cached. + self.handle_cache = dict() + self._evicted_handle_keys = set() + + # NOTE(edoakes): Need this because the shutdown order isn't guaranteed + # when the interpreter is exiting so we can't rely on __del__ (it + # throws a nasty stacktrace). + if not self._detached: + + def shutdown_serve_client(): + self.shutdown() + + atexit.register(shutdown_serve_client) + + @property + def root_url(self): + return self._root_url + + @property + def http_config(self): + return self._http_config + + @property + def checkpoint_path(self): + return self._checkpoint_path + + def __del__(self): + if not self._detached: + logger.debug( + "Shutting down Ray Serve because client went out of " + "scope. To prevent this, either keep a reference to " + "the client or use serve.start(detached=True)." + ) + self.shutdown() + + def __reduce__(self): + raise RayServeException(("Ray Serve client cannot be serialized.")) + + def shutdown(self) -> None: + """Completely shut down the connected Serve instance. + + Shuts down all processes and deletes all state associated with the + instance. + """ + if ray.is_initialized() and not self._shutdown: + ray.get(self._controller.shutdown.remote()) + self._wait_for_deployments_shutdown() + + ray.kill(self._controller, no_restart=True) + + # Wait for the named actor entry gets removed as well. + started = time.time() + while True: + try: + controller_namespace = get_controller_namespace( + self._detached, + self._override_controller_namespace, + ) + ray.get_actor(self._controller_name, namespace=controller_namespace) + if time.time() - started > 5: + logger.warning( + "Waited 5s for Serve to shutdown gracefully but " + "the controller is still not cleaned up. " + "You can ignore this warning if you are shutting " + "down the Ray cluster." + ) + break + except ValueError: # actor name is removed + break + + self._shutdown = True + + def _wait_for_deployments_shutdown(self, timeout_s: int = 60): + """Waits for all deployments to be shut down and deleted. + + Raises TimeoutError if this doesn't happen before timeout_s. + """ + start = time.time() + while time.time() - start < timeout_s: + statuses = self.get_deployment_statuses() + if len(statuses) == 0: + break + else: + logger.debug( + f"Waiting for shutdown, {len(statuses)} deployments still alive." + ) + time.sleep(CLIENT_POLLING_INTERVAL_S) + else: + live_names = list(statuses.keys()) + raise TimeoutError( + f"Shutdown didn't complete after {timeout_s}s. " + f"Deployments still alive: {live_names}." + ) + + def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1): + """Waits for the named deployment to enter "HEALTHY" status. + + Raises RuntimeError if the deployment enters the "UNHEALTHY" status + instead. + + Raises TimeoutError if this doesn't happen before timeout_s. + """ + start = time.time() + while time.time() - start < timeout_s or timeout_s < 0: + statuses = self.get_deployment_statuses() + try: + status = statuses[name] + except KeyError: + raise RuntimeError( + f"Waiting for deployment {name} to be HEALTHY, " + "but deployment doesn't exist." + ) from None + + if status.status == DeploymentStatus.HEALTHY: + break + elif status.status == DeploymentStatus.UNHEALTHY: + raise RuntimeError(f"Deployment {name} is UNHEALTHY: {status.message}") + else: + # Guard against new unhandled statuses being added. + assert status.status == DeploymentStatus.UPDATING + + logger.debug( + f"Waiting for {name} to be healthy, current status: {status.status}." + ) + time.sleep(CLIENT_POLLING_INTERVAL_S) + else: + raise TimeoutError( + f"Deployment {name} did not become HEALTHY after {timeout_s}s." + ) + + def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60): + """Waits for the named deployment to be shut down and deleted. + + Raises TimeoutError if this doesn't happen before timeout_s. + """ + start = time.time() + while time.time() - start < timeout_s: + statuses = self.get_deployment_statuses() + if name not in statuses: + break + else: + curr_status = statuses[name].status + logger.debug( + f"Waiting for {name} to be deleted, current status: {curr_status}." + ) + time.sleep(CLIENT_POLLING_INTERVAL_S) + else: + raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.") + + @_ensure_connected + def deploy( + self, + name: str, + deployment_def: Union[Callable, Type[Callable], str], + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], + ray_actor_options: Optional[Dict] = None, + config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + route_prefix: Optional[str] = None, + url: Optional[str] = None, + _blocking: Optional[bool] = True, + ): + + controller_deploy_args = self.get_deploy_args( + name=name, + deployment_def=deployment_def, + init_args=init_args, + init_kwargs=init_kwargs, + ray_actor_options=ray_actor_options, + config=config, + version=version, + prev_version=prev_version, + route_prefix=route_prefix, + ) + + updating = ray.get(self._controller.deploy.remote(**controller_deploy_args)) + + tag = self.log_deployment_update_status(name, version, updating) + + if _blocking: + self._wait_for_deployment_healthy(name) + self.log_deployment_ready(name, version, url, tag) + + @_ensure_connected + def deploy_group(self, deployments: List[Dict], _blocking: bool = True): + deployment_args_list = [] + for deployment in deployments: + deployment_args_list.append( + self.get_deploy_args( + deployment["name"], + deployment["func_or_class"], + deployment["init_args"], + deployment["init_kwargs"], + ray_actor_options=deployment["ray_actor_options"], + config=deployment["config"], + version=deployment["version"], + prev_version=deployment["prev_version"], + route_prefix=deployment["route_prefix"], + ) + ) + + updating_list = ray.get( + self._controller.deploy_group.remote(deployment_args_list) + ) + + tags = [] + for i, updating in enumerate(updating_list): + deployment = deployments[i] + name, version = deployment["name"], deployment["version"] + + tags.append(self.log_deployment_update_status(name, version, updating)) + + for i, deployment in enumerate(deployments): + name = deployment["name"] + url = deployment["url"] + + if _blocking: + self._wait_for_deployment_healthy(name) + self.log_deployment_ready(name, version, url, tags[i]) + + @_ensure_connected + def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None: + ray.get(self._controller.delete_deployments.remote(names)) + if blocking: + for name in names: + self._wait_for_deployment_deleted(name) + + @_ensure_connected + def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]: + deployment_route = DeploymentRoute.FromString( + ray.get(self._controller.get_deployment_info.remote(name)) + ) + return ( + DeploymentInfo.from_proto(deployment_route.deployment_info), + deployment_route.route if deployment_route.route != "" else None, + ) + + @_ensure_connected + def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]: + deployment_route_list = DeploymentRouteList.FromString( + ray.get(self._controller.list_deployments.remote()) + ) + return { + deployment_route.deployment_info.name: ( + DeploymentInfo.from_proto(deployment_route.deployment_info), + deployment_route.route if deployment_route.route != "" else None, + ) + for deployment_route in deployment_route_list.deployment_routes + } + + @_ensure_connected + def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]: + proto = DeploymentStatusInfoList.FromString( + ray.get(self._controller.get_deployment_statuses.remote()) + ) + return { + deployment_status_info.name: DeploymentStatusInfo.from_proto( + deployment_status_info + ) + for deployment_status_info in proto.deployment_status_infos + } + + @_ensure_connected + def get_handle( + self, + deployment_name: str, + missing_ok: Optional[bool] = False, + sync: bool = True, + _internal_pickled_http_request: bool = False, + ) -> Union[RayServeHandle, RayServeSyncHandle]: + """Retrieve RayServeHandle for service deployment to invoke it from Python. + + Args: + deployment_name (str): A registered service deployment. + missing_ok (bool): If true, then Serve won't check the deployment + is registered. False by default. + sync (bool): If true, then Serve will return a ServeHandle that + works everywhere. Otherwise, Serve will return a ServeHandle + that's only usable in asyncio loop. + + Returns: + RayServeHandle + """ + cache_key = (deployment_name, missing_ok, sync) + if cache_key in self.handle_cache: + cached_handle = self.handle_cache[cache_key] + if cached_handle.is_polling and cached_handle.is_same_loop: + return cached_handle + + all_endpoints = ray.get(self._controller.get_all_endpoints.remote()) + if not missing_ok and deployment_name not in all_endpoints: + raise KeyError(f"Deployment '{deployment_name}' does not exist.") + + try: + asyncio_loop_running = asyncio.get_event_loop().is_running() + except RuntimeError as ex: + if "There is no current event loop in thread" in str(ex): + asyncio_loop_running = False + else: + raise ex + + if asyncio_loop_running and sync: + logger.warning( + "You are retrieving a sync handle inside an asyncio loop. " + "Try getting client.get_handle(.., sync=False) to get better " + "performance. Learn more at https://docs.ray.io/en/master/" + "serve/http-servehandle.html#sync-and-async-handles" + ) + + if not asyncio_loop_running and not sync: + logger.warning( + "You are retrieving an async handle outside an asyncio loop. " + "You should make sure client.get_handle is called inside a " + "running event loop. Or call client.get_handle(.., sync=True) " + "to create sync handle. Learn more at https://docs.ray.io/en/" + "master/serve/http-servehandle.html#sync-and-async-handles" + ) + + if sync: + handle = RayServeSyncHandle( + self._controller, + deployment_name, + _internal_pickled_http_request=_internal_pickled_http_request, + ) + else: + handle = RayServeHandle( + self._controller, + deployment_name, + _internal_pickled_http_request=_internal_pickled_http_request, + ) + + self.handle_cache[cache_key] = handle + if cache_key in self._evicted_handle_keys: + logger.warning( + "You just got a ServeHandle that was evicted from internal " + "cache. This means you are getting too many ServeHandles in " + "the same process, this will bring down Serve's performance. " + "Please post a github issue at " + "https://github.com/ray-project/ray/issues to let the Serve " + "team to find workaround for your use case." + ) + + if len(self.handle_cache) > MAX_CACHED_HANDLES: + # Perform random eviction to keep the handle cache from growing + # infinitely. We used use WeakValueDictionary but hit + # https://github.com/ray-project/ray/issues/18980. + evict_key = random.choice(list(self.handle_cache.keys())) + self._evicted_handle_keys.add(evict_key) + self.handle_cache.pop(evict_key) + + return handle + + @_ensure_connected + def get_deploy_args( + self, + name: str, + deployment_def: Union[Callable, Type[Callable], str], + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], + ray_actor_options: Optional[Dict] = None, + config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + route_prefix: Optional[str] = None, + ) -> Dict: + """ + Takes a deployment's configuration, and returns the arguments needed + for the controller to deploy it. + """ + + if config is None: + config = {} + if ray_actor_options is None: + ray_actor_options = {} + + curr_job_env = ray.get_runtime_context().runtime_env + if "runtime_env" in ray_actor_options: + # It is illegal to set field working_dir to None. + if curr_job_env.get("working_dir") is not None: + ray_actor_options["runtime_env"].setdefault( + "working_dir", curr_job_env.get("working_dir") + ) + else: + ray_actor_options["runtime_env"] = curr_job_env + + replica_config = ReplicaConfig( + deployment_def, + init_args=init_args, + init_kwargs=init_kwargs, + ray_actor_options=ray_actor_options, + ) + + if isinstance(config, dict): + deployment_config = DeploymentConfig.parse_obj(config) + elif isinstance(config, DeploymentConfig): + deployment_config = config + else: + raise TypeError("config must be a DeploymentConfig or a dictionary.") + + deployment_config.version = version + deployment_config.prev_version = prev_version + + if ( + deployment_config.autoscaling_config is not None + and deployment_config.max_concurrent_queries + < deployment_config.autoscaling_config.target_num_ongoing_requests_per_replica # noqa: E501 + ): + logger.warning( + "Autoscaling will never happen, " + "because 'max_concurrent_queries' is less than " + "'target_num_ongoing_requests_per_replica' now." + ) + + controller_deploy_args = { + "name": name, + "deployment_config_proto_bytes": deployment_config.to_proto_bytes(), + "replica_config_proto_bytes": replica_config.to_proto_bytes(), + "route_prefix": route_prefix, + "deployer_job_id": ray.get_runtime_context().job_id, + } + + return controller_deploy_args + + @_ensure_connected + def log_deployment_update_status( + self, name: str, version: str, updating: bool + ) -> str: + tag = f"component=serve deployment={name}" + + if updating: + msg = f"Updating deployment '{name}'" + if version is not None: + msg += f" to version '{version}'" + logger.info(f"{msg}. {tag}") + else: + logger.info( + f"Deployment '{name}' is already at version " + f"'{version}', not updating. {tag}" + ) + + return tag + + @_ensure_connected + def log_deployment_ready(self, name: str, version: str, url: str, tag: str) -> None: + if url is not None: + url_part = f" at `{url}`" + else: + url_part = "" + logger.info( + f"Deployment '{name}{':'+version if version else ''}' is ready" + f"{url_part}. {tag}" + ) + + +def get_controller_namespace( + detached: bool, _override_controller_namespace: Optional[str] = None +): + """Gets the controller's namespace. + + Args: + detached (bool): Whether serve.start() was called with detached=True + _override_controller_namespace (Optional[str]): When set, this is the + controller's namespace + """ + + if _override_controller_namespace is not None: + return _override_controller_namespace + + controller_namespace = ray.get_runtime_context().namespace + + if not detached: + return controller_namespace + + # Start controller in "serve" namespace if detached and currently + # in anonymous namespace. + if ANONYMOUS_NAMESPACE_PATTERN.fullmatch(controller_namespace) is not None: + controller_namespace = "serve" + return controller_namespace diff --git a/python/ray/serve/constants.py b/python/ray/serve/constants.py index 0c4e936b8..ea3ce5372 100644 --- a/python/ray/serve/constants.py +++ b/python/ray/serve/constants.py @@ -1,4 +1,5 @@ from enum import Enum +import re #: Used for debugging to turn on DEBUG-level logs DEBUG_LOG_ENV_VAR = "SERVE_DEBUG_LOG" @@ -86,6 +87,15 @@ REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD = 3 # Key used to idenfity given json represents a serialized RayServeHandle SERVE_HANDLE_JSON_KEY = "__SerializedServeHandle__" +# The time in seconds that the Serve client waits before rechecking deployment state +CLIENT_POLLING_INTERVAL_S: float = 1 + +# Regex pattern for anonymous namespace. Should match the pattern used in +# src/ray/gcs/gcs_server/gcs_actor_manager.cc's is_uuid() method. +ANONYMOUS_NAMESPACE_PATTERN = re.compile( + "[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12}" +) + class ServeHandleType(str, Enum): SYNC = "SYNC" diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py new file mode 100644 index 000000000..25e9b77e0 --- /dev/null +++ b/python/ray/serve/context.py @@ -0,0 +1,141 @@ +""" +This file stores global state for a Serve application. Deployment replicas +can use this state to access metadata or the Serve controller. +""" + +import logging +from dataclasses import dataclass +from typing import Callable, Optional + +import ray +from ray.exceptions import RayActorError +from ray.serve.common import ReplicaTag +from ray.serve.constants import SERVE_CONTROLLER_NAME +from ray.serve.exceptions import RayServeException +from ray.serve.client import ServeControllerClient, get_controller_namespace + +logger = logging.getLogger(__file__) + +_INTERNAL_REPLICA_CONTEXT: "ReplicaContext" = None +_global_client: ServeControllerClient = None + + +@dataclass +class ReplicaContext: + """Stores data for Serve API calls from within deployments.""" + + deployment: str + replica_tag: ReplicaTag + _internal_controller_name: str + _internal_controller_namespace: str + servable_object: Callable + + +def get_global_client( + _override_controller_namespace: Optional[str] = None, + _health_check_controller: bool = False, +) -> ServeControllerClient: + """Gets the global client, which stores the controller's handle. + + Args: + _override_controller_namespace (Optional[str]): If None and there's no + cached client, searches for the controller in this namespace. + _health_check_controller (bool): If True, run a health check on the + cached controller if it exists. If the check fails, try reconnecting + to the controller. + + Raises: + RayServeException: if there is no Serve controller actor in the + expected namespace. + """ + + try: + if _global_client is not None: + if _health_check_controller: + ray.get(_global_client._controller.check_alive.remote()) + return _global_client + except RayActorError: + logger.info("The cached controller has died. Reconnecting.") + set_global_client(None) + + return _connect(_override_controller_namespace=_override_controller_namespace) + + +def set_global_client(client): + global _global_client + _global_client = client + + +def get_internal_replica_context(): + return _INTERNAL_REPLICA_CONTEXT + + +def set_internal_replica_context( + deployment: str, + replica_tag: ReplicaTag, + controller_name: str, + controller_namespace: str, + servable_object: Callable, +): + global _INTERNAL_REPLICA_CONTEXT + _INTERNAL_REPLICA_CONTEXT = ReplicaContext( + deployment, replica_tag, controller_name, controller_namespace, servable_object + ) + + +def _connect( + _override_controller_namespace: Optional[str] = None, +) -> ServeControllerClient: + """Connect to an existing Serve instance on this Ray cluster. + + If calling from the driver program, the Serve instance on this Ray cluster + must first have been initialized using `serve.start(detached=True)`. + + If called from within a replica, this will connect to the same Serve + instance that the replica is running in. + + Args: + _override_controller_namespace (Optional[str]): The namespace to use + when looking for the controller. If None, Serve recalculates the + controller's namespace using get_controller_namespace(). + + Raises: + RayServeException: if there is no Serve controller actor in the + expected namespace. + """ + + # Initialize ray if needed. + ray.worker.global_worker.filter_logs_by_job = False + if not ray.is_initialized(): + ray.init(namespace="serve") + + # When running inside of a replica, _INTERNAL_REPLICA_CONTEXT is set to + # ensure that the correct instance is connected to. + if _INTERNAL_REPLICA_CONTEXT is None: + controller_name = SERVE_CONTROLLER_NAME + controller_namespace = get_controller_namespace( + detached=True, _override_controller_namespace=_override_controller_namespace + ) + else: + controller_name = _INTERNAL_REPLICA_CONTEXT._internal_controller_name + controller_namespace = _INTERNAL_REPLICA_CONTEXT._internal_controller_namespace + + # Try to get serve controller if it exists + try: + controller = ray.get_actor(controller_name, namespace=controller_namespace) + except ValueError: + raise RayServeException( + "There is no " + "instance running on this Ray cluster. Please " + "call `serve.start(detached=True) to start " + "one." + ) + + client = ServeControllerClient( + controller, + controller_name, + detached=True, + _override_controller_namespace=_override_controller_namespace, + ) + set_global_client(client) + return client diff --git a/python/ray/serve/deployment.py b/python/ray/serve/deployment.py index c2f83f67f..9bdffbaa9 100644 --- a/python/ray/serve/deployment.py +++ b/python/ray/serve/deployment.py @@ -8,6 +8,8 @@ from typing import ( Tuple, Union, ) + +from ray.serve.context import get_global_client from ray.experimental.dag.class_node import ClassNode from ray.experimental.dag.function_node import FunctionNode from ray.serve.config import ( @@ -23,10 +25,6 @@ from ray.serve.schema import ( ) -# TODO (shrekris-anyscale): remove following dependencies on api.py: -# - internal_get_global_client - - @PublicAPI class Deployment: def __init__( @@ -175,9 +173,7 @@ class Deployment: # this deployment is not exposed over HTTP return None - from ray.serve.api import internal_get_global_client - - return internal_get_global_client().root_url + self.route_prefix + return get_global_client().root_url + self.route_prefix def __call__(self): raise RuntimeError( @@ -237,9 +233,7 @@ class Deployment: if len(init_kwargs) == 0 and self._init_kwargs is not None: init_kwargs = self._init_kwargs - from ray.serve.api import internal_get_global_client - - return internal_get_global_client().deploy( + return get_global_client().deploy( self._name, self._func_or_class, init_args, @@ -257,9 +251,7 @@ class Deployment: def delete(self): """Delete this deployment.""" - from ray.serve.api import internal_get_global_client - - return internal_get_global_client().delete_deployments([self._name]) + return get_global_client().delete_deployments([self._name]) @PublicAPI def get_handle( @@ -277,11 +269,7 @@ class Deployment: ServeHandle """ - from ray.serve.api import internal_get_global_client - - return internal_get_global_client().get_handle( - self._name, missing_ok=True, sync=sync - ) + return get_global_client().get_handle(self._name, missing_ok=True, sync=sync) @PublicAPI def options( diff --git a/python/ray/serve/deployment_state.py b/python/ray/serve/deployment_state.py index fc75973b5..ab1cc8af1 100644 --- a/python/ray/serve/deployment_state.py +++ b/python/ray/serve/deployment_state.py @@ -147,7 +147,7 @@ class ActorReplicaWrapper: self._placement_group_name = self._actor_name + "_placement_group" self._detached = detached self._controller_name = controller_name - self._controller_namespace = ray.serve.api._get_controller_namespace( + self._controller_namespace = ray.serve.client.get_controller_namespace( detached, _override_controller_namespace=_override_controller_namespace ) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 00f517839..96b65997d 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -321,7 +321,7 @@ def serve_handle_from_json_dict(d: Dict[str, str]) -> RayServeHandle: if SERVE_HANDLE_JSON_KEY not in d: raise ValueError(f"dict must contain {SERVE_HANDLE_JSON_KEY} key.") - return serve.api.internal_get_global_client().get_handle( + return serve.context.get_global_client().get_handle( d["deployment_name"], sync=d[SERVE_HANDLE_JSON_KEY] == ServeHandleType.SYNC, missing_ok=True, diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 9a4bcb8f7..171dd131c 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -197,7 +197,7 @@ class HTTPProxy: ): # Set the controller name so that serve will connect to the # controller instance this proxy is running in. - ray.serve.api._set_internal_replica_context( + ray.serve.context.set_internal_replica_context( None, None, controller_name, controller_namespace, None ) @@ -205,7 +205,7 @@ class HTTPProxy: self.route_info: Dict[str, EndpointTag] = dict() def get_handle(name): - return serve.api.internal_get_global_client().get_handle( + return serve.context.get_global_client().get_handle( name, sync=False, missing_ok=True, diff --git a/python/ray/serve/http_state.py b/python/ray/serve/http_state.py index e1ed47c25..4e72f971f 100644 --- a/python/ray/serve/http_state.py +++ b/python/ray/serve/http_state.py @@ -35,7 +35,7 @@ class HTTPState: _start_proxies_on_init: bool = True, ): self._controller_name = controller_name - self._controller_namespace = ray.serve.api._get_controller_namespace( + self._controller_namespace = ray.serve.client.get_controller_namespace( detached, _override_controller_namespace=_override_controller_namespace ) self._detached = detached diff --git a/python/ray/serve/replica.py b/python/ray/serve/replica.py index bded5304e..b18ce11ce 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/replica.py @@ -116,7 +116,7 @@ def create_replica_wrapper( # Set the controller name so that serve.connect() in the user's # code will connect to the instance that this deployment is running # in. - ray.serve.api._set_internal_replica_context( + ray.serve.context.set_internal_replica_context( deployment_name, replica_tag, controller_name, @@ -148,7 +148,7 @@ def create_replica_wrapper( await sync_to_async(_callable.__init__)(*init_args, **init_kwargs) # Setting the context again to update the servable_object. - ray.serve.api._set_internal_replica_context( + ray.serve.context.set_internal_replica_context( deployment_name, replica_tag, controller_name, diff --git a/python/ray/serve/scripts.py b/python/ray/serve/scripts.py index 004b4c9e5..c2265bbfc 100644 --- a/python/ray/serve/scripts.py +++ b/python/ray/serve/scripts.py @@ -22,7 +22,7 @@ from ray.dashboard.modules.dashboard_sdk import parse_runtime_env_args from ray.dashboard.modules.serve.sdk import ServeSubmissionClient from ray.autoscaler._private.cli_logger import cli_logger from ray.serve.api import build as build_app -from ray.serve.api import Application +from ray.serve.application import Application from ray.serve.deployment_graph import ( FunctionNode, ClassNode, @@ -140,7 +140,7 @@ def shutdown(address: str, namespace: str): address=address, namespace=namespace, ) - serve.api._connect() + serve.context._connect() serve.shutdown() diff --git a/python/ray/serve/tests/fault_tolerance_tests/test_controller_recovery.py b/python/ray/serve/tests/fault_tolerance_tests/test_controller_recovery.py index c45e30d14..5282a22ef 100644 --- a/python/ray/serve/tests/fault_tolerance_tests/test_controller_recovery.py +++ b/python/ray/serve/tests/fault_tolerance_tests/test_controller_recovery.py @@ -68,7 +68,7 @@ def test_recover_start_from_replica_actor_names(serve_instance): ), "Should have two running replicas fetched from ray API." # Kill controller and wait for endpoint to be available again - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) for _ in range(10): response = request_with_retries( "/recover_start_from_replica_actor_names/", timeout=30 @@ -171,7 +171,7 @@ def test_recover_rolling_update_from_replica_actor_names(serve_instance): responses2, blocking2 = make_nonblocking_calls({"1": 1}, expect_blocking=True) assert list(responses2["1"])[0] in pids1 - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) # Redeploy new version. Since there is one replica blocking, only one new # replica should be started up. @@ -181,7 +181,7 @@ def test_recover_rolling_update_from_replica_actor_names(serve_instance): client._wait_for_deployment_healthy(V2.name, timeout_s=0.1) responses3, blocking3 = make_nonblocking_calls({"1": 1}, expect_blocking=True) - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) # Signal the original call to exit. ray.get(signal.send.remote()) diff --git a/python/ray/serve/tests/test_cluster.py b/python/ray/serve/tests/test_cluster.py index 3312a1123..86f9b4523 100644 --- a/python/ray/serve/tests/test_cluster.py +++ b/python/ray/serve/tests/test_cluster.py @@ -44,7 +44,7 @@ def test_scale_up(ray_cluster): return pids serve.start(detached=True) - client = serve.api._connect() + client = serve.context._connect() D.deploy() pids1 = get_pids(1) diff --git a/python/ray/serve/tests/test_controller.py b/python/ray/serve/tests/test_controller.py index eb46b7d7a..a6bdf1103 100644 --- a/python/ray/serve/tests/test_controller.py +++ b/python/ray/serve/tests/test_controller.py @@ -10,7 +10,7 @@ from ray.serve.generated.serve_pb2 import DeploymentRoute def test_redeploy_start_time(serve_instance): """Check that redeploying a deployment doesn't reset its start time.""" - controller = serve.api._global_client._controller + controller = serve.context._global_client._controller @serve.deployment def test(_): diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 987f46497..aae74f187 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -34,7 +34,7 @@ def test_controller_failure(serve_instance): response = request_with_retries("/controller_failure/", timeout=30) assert response.text == "hello1" - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure/", timeout=30) @@ -43,7 +43,7 @@ def test_controller_failure(serve_instance): def function2(_): return "hello2" - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) function.options(func_or_class=function2).deploy() @@ -57,9 +57,9 @@ def test_controller_failure(serve_instance): def function3(_): return "hello3" - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) function3.deploy() - ray.kill(serve.api._global_client._controller, no_restart=False) + ray.kill(serve.context._global_client._controller, no_restart=False) for _ in range(10): response = request_with_retries("/controller_failure/", timeout=30) @@ -70,7 +70,7 @@ def test_controller_failure(serve_instance): def _kill_http_proxies(): http_proxies = ray.get( - serve.api._global_client._controller.get_http_proxies.remote() + serve.context._global_client._controller.get_http_proxies.remote() ) for http_proxy in http_proxies.values(): ray.kill(http_proxy, no_restart=False) @@ -107,7 +107,7 @@ def test_http_proxy_failure(serve_instance): def _get_worker_handles(deployment): - controller = serve.api._global_client._controller + controller = serve.context._global_client._controller deployment_dict = ray.get(controller._all_running_replicas.remote()) return [replica.actor_handle for replica in deployment_dict[deployment]] diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index 905924d63..5801b59e9 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -8,7 +8,7 @@ from ray.serve.http_state import HTTPState @pytest.fixture def patch_get_namespace(): - with patch("ray.serve.api._get_controller_namespace") as func: + with patch("ray.serve.client.get_controller_namespace") as func: func.return_value = "dummy_namespace" yield diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index bccdfe9e6..91f405e1d 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -11,7 +11,7 @@ import ray from ray.exceptions import GetTimeoutError from ray import serve from ray._private.test_utils import SignalActor -from ray.serve.api import internal_get_global_client +from ray.serve.context import get_global_client @pytest.fixture @@ -146,7 +146,7 @@ def test_nested_actors(serve_instance): def test_handle_cache_out_of_scope(serve_instance): # https://github.com/ray-project/ray/issues/18980 - initial_num_cached = len(internal_get_global_client().handle_cache) + initial_num_cached = len(get_global_client().handle_cache) @serve.deployment(name="f") def f(): @@ -155,7 +155,7 @@ def test_handle_cache_out_of_scope(serve_instance): f.deploy() handle = serve.get_deployment("f").get_handle() - handle_cache = internal_get_global_client().handle_cache + handle_cache = get_global_client().handle_cache assert len(handle_cache) == initial_num_cached + 1 def sender_where_handle_goes_out_of_scope(): diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 6241ef1a5..6581e0e91 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -25,7 +25,7 @@ from ray._private.test_utils import ( from ray.cluster_utils import Cluster, cluster_not_supported from ray import serve -from ray.serve.api import internal_get_global_client +from ray.serve.context import get_global_client from ray.serve.config import HTTPOptions from ray.serve.constants import SERVE_ROOT_URL_ENV_KEY, SERVE_PROXY_NAME from ray.serve.exceptions import RayServeException @@ -58,12 +58,12 @@ def test_shutdown(ray_shutdown): f.deploy() - serve_controller_name = serve.api._global_client._controller_name + serve_controller_name = serve.context._global_client._controller_name actor_names = [ serve_controller_name, format_actor_name( SERVE_PROXY_NAME, - serve.api._global_client._controller_name, + serve.context._global_client._controller_name, get_all_node_ids()[0][0], ), ] @@ -124,7 +124,7 @@ def test_detached_deployment(ray_cluster): assert ray.get(f.get_handle().remote()) == "from_f" assert requests.get("http://localhost:8000/say_hi_f").text == "from_f" - serve.api._global_client = None + serve.context._global_client = None ray.shutdown() # Create the second job, make sure we can still create new deployments. @@ -200,7 +200,9 @@ def test_multiple_routers(ray_cluster): for node_id, _ in get_all_node_ids(): proxy_names.append( format_actor_name( - SERVE_PROXY_NAME, serve.api._global_client._controller_name, node_id + SERVE_PROXY_NAME, + serve.context._global_client._controller_name, + node_id, ) ) return proxy_names @@ -375,7 +377,7 @@ def test_no_http(ray_shutdown): if actor["State"] == convert_actor_state(gcs_utils.ActorTableData.ALIVE) ] assert len(live_actors) == 1 - controller = serve.api._global_client._controller + controller = serve.context._global_client._controller assert len(ray.get(controller.get_http_proxies.remote())) == 0 # Test that the handle still works. @@ -446,7 +448,7 @@ def test_fixed_number_proxies(ray_cluster): ) # Only the controller and two http proxy should be started. - controller_handle = internal_get_global_client()._controller + controller_handle = get_global_client()._controller node_to_http_actors = ray.get(controller_handle.get_http_proxies.remote()) assert len(node_to_http_actors) == 2 @@ -507,7 +509,7 @@ def test_serve_controller_namespace( ray.init(namespace=namespace) serve.start(detached=detached) - client = serve.api._global_client + client = serve.context._global_client if namespace: controller_namespace = namespace elif detached: diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py index f055d3f16..667d0d7b9 100644 --- a/python/ray/serve/tests/test_standalone2.py +++ b/python/ray/serve/tests/test_standalone2.py @@ -11,7 +11,7 @@ from ray.exceptions import RayActorError import ray import ray.state from ray import serve -from ray.serve.api import internal_get_global_client +from ray.serve.context import get_global_client from ray._private.test_utils import wait_for_condition from ray.tests.conftest import call_ray_stop_only # noqa: F401 @@ -76,7 +76,7 @@ def test_override_namespace(shutdown_ray, detached): ray.init(namespace=ray_namespace) serve.start(detached=detached, _override_controller_namespace=controller_namespace) - controller_name = internal_get_global_client()._controller_name + controller_name = get_global_client()._controller_name ray.get_actor(controller_name, namespace=controller_namespace) serve.shutdown() @@ -148,7 +148,7 @@ def test_refresh_controller_after_death(shutdown_ray, detached): serve.shutdown() # Ensure serve isn't running before beginning the test serve.start(detached=detached, _override_controller_namespace=controller_namespace) - old_handle = internal_get_global_client()._controller + old_handle = get_global_client()._controller ray.kill(old_handle, no_restart=True) def controller_died(handle): @@ -163,7 +163,7 @@ def test_refresh_controller_after_death(shutdown_ray, detached): # Call start again to refresh handle serve.start(detached=detached, _override_controller_namespace=controller_namespace) - new_handle = internal_get_global_client()._controller + new_handle = get_global_client()._controller assert new_handle is not old_handle # Health check should not error diff --git a/release/long_running_tests/workloads/serve_failure.py b/release/long_running_tests/workloads/serve_failure.py index f8e344601..40bdcac70 100644 --- a/release/long_running_tests/workloads/serve_failure.py +++ b/release/long_running_tests/workloads/serve_failure.py @@ -68,7 +68,7 @@ class RandomKiller: self.kill_period_s = kill_period_s def _get_all_serve_actors(self): - controller = serve.api.internal_get_global_client()._controller + controller = serve.context.get_global_client()._controller routers = list(ray.get(controller.get_http_proxies.remote()).values()) all_handles = routers + [controller] worker_handle_dict = ray.get(controller._all_running_replicas.remote()) diff --git a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py index d5afa31e6..e830d0459 100644 --- a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py +++ b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py @@ -83,10 +83,10 @@ def main(): # Kill current cluster, recover from remote checkpoint and ensure endpoint # is still available with expected results - ray.kill(serve.api._global_client._controller, no_restart=True) + ray.kill(serve.context._global_client._controller, no_restart=True) ray.shutdown() cluster.shutdown() - serve.api._set_global_client(None) + serve.context.set_global_client(None) # Start another ray cluster with same namespace to resume from previous # checkpoints with no new deploy() call. diff --git a/release/serve_tests/workloads/serve_cluster_fault_tolerance_gcs.py b/release/serve_tests/workloads/serve_cluster_fault_tolerance_gcs.py index dd296e22e..61bb53e5c 100644 --- a/release/serve_tests/workloads/serve_cluster_fault_tolerance_gcs.py +++ b/release/serve_tests/workloads/serve_cluster_fault_tolerance_gcs.py @@ -82,10 +82,10 @@ def main(): # Kill current cluster, recover from remote checkpoint and ensure endpoint # is still available with expected results - ray.kill(serve.api._global_client._controller, no_restart=True) + ray.kill(serve.context._global_client._controller, no_restart=True) ray.shutdown() cluster.shutdown() - serve.api._set_global_client(None) + serve.context.set_global_client(None) # Start another ray cluster with same namespace to resume from previous # checkpoints with no new deploy() call.