diff --git a/ci/ci.sh b/ci/ci.sh index a57622dea..79cb8f546 100755 --- a/ci/ci.sh +++ b/ci/ci.sh @@ -135,6 +135,7 @@ test_core() { prepare_docker() { rm "${WORKSPACE_DIR}"/python/dist/* ||: pushd "${WORKSPACE_DIR}/python" + pip install -e . --verbose python setup.py bdist_wheel tmp_dir="/tmp/prepare_docker_$RANDOM" mkdir -p $tmp_dir diff --git a/dashboard/consts.py b/dashboard/consts.py index f2d27abd7..015df9e04 100644 --- a/dashboard/consts.py +++ b/dashboard/consts.py @@ -27,6 +27,7 @@ GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10) GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer( "GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2 ) +GCS_RPC_TIMEOUT_SECONDS = 3 # aiohttp_cache AIOHTTP_CACHE_TTL_SECONDS = 2 AIOHTTP_CACHE_MAX_SIZE = 128 diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index a3520e0c2..d79460ab8 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -7,7 +7,6 @@ import ray._private.ray_constants as ray_constants import ray._private.utils as utils import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray.core.generated import event_pb2, event_pb2_grpc from ray.dashboard.modules.event import event_consts from ray.dashboard.modules.event.event_utils import monitor_events @@ -24,6 +23,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule): self._monitor: Union[asyncio.Task, None] = None self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE) + self._gcs_aio_client = dashboard_agent.gcs_aio_client + logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize) async def _connect_to_dashboard(self): @@ -35,11 +36,12 @@ class EventAgent(dashboard_utils.DashboardAgentModule): """ while True: try: - # TODO: Use async version if performance is an issue - dashboard_rpc_address = internal_kv._internal_kv_get( - dashboard_consts.DASHBOARD_RPC_ADDRESS, + dashboard_rpc_address = await self._gcs_aio_client.internal_kv_get( + dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(), namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=1, ) + dashboard_rpc_address = dashboard_rpc_address.decode() if dashboard_rpc_address: logger.info("Report events to %s", dashboard_rpc_address) options = ray_constants.GLOBAL_GRPC_OPTIONS diff --git a/dashboard/modules/healthz/healthz_agent.py b/dashboard/modules/healthz/healthz_agent.py index cc9bc2dd2..1e5070b62 100644 --- a/dashboard/modules/healthz/healthz_agent.py +++ b/dashboard/modules/healthz/healthz_agent.py @@ -1,7 +1,7 @@ import ray.dashboard.utils as dashboard_utils import ray.dashboard.optional_utils as optional_utils from ray.dashboard.modules.healthz.utils import HealthChecker -from aiohttp.web import Request, Response, HTTPServiceUnavailable +from aiohttp.web import Request, Response import grpc routes = optional_utils.ClassMethodRouteTable @@ -26,7 +26,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule): try: alive = await self._health_checker.check_local_raylet_liveness() if alive is False: - return HTTPServiceUnavailable(reason="Local Raylet failed") + return Response(status=503, text="Local Raylet failed") except grpc.RpcError as e: # We only consider the error other than GCS unreachable as raylet failure # to avoid false positive. @@ -38,7 +38,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule): grpc.StatusCode.UNKNOWN, grpc.StatusCode.DEADLINE_EXCEEDED, ): - return HTTPServiceUnavailable(reason=e.message()) + return Response(status=503, text=f"Health check failed due to: {e}") return Response( text="success", diff --git a/dashboard/modules/healthz/utils.py b/dashboard/modules/healthz/utils.py index e3127e74e..e29d4bef9 100644 --- a/dashboard/modules/healthz/utils.py +++ b/dashboard/modules/healthz/utils.py @@ -14,10 +14,10 @@ class HealthChecker: return False liveness = await self._gcs_aio_client.check_alive( - [self._local_node_address.encode()], 1 + [self._local_node_address.encode()], 0.1 ) return liveness[0] async def check_gcs_liveness(self) -> bool: - await self._gcs_aio_client.check_alive([], 1) + await self._gcs_aio_client.check_alive([], 0.1) return True diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index a65cd6459..6d1531611 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -80,6 +80,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule): # The time it takes until the head node is registered. None means # head node hasn't been registered. self._head_node_registration_time_s = None + self._gcs_aio_client = dashboard_head.gcs_aio_client async def _update_stubs(self, change): if change.old: diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py index 32eb9c4d3..8f21952cd 100644 --- a/dashboard/modules/reporter/reporter_agent.py +++ b/dashboard/modules/reporter/reporter_agent.py @@ -14,9 +14,9 @@ import psutil import ray import ray._private.services import ray._private.utils +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.modules.reporter.reporter_consts as reporter_consts import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray._private.metrics_agent import Gauge, MetricsAgent, Record from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS from ray.core.generated import reporter_pb2, reporter_pb2_grpc @@ -227,7 +227,7 @@ class ReporterAgent( logical_cpu_count = psutil.cpu_count() physical_cpu_count = psutil.cpu_count(logical=False) self._cpu_counts = (logical_cpu_count, physical_cpu_count) - + self._gcs_aio_client = dashboard_agent.gcs_aio_client self._ip = dashboard_agent.ip self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0] self._hostname = socket.gethostname() @@ -787,8 +787,10 @@ class ReporterAgent( """Get any changes to the log files and push updates to kv.""" while True: try: - formatted_status_string = internal_kv._internal_kv_get( - DEBUG_AUTOSCALING_STATUS + formatted_status_string = await self._gcs_aio_client.internal_kv_get( + DEBUG_AUTOSCALING_STATUS.encode(), + None, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) stats = self._get_all_stats() diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 013d08438..eb671dcd7 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -1,7 +1,7 @@ import json import logging import os - +import asyncio import aiohttp.web import yaml @@ -9,8 +9,8 @@ import ray import ray._private.services import ray._private.utils import ray.dashboard.optional_utils as dashboard_optional_utils +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter from ray._private.ray_constants import ( @@ -40,6 +40,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule): gcs_address = dashboard_head.gcs_address temp_dir = dashboard_head.temp_dir self.service_discovery = PrometheusServiceDiscoveryWriter(gcs_address, temp_dir) + self._gcs_aio_client = dashboard_head.gcs_aio_client async def _update_stubs(self, change): if change.old: @@ -127,15 +128,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule): autoscaler writes them there. """ - assert ray.experimental.internal_kv._internal_kv_initialized() - legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY) - formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS) + (legacy_status, formatted_status_string, error) = await asyncio.gather( + *[ + self._gcs_aio_client.internal_kv_get( + key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS + ) + for key in [ + DEBUG_AUTOSCALING_STATUS_LEGACY, + DEBUG_AUTOSCALING_STATUS, + DEBUG_AUTOSCALING_ERROR, + ] + ] + ) + formatted_status = ( json.loads(formatted_status_string.decode()) if formatted_status_string else {} ) - error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR) return dashboard_optional_utils.rest_response( success=True, message="Got cluster status.", diff --git a/dashboard/modules/serve/serve_agent.py b/dashboard/modules/serve/serve_agent.py index b549d3c9a..e1ed51680 100644 --- a/dashboard/modules/serve/serve_agent.py +++ b/dashboard/modules/serve/serve_agent.py @@ -2,9 +2,9 @@ import json import logging from aiohttp.web import Request, Response - import dataclasses import ray +import asyncio import aiohttp.web import ray.dashboard.optional_utils as optional_utils import ray.dashboard.utils as dashboard_utils @@ -25,6 +25,8 @@ routes = optional_utils.ClassMethodRouteTable class ServeAgent(dashboard_utils.DashboardAgentModule): def __init__(self, dashboard_agent): super().__init__(dashboard_agent) + self._controller = None + self._controller_lock = asyncio.Lock() # TODO: It's better to use `/api/version`. # It requires a refactor of ClassMethodRouteTable to differentiate the server. @@ -48,12 +50,24 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def get_all_deployments(self, req: Request) -> Response: from ray.serve.schema import ServeApplicationSchema - client = self.get_serve_client() + controller = await self.get_serve_controller() - if client is None: + if controller is None: config = ServeApplicationSchema.get_empty_schema_dict() else: - config = client.get_app_config() + try: + config = await controller.get_app_config.remote() + except ray.exceptions.RayTaskError as e: + # Task failure sometimes are due to GCS + # failure. When GCS failed, we expect a longer time + # to recover. + return Response( + status=503, + text=( + "Fail to get the response from the controller. " + f"Potentially the GCS is down: {e}" + ), + ) return Response( text=json.dumps(config), @@ -65,13 +79,20 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def get_all_deployment_statuses(self, req: Request) -> Response: from ray.serve.schema import serve_status_to_schema, ServeStatusSchema - client = self.get_serve_client() + controller = await self.get_serve_controller() - if client is None: + if controller is None: status_json = ServeStatusSchema.get_empty_schema_dict() status_json_str = json.dumps(status_json) else: - status = client.get_serve_status() + from ray.serve._private.common import StatusOverview + from ray.serve.generated.serve_pb2 import ( + StatusOverview as StatusOverviewProto, + ) + + serve_status = await controller.get_serve_status.remote() + proto = StatusOverviewProto.FromString(serve_status) + status = StatusOverview.from_proto(proto) status_json_str = serve_status_to_schema(status).json() return Response( @@ -84,7 +105,7 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def delete_serve_application(self, req: Request) -> Response: from ray import serve - if self.get_serve_client() is not None: + if await self.get_serve_controller() is not None: serve.shutdown() return Response() @@ -144,21 +165,44 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): return Response() - def get_serve_client(self): - """Gets the ServeControllerClient to the this cluster's Serve app. + async def get_serve_controller(self): + """Gets the ServeController to the this cluster's Serve app. return: If Serve is running on this Ray cluster, returns a client to the Serve controller. If Serve is not running, returns None. """ + async with self._controller_lock: + if self._controller is not None: + try: + await self._controller.check_alive.remote() + return self._controller + except ray.exceptions.RayActorError: + logger.info("Controller is dead") + self._controller = None - from ray.serve.context import get_global_client - from ray.serve.exceptions import RayServeException + # Try to connect to serve even when we detect the actor is dead + # because the user might have started a new + # serve cluter. + from ray.serve._private.constants import ( + SERVE_CONTROLLER_NAME, + SERVE_NAMESPACE, + ) - try: - return get_global_client(_health_check_controller=True) - except RayServeException: - logger.debug("There's no Serve app running on this Ray cluster.") - return None + try: + # get_actor is a sync call but it'll timeout after + # ray.dashboard.consts.GCS_RPC_TIMEOUT_SECONDS + self._controller = ray.get_actor( + SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE + ) + except Exception as e: + logger.debug( + "There is no " + "instance running on this Ray cluster. Please " + "call `serve.start(detached=True) to start " + f"one: {e}" + ) + + return self._controller async def run(self, server): pass diff --git a/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py b/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py new file mode 100644 index 000000000..b6a34c29d --- /dev/null +++ b/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py @@ -0,0 +1,68 @@ +import requests +import pytest +import ray +import sys +from ray import serve +from ray.tests.conftest import * # noqa: F401,F403 +from ray._private.test_utils import generate_system_config_map + +DEPLOYMENTS_URL = "http://localhost:52365/api/serve/deployments/" +STATUS_URL = "http://localhost:52365/api/serve/deployments/status" + + +@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.") +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + { + **generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=3600, + gcs_server_request_timeout_seconds=3, + ), + } + ], + indirect=True, +) +def test_deployments_get_tolerane(monkeypatch, ray_start_regular_with_external_redis): + # test serve agent's availability when gcs is down + monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3") + serve.start(detached=True) + + get_response = requests.get(DEPLOYMENTS_URL, timeout=15) + assert get_response.status_code == 200 + ray._private.worker._global_node.kill_gcs_server() + + get_response = requests.get(DEPLOYMENTS_URL, timeout=30) + assert get_response.status_code == 503 + + +@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.") +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + { + **generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=3600, + gcs_server_request_timeout_seconds=1, + ), + } + ], + indirect=True, +) +def test_status_url_get_tolerane(monkeypatch, ray_start_regular_with_external_redis): + # test serve agent's availability when gcs is down + monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3") + serve.start(detached=True) + get_response = requests.get(STATUS_URL, timeout=15) + assert get_response.status_code == 200 + + ray._private.worker._global_node.kill_gcs_server() + + get_response = requests.get(STATUS_URL, timeout=30) + assert get_response.status_code == 200 + + +if __name__ == "__main__": + sys.exit(pytest.main(["-vs", "--forked", __file__])) diff --git a/dashboard/modules/snapshot/snapshot_head.py b/dashboard/modules/snapshot/snapshot_head.py index 61659b4c8..3af83908a 100644 --- a/dashboard/modules/snapshot/snapshot_head.py +++ b/dashboard/modules/snapshot/snapshot_head.py @@ -12,18 +12,14 @@ import aiohttp.web from pydantic import BaseModel, Extra, Field, validator import ray -from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK +from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK, GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils from ray._private import ray_constants from ray._private.storage import _load_class from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JobInfoStorageClient -from ray.experimental.internal_kv import ( - _internal_kv_get, - _internal_kv_initialized, - _internal_kv_list, -) + from ray.job_submission import JobInfo from ray.runtime_env import RuntimeEnv @@ -87,7 +83,7 @@ class APIHead(dashboard_utils.DashboardHeadModule): self._gcs_job_info_stub = None self._gcs_actor_info_stub = None self._dashboard_head = dashboard_head - assert _internal_kv_initialized() + self._gcs_aio_client = dashboard_head.gcs_aio_client self._job_info_client = None # For offloading CPU intensive work. self._thread_pool = concurrent.futures.ThreadPoolExecutor( @@ -382,47 +378,47 @@ class APIHead(dashboard_utils.DashboardHeadModule): # Serve wraps Ray's internal KV store and specially formats the keys. # These are the keys we are interested in: # SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY - # TODO: Convert to async GRPC, if CPU usage is not a concern. - def get_deployments(): - serve_keys = _internal_kv_list( - SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE - ) - serve_snapshot_keys = filter( - lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys - ) - - deployments_per_controller: List[Dict[str, Any]] = [] - for key in serve_snapshot_keys: - val_bytes = _internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_SERVE - ) or "{}".encode("utf-8") - deployments_per_controller.append(json.loads(val_bytes.decode("utf-8"))) - # Merge the deployments dicts of all controllers. - deployments: Dict[str, Any] = { - k: v for d in deployments_per_controller for k, v in d.items() - } - # Replace the keys (deployment names) with their hashes to prevent - # collisions caused by the automatic conversion to camelcase by the - # dashboard agent. - return { - hashlib.sha1(name.encode()).hexdigest(): info - for name, info in deployments.items() - } - - return await asyncio.get_event_loop().run_in_executor( - executor=self._thread_pool, func=get_deployments + serve_keys = await self._gcs_aio_client.internal_kv_keys( + SERVE_CONTROLLER_NAME.encode(), + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) + tasks = [ + self._gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + for key in serve_keys + if SERVE_SNAPSHOT_KEY in key.decode() + ] + + serve_snapshot_vals = await asyncio.gather(*tasks) + + deployments_per_controller: List[Dict[str, Any]] = [ + json.loads(val.decode()) for val in serve_snapshot_vals + ] + + # Merge the deployments dicts of all controllers. + deployments: Dict[str, Any] = { + k: v for d in deployments_per_controller for k, v in d.items() + } + # Replace the keys (deployment names) with their hashes to prevent + # collisions caused by the automatic conversion to camelcase by the + # dashboard agent. + return { + hashlib.sha1(name.encode()).hexdigest(): info + for name, info in deployments.items() + } + async def get_session_name(self): - # TODO(yic): Convert to async GRPC. - def get_session(): - return ray.experimental.internal_kv._internal_kv_get( - "session_name", namespace=ray_constants.KV_NAMESPACE_SESSION - ).decode() - - return await asyncio.get_event_loop().run_in_executor( - executor=self._thread_pool, func=get_session + session_name = await self._gcs_aio_client.internal_kv_get( + b"session_name", + namespace=ray_constants.KV_NAMESPACE_SESSION, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) + return session_name.decode() async def run(self, server): self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 3865c0073..7a3a8d1e0 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -261,6 +261,10 @@ def init_ray_and_catch_exceptions() -> Callable: try: address = self.get_gcs_address() logger.info(f"Connecting to ray with address={address}") + # Set the gcs rpc timeout to shorter + os.environ["RAY_gcs_server_request_timeout_seconds"] = str( + dashboard_consts.GCS_RPC_TIMEOUT_SECONDS + ) # Init ray without logging to driver # to avoid infinite logging issue. ray.init( diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 0922c546f..2dddd89c3 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2447,6 +2447,10 @@ def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHan have been created with Actor.options(name="name").remote(). This works for both detached & non-detached actors. + This method is a sync call and it'll timeout after 60s. This can be modified + by setting OS env RAY_gcs_server_request_timeout_seconds before starting + the cluster. + Args: name: The name of the actor. namespace: The namespace of the actor, or None to specify the current diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index c3773652f..3343f85de 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -96,6 +96,9 @@ HANDLE_METRIC_PUSH_INTERVAL_S = 10 # Timeout for GCS internal KV service RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None +# Timeout for GCS RPC request +RAY_GCS_RPC_TIMEOUT_S = 3.0 + # Env var to control legacy sync deployment handle behavior in DAG. SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY = "SERVE_DEPLOYMENT_HANDLE_IS_SYNC" diff --git a/python/ray/serve/_private/http_state.py b/python/ray/serve/_private/http_state.py index 5dc338a38..e293162e5 100644 --- a/python/ray/serve/_private/http_state.py +++ b/python/ray/serve/_private/http_state.py @@ -7,6 +7,7 @@ import ray from ray.actor import ActorHandle from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray._private.gcs_utils import GcsClient from ray.serve.config import HTTPOptions, DeploymentMode from ray.serve._private.constants import ( ASYNC_CONCURRENCY, @@ -37,6 +38,7 @@ class HTTPState: detached: bool, config: HTTPOptions, head_node_id: str, + gcs_client: GcsClient, # Used by unit testing _start_proxies_on_init: bool = True, ): @@ -49,6 +51,9 @@ class HTTPState: self._proxy_actors: Dict[NodeId, ActorHandle] = dict() self._proxy_actor_names: Dict[NodeId, str] = dict() self._head_node_id: str = head_node_id + + self._gcs_client = gcs_client + assert isinstance(head_node_id, str) # Will populate self.proxy_actors with existing actors. @@ -75,7 +80,7 @@ class HTTPState: def _get_target_nodes(self) -> List[Tuple[str, str]]: """Return the list of (node_id, ip_address) to deploy HTTP servers on.""" location = self._config.location - target_nodes = get_all_node_ids() + target_nodes = get_all_node_ids(self._gcs_client) if location == DeploymentMode.NoServer: return [] @@ -112,6 +117,7 @@ class HTTPState: def _start_proxies_if_needed(self) -> None: """Start a proxy on every node if it doesn't already exist.""" + for node_id, node_ip_address in self._get_target_nodes(): if node_id in self._proxy_actors: continue @@ -151,7 +157,7 @@ class HTTPState: def _stop_proxies_if_needed(self) -> bool: """Removes proxy actors from any nodes that no longer exist.""" - all_node_ids = {node_id for node_id, _ in get_all_node_ids()} + all_node_ids = {node_id for node_id, _ in get_all_node_ids(self._gcs_client)} to_stop = [] for node_id in self._proxy_actors: if node_id not in all_node_ids: diff --git a/python/ray/serve/_private/storage/kv_store.py b/python/ray/serve/_private/storage/kv_store.py index 55994845d..8ac175ad2 100644 --- a/python/ray/serve/_private/storage/kv_store.py +++ b/python/ray/serve/_private/storage/kv_store.py @@ -28,12 +28,15 @@ class RayInternalKVStore(KVStoreBase): def __init__( self, - namespace: str = None, + namespace: Optional[str] = None, + gcs_client: Optional[GcsClient] = None, ): if namespace is not None and not isinstance(namespace, str): raise TypeError("namespace must a string, got: {}.".format(type(namespace))) - - self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) + if gcs_client is not None: + self.gcs_client = gcs_client + else: + self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) self.timeout = RAY_SERVE_KV_TIMEOUT_S self.namespace = namespace or "" diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 8082ac86f..0b51c9ad6 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -21,7 +21,7 @@ import ray import ray.util.serialization_addons from ray.actor import ActorHandle from ray.exceptions import RayTaskError -from ray.serve._private.constants import HTTP_PROXY_TIMEOUT +from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request from ray.util.serialization import StandaloneSerializationContext @@ -145,19 +145,21 @@ def format_actor_name(actor_name, controller_name=None, *modifiers): return name -def get_all_node_ids() -> List[Tuple[str, str]]: +def get_all_node_ids(gcs_client) -> List[Tuple[str, str]]: """Get IDs for all live nodes in the cluster. Returns a list of (node_id: str, ip_address: str). The node_id can be passed into the Ray SchedulingPolicy API. """ - node_ids = [] - # Sort on NodeID to ensure the ordering is deterministic across the cluster. - for node in sorted(ray.nodes(), key=lambda entry: entry["NodeID"]): - # print(node) - if node["Alive"]: - node_ids.append((node["NodeID"], node["NodeName"])) + nodes = gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S) + node_ids = [ + (ray.NodeID.from_binary(node.node_id).hex(), node.node_name) + for node in nodes.node_info_list + if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE + ] + # Sort on NodeID to ensure the ordering is deterministic across the cluster. + sorted(node_ids) return node_ids diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index f2ec3a7bc..6adf6ff04 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -13,6 +13,7 @@ from ray._private.utils import import_attr from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray.actor import ActorHandle from ray.exceptions import RayTaskError +from ray._private.gcs_utils import GcsClient from ray.serve._private.autoscaling_policy import BasicAutoscalingPolicy from ray.serve._private.common import ( ApplicationStatus, @@ -96,9 +97,10 @@ class ServeController: # Used to read/write checkpoints. self.ray_worker_namespace = ray.get_runtime_context().namespace self.controller_name = controller_name + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) kv_store_namespace = f"{self.controller_name}-{self.ray_worker_namespace}" - self.kv_store = RayInternalKVStore(kv_store_namespace) - self.snapshot_store = RayInternalKVStore(namespace=kv_store_namespace) + self.kv_store = RayInternalKVStore(kv_store_namespace, gcs_client) + self.snapshot_store = RayInternalKVStore(kv_store_namespace, gcs_client) # Dictionary of deployment_name -> proxy_name -> queue length. self.deployment_stats = defaultdict(lambda: defaultdict(dict)) @@ -114,6 +116,7 @@ class ServeController: detached, http_config, head_node_id, + gcs_client, ) self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host) diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index f72dcc312..957a79f95 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -15,6 +15,7 @@ def test_node_selection(): detached=True, config=http_options, head_node_id=head_node_id, + gcs_client=None, _start_proxies_on_init=False, ) diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 58f660fd3..ff7fef025 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -28,6 +28,7 @@ from ray.serve._private.constants import ( SERVE_PROXY_NAME, SERVE_ROOT_URL_ENV_KEY, ) +from ray._private.gcs_utils import GcsClient from ray.serve.context import get_global_client from ray.serve.exceptions import RayServeException from ray.serve.generated.serve_pb2 import ActorNameList @@ -87,6 +88,7 @@ def lower_slow_startup_threshold_and_reset(): def test_shutdown(ray_shutdown): ray.init(num_cpus=16) serve.start(http_options=dict(port=8003)) + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) @serve.deployment def f(): @@ -100,7 +102,7 @@ def test_shutdown(ray_shutdown): format_actor_name( SERVE_PROXY_NAME, serve.context._global_client._controller_name, - get_all_node_ids()[0][0], + get_all_node_ids(gcs_client)[0][0], ), ] @@ -239,10 +241,11 @@ def test_multiple_routers(ray_cluster): node_ids = ray._private.state.node_ids() assert len(node_ids) == 2 serve.start(http_options=dict(port=8005, location="EveryNode")) + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) def get_proxy_names(): proxy_names = [] - for node_id, _ in get_all_node_ids(): + for node_id, _ in get_all_node_ids(gcs_client): proxy_names.append( format_actor_name( SERVE_PROXY_NAME, diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index ce03c1c65..38d9b0916 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -1,4 +1,5 @@ import sys +import os import threading from time import sleep @@ -609,8 +610,37 @@ def test_pg_actor_workloads(ray_start_regular_with_external_redis): assert pid == ray.get(c.pid.remote()) +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=60, + gcs_server_request_timeout_seconds=10, + ) + ], + indirect=True, +) +def test_get_actor_when_gcs_is_down(ray_start_regular_with_external_redis): + @ray.remote + def create_actor(): + @ray.remote + class A: + def pid(self): + return os.getpid() + + a = A.options(lifetime="detached", name="A").remote() + ray.get(a.pid.remote()) + + ray.get(create_actor.remote()) + + ray._private.worker._global_node.kill_gcs_server() + + with pytest.raises(ray.exceptions.GetTimeoutError): + ray.get_actor("A") + + if __name__ == "__main__": - import os import pytest diff --git a/python/ray/tests/test_ray_cluster_with_external_redis.py b/python/ray/tests/test_ray_cluster_with_external_redis.py deleted file mode 100644 index 8a1bb8dee..000000000 --- a/python/ray/tests/test_ray_cluster_with_external_redis.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import pytest -import sys - -import ray - - -@pytest.mark.parametrize( - "call_ray_start_with_external_redis", - [ - "6379", - "6379,6380", - "6379,6380,6381", - ], - indirect=True, -) -def test_using_hostnames(call_ray_start_with_external_redis): - ray.init(address="127.0.0.1:6379", _redis_password="123") - - @ray.remote - def f(): - return 1 - - assert ray.get(f.remote()) == 1 - - @ray.remote - class Counter: - def __init__(self): - self.count = 0 - - def inc_and_get(self): - self.count += 1 - return self.count - - counter = Counter.remote() - assert ray.get(counter.inc_and_get.remote()) == 1 - - -if __name__ == "__main__": - import pytest - - # Make subprocess happy in bazel. - os.environ["LC_ALL"] = "en_US.UTF-8" - os.environ["LANG"] = "en_US.UTF-8" - if os.environ.get("PARALLEL_CI"): - sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) - else: - sys.exit(pytest.main(["-sv", __file__]))