From a09ae6582388b965b047916350f4f26d1b3fcd02 Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Tue, 9 Aug 2022 01:21:12 +0000 Subject: [PATCH] [serve] Make serve agent not blocking when GCS is down. (#27526) (#27674) This PR fixed several issue which block serve agent when GCS is down. We need to make sure serve agent is always alive and can make sure the external requests can be sent to the agent and check the status. - internal kv used in dashboard/agent blocks the agent. We use the async one instead - serve controller use ray.nodes which is a blocking call and blocking forever. change to use gcs client with timeout - agent use serve controller client which is a blocking call with max retries = -1. This blocks until controller is back. To enable Serve HA, we also need to setup: - RAY_gcs_server_request_timeout_seconds=5 - RAY_SERVE_KV_TIMEOUT_S=5 which we should set in KubeRay. --- ci/ci.sh | 1 + dashboard/consts.py | 1 + dashboard/modules/event/event_agent.py | 10 ++- dashboard/modules/healthz/healthz_agent.py | 6 +- dashboard/modules/healthz/utils.py | 4 +- dashboard/modules/node/node_head.py | 1 + dashboard/modules/reporter/reporter_agent.py | 10 ++- dashboard/modules/reporter/reporter_head.py | 22 +++-- dashboard/modules/serve/serve_agent.py | 78 +++++++++++++---- .../tests/test_serve_agent_fault_tolerane.py | 68 +++++++++++++++ dashboard/modules/snapshot/snapshot_head.py | 84 +++++++++---------- dashboard/optional_utils.py | 4 + python/ray/_private/worker.py | 4 + python/ray/serve/_private/constants.py | 3 + python/ray/serve/_private/http_state.py | 10 ++- python/ray/serve/_private/storage/kv_store.py | 9 +- python/ray/serve/_private/utils.py | 18 ++-- python/ray/serve/controller.py | 7 +- python/ray/serve/tests/test_http_state.py | 1 + python/ray/serve/tests/test_standalone.py | 7 +- python/ray/tests/test_gcs_fault_tolerance.py | 32 ++++++- .../test_ray_cluster_with_external_redis.py | 48 ----------- 22 files changed, 282 insertions(+), 146 deletions(-) create mode 100644 dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py delete mode 100644 python/ray/tests/test_ray_cluster_with_external_redis.py diff --git a/ci/ci.sh b/ci/ci.sh index 1a15086b7..669b6bc60 100755 --- a/ci/ci.sh +++ b/ci/ci.sh @@ -135,6 +135,7 @@ test_core() { prepare_docker() { rm "${WORKSPACE_DIR}"/python/dist/* ||: pushd "${WORKSPACE_DIR}/python" + pip install -e . --verbose python setup.py bdist_wheel tmp_dir="/tmp/prepare_docker_$RANDOM" mkdir -p $tmp_dir diff --git a/dashboard/consts.py b/dashboard/consts.py index f2d27abd7..015df9e04 100644 --- a/dashboard/consts.py +++ b/dashboard/consts.py @@ -27,6 +27,7 @@ GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10) GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer( "GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2 ) +GCS_RPC_TIMEOUT_SECONDS = 3 # aiohttp_cache AIOHTTP_CACHE_TTL_SECONDS = 2 AIOHTTP_CACHE_MAX_SIZE = 128 diff --git a/dashboard/modules/event/event_agent.py b/dashboard/modules/event/event_agent.py index a3520e0c2..d79460ab8 100644 --- a/dashboard/modules/event/event_agent.py +++ b/dashboard/modules/event/event_agent.py @@ -7,7 +7,6 @@ import ray._private.ray_constants as ray_constants import ray._private.utils as utils import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray.core.generated import event_pb2, event_pb2_grpc from ray.dashboard.modules.event import event_consts from ray.dashboard.modules.event.event_utils import monitor_events @@ -24,6 +23,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule): self._monitor: Union[asyncio.Task, None] = None self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE) + self._gcs_aio_client = dashboard_agent.gcs_aio_client + logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize) async def _connect_to_dashboard(self): @@ -35,11 +36,12 @@ class EventAgent(dashboard_utils.DashboardAgentModule): """ while True: try: - # TODO: Use async version if performance is an issue - dashboard_rpc_address = internal_kv._internal_kv_get( - dashboard_consts.DASHBOARD_RPC_ADDRESS, + dashboard_rpc_address = await self._gcs_aio_client.internal_kv_get( + dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(), namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=1, ) + dashboard_rpc_address = dashboard_rpc_address.decode() if dashboard_rpc_address: logger.info("Report events to %s", dashboard_rpc_address) options = ray_constants.GLOBAL_GRPC_OPTIONS diff --git a/dashboard/modules/healthz/healthz_agent.py b/dashboard/modules/healthz/healthz_agent.py index cc9bc2dd2..1e5070b62 100644 --- a/dashboard/modules/healthz/healthz_agent.py +++ b/dashboard/modules/healthz/healthz_agent.py @@ -1,7 +1,7 @@ import ray.dashboard.utils as dashboard_utils import ray.dashboard.optional_utils as optional_utils from ray.dashboard.modules.healthz.utils import HealthChecker -from aiohttp.web import Request, Response, HTTPServiceUnavailable +from aiohttp.web import Request, Response import grpc routes = optional_utils.ClassMethodRouteTable @@ -26,7 +26,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule): try: alive = await self._health_checker.check_local_raylet_liveness() if alive is False: - return HTTPServiceUnavailable(reason="Local Raylet failed") + return Response(status=503, text="Local Raylet failed") except grpc.RpcError as e: # We only consider the error other than GCS unreachable as raylet failure # to avoid false positive. @@ -38,7 +38,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule): grpc.StatusCode.UNKNOWN, grpc.StatusCode.DEADLINE_EXCEEDED, ): - return HTTPServiceUnavailable(reason=e.message()) + return Response(status=503, text=f"Health check failed due to: {e}") return Response( text="success", diff --git a/dashboard/modules/healthz/utils.py b/dashboard/modules/healthz/utils.py index e3127e74e..e29d4bef9 100644 --- a/dashboard/modules/healthz/utils.py +++ b/dashboard/modules/healthz/utils.py @@ -14,10 +14,10 @@ class HealthChecker: return False liveness = await self._gcs_aio_client.check_alive( - [self._local_node_address.encode()], 1 + [self._local_node_address.encode()], 0.1 ) return liveness[0] async def check_gcs_liveness(self) -> bool: - await self._gcs_aio_client.check_alive([], 1) + await self._gcs_aio_client.check_alive([], 0.1) return True diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index a65cd6459..6d1531611 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -80,6 +80,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule): # The time it takes until the head node is registered. None means # head node hasn't been registered. self._head_node_registration_time_s = None + self._gcs_aio_client = dashboard_head.gcs_aio_client async def _update_stubs(self, change): if change.old: diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py index 32eb9c4d3..8f21952cd 100644 --- a/dashboard/modules/reporter/reporter_agent.py +++ b/dashboard/modules/reporter/reporter_agent.py @@ -14,9 +14,9 @@ import psutil import ray import ray._private.services import ray._private.utils +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.modules.reporter.reporter_consts as reporter_consts import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray._private.metrics_agent import Gauge, MetricsAgent, Record from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS from ray.core.generated import reporter_pb2, reporter_pb2_grpc @@ -227,7 +227,7 @@ class ReporterAgent( logical_cpu_count = psutil.cpu_count() physical_cpu_count = psutil.cpu_count(logical=False) self._cpu_counts = (logical_cpu_count, physical_cpu_count) - + self._gcs_aio_client = dashboard_agent.gcs_aio_client self._ip = dashboard_agent.ip self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0] self._hostname = socket.gethostname() @@ -787,8 +787,10 @@ class ReporterAgent( """Get any changes to the log files and push updates to kv.""" while True: try: - formatted_status_string = internal_kv._internal_kv_get( - DEBUG_AUTOSCALING_STATUS + formatted_status_string = await self._gcs_aio_client.internal_kv_get( + DEBUG_AUTOSCALING_STATUS.encode(), + None, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) stats = self._get_all_stats() diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 013d08438..eb671dcd7 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -1,7 +1,7 @@ import json import logging import os - +import asyncio import aiohttp.web import yaml @@ -9,8 +9,8 @@ import ray import ray._private.services import ray._private.utils import ray.dashboard.optional_utils as dashboard_optional_utils +from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.utils as dashboard_utils -import ray.experimental.internal_kv as internal_kv from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter from ray._private.ray_constants import ( @@ -40,6 +40,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule): gcs_address = dashboard_head.gcs_address temp_dir = dashboard_head.temp_dir self.service_discovery = PrometheusServiceDiscoveryWriter(gcs_address, temp_dir) + self._gcs_aio_client = dashboard_head.gcs_aio_client async def _update_stubs(self, change): if change.old: @@ -127,15 +128,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule): autoscaler writes them there. """ - assert ray.experimental.internal_kv._internal_kv_initialized() - legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY) - formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS) + (legacy_status, formatted_status_string, error) = await asyncio.gather( + *[ + self._gcs_aio_client.internal_kv_get( + key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS + ) + for key in [ + DEBUG_AUTOSCALING_STATUS_LEGACY, + DEBUG_AUTOSCALING_STATUS, + DEBUG_AUTOSCALING_ERROR, + ] + ] + ) + formatted_status = ( json.loads(formatted_status_string.decode()) if formatted_status_string else {} ) - error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR) return dashboard_optional_utils.rest_response( success=True, message="Got cluster status.", diff --git a/dashboard/modules/serve/serve_agent.py b/dashboard/modules/serve/serve_agent.py index 3d2f1765b..b345766b6 100644 --- a/dashboard/modules/serve/serve_agent.py +++ b/dashboard/modules/serve/serve_agent.py @@ -2,9 +2,9 @@ import json import logging from aiohttp.web import Request, Response - import dataclasses import ray +import asyncio import aiohttp.web import ray.dashboard.optional_utils as optional_utils import ray.dashboard.utils as dashboard_utils @@ -25,6 +25,8 @@ routes = optional_utils.ClassMethodRouteTable class ServeAgent(dashboard_utils.DashboardAgentModule): def __init__(self, dashboard_agent): super().__init__(dashboard_agent) + self._controller = None + self._controller_lock = asyncio.Lock() # TODO: It's better to use `/api/version`. # It requires a refactor of ClassMethodRouteTable to differentiate the server. @@ -48,12 +50,24 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def get_all_deployments(self, req: Request) -> Response: from ray.serve.schema import ServeApplicationSchema - client = self.get_serve_client() + controller = await self.get_serve_controller() - if client is None: + if controller is None: config = ServeApplicationSchema.get_empty_schema_dict() else: - config = client.get_app_config() + try: + config = await controller.get_app_config.remote() + except ray.exceptions.RayTaskError as e: + # Task failure sometimes are due to GCS + # failure. When GCS failed, we expect a longer time + # to recover. + return Response( + status=503, + text=( + "Fail to get the response from the controller. " + f"Potentially the GCS is down: {e}" + ), + ) return Response( text=json.dumps(config), @@ -65,13 +79,20 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def get_all_deployment_statuses(self, req: Request) -> Response: from ray.serve.schema import serve_status_to_schema, ServeStatusSchema - client = self.get_serve_client() + controller = await self.get_serve_controller() - if client is None: + if controller is None: status_json = ServeStatusSchema.get_empty_schema_dict() status_json_str = json.dumps(status_json) else: - status = client.get_serve_status() + from ray.serve._private.common import StatusOverview + from ray.serve.generated.serve_pb2 import ( + StatusOverview as StatusOverviewProto, + ) + + serve_status = await controller.get_serve_status.remote() + proto = StatusOverviewProto.FromString(serve_status) + status = StatusOverview.from_proto(proto) status_json_str = serve_status_to_schema(status).json() return Response( @@ -84,7 +105,7 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): async def delete_serve_application(self, req: Request) -> Response: from ray import serve - if self.get_serve_client() is not None: + if await self.get_serve_controller() is not None: serve.shutdown() return Response() @@ -105,21 +126,44 @@ class ServeAgent(dashboard_utils.DashboardAgentModule): return Response() - def get_serve_client(self): - """Gets the ServeControllerClient to the this cluster's Serve app. + async def get_serve_controller(self): + """Gets the ServeController to the this cluster's Serve app. return: If Serve is running on this Ray cluster, returns a client to the Serve controller. If Serve is not running, returns None. """ + async with self._controller_lock: + if self._controller is not None: + try: + await self._controller.check_alive.remote() + return self._controller + except ray.exceptions.RayActorError: + logger.info("Controller is dead") + self._controller = None - from ray.serve.context import get_global_client - from ray.serve.exceptions import RayServeException + # Try to connect to serve even when we detect the actor is dead + # because the user might have started a new + # serve cluter. + from ray.serve._private.constants import ( + SERVE_CONTROLLER_NAME, + SERVE_NAMESPACE, + ) - try: - return get_global_client(_health_check_controller=True) - except RayServeException: - logger.debug("There's no Serve app running on this Ray cluster.") - return None + try: + # get_actor is a sync call but it'll timeout after + # ray.dashboard.consts.GCS_RPC_TIMEOUT_SECONDS + self._controller = ray.get_actor( + SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE + ) + except Exception as e: + logger.debug( + "There is no " + "instance running on this Ray cluster. Please " + "call `serve.start(detached=True) to start " + f"one: {e}" + ) + + return self._controller async def run(self, server): pass diff --git a/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py b/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py new file mode 100644 index 000000000..b6a34c29d --- /dev/null +++ b/dashboard/modules/serve/tests/test_serve_agent_fault_tolerane.py @@ -0,0 +1,68 @@ +import requests +import pytest +import ray +import sys +from ray import serve +from ray.tests.conftest import * # noqa: F401,F403 +from ray._private.test_utils import generate_system_config_map + +DEPLOYMENTS_URL = "http://localhost:52365/api/serve/deployments/" +STATUS_URL = "http://localhost:52365/api/serve/deployments/status" + + +@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.") +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + { + **generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=3600, + gcs_server_request_timeout_seconds=3, + ), + } + ], + indirect=True, +) +def test_deployments_get_tolerane(monkeypatch, ray_start_regular_with_external_redis): + # test serve agent's availability when gcs is down + monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3") + serve.start(detached=True) + + get_response = requests.get(DEPLOYMENTS_URL, timeout=15) + assert get_response.status_code == 200 + ray._private.worker._global_node.kill_gcs_server() + + get_response = requests.get(DEPLOYMENTS_URL, timeout=30) + assert get_response.status_code == 503 + + +@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.") +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + { + **generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=3600, + gcs_server_request_timeout_seconds=1, + ), + } + ], + indirect=True, +) +def test_status_url_get_tolerane(monkeypatch, ray_start_regular_with_external_redis): + # test serve agent's availability when gcs is down + monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3") + serve.start(detached=True) + get_response = requests.get(STATUS_URL, timeout=15) + assert get_response.status_code == 200 + + ray._private.worker._global_node.kill_gcs_server() + + get_response = requests.get(STATUS_URL, timeout=30) + assert get_response.status_code == 200 + + +if __name__ == "__main__": + sys.exit(pytest.main(["-vs", "--forked", __file__])) diff --git a/dashboard/modules/snapshot/snapshot_head.py b/dashboard/modules/snapshot/snapshot_head.py index 61659b4c8..3af83908a 100644 --- a/dashboard/modules/snapshot/snapshot_head.py +++ b/dashboard/modules/snapshot/snapshot_head.py @@ -12,18 +12,14 @@ import aiohttp.web from pydantic import BaseModel, Extra, Field, validator import ray -from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK +from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK, GCS_RPC_TIMEOUT_SECONDS import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils from ray._private import ray_constants from ray._private.storage import _load_class from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JobInfoStorageClient -from ray.experimental.internal_kv import ( - _internal_kv_get, - _internal_kv_initialized, - _internal_kv_list, -) + from ray.job_submission import JobInfo from ray.runtime_env import RuntimeEnv @@ -87,7 +83,7 @@ class APIHead(dashboard_utils.DashboardHeadModule): self._gcs_job_info_stub = None self._gcs_actor_info_stub = None self._dashboard_head = dashboard_head - assert _internal_kv_initialized() + self._gcs_aio_client = dashboard_head.gcs_aio_client self._job_info_client = None # For offloading CPU intensive work. self._thread_pool = concurrent.futures.ThreadPoolExecutor( @@ -382,47 +378,47 @@ class APIHead(dashboard_utils.DashboardHeadModule): # Serve wraps Ray's internal KV store and specially formats the keys. # These are the keys we are interested in: # SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY - # TODO: Convert to async GRPC, if CPU usage is not a concern. - def get_deployments(): - serve_keys = _internal_kv_list( - SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE - ) - serve_snapshot_keys = filter( - lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys - ) - - deployments_per_controller: List[Dict[str, Any]] = [] - for key in serve_snapshot_keys: - val_bytes = _internal_kv_get( - key, namespace=ray_constants.KV_NAMESPACE_SERVE - ) or "{}".encode("utf-8") - deployments_per_controller.append(json.loads(val_bytes.decode("utf-8"))) - # Merge the deployments dicts of all controllers. - deployments: Dict[str, Any] = { - k: v for d in deployments_per_controller for k, v in d.items() - } - # Replace the keys (deployment names) with their hashes to prevent - # collisions caused by the automatic conversion to camelcase by the - # dashboard agent. - return { - hashlib.sha1(name.encode()).hexdigest(): info - for name, info in deployments.items() - } - - return await asyncio.get_event_loop().run_in_executor( - executor=self._thread_pool, func=get_deployments + serve_keys = await self._gcs_aio_client.internal_kv_keys( + SERVE_CONTROLLER_NAME.encode(), + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) + tasks = [ + self._gcs_aio_client.internal_kv_get( + key, + namespace=ray_constants.KV_NAMESPACE_SERVE, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + for key in serve_keys + if SERVE_SNAPSHOT_KEY in key.decode() + ] + + serve_snapshot_vals = await asyncio.gather(*tasks) + + deployments_per_controller: List[Dict[str, Any]] = [ + json.loads(val.decode()) for val in serve_snapshot_vals + ] + + # Merge the deployments dicts of all controllers. + deployments: Dict[str, Any] = { + k: v for d in deployments_per_controller for k, v in d.items() + } + # Replace the keys (deployment names) with their hashes to prevent + # collisions caused by the automatic conversion to camelcase by the + # dashboard agent. + return { + hashlib.sha1(name.encode()).hexdigest(): info + for name, info in deployments.items() + } + async def get_session_name(self): - # TODO(yic): Convert to async GRPC. - def get_session(): - return ray.experimental.internal_kv._internal_kv_get( - "session_name", namespace=ray_constants.KV_NAMESPACE_SESSION - ).decode() - - return await asyncio.get_event_loop().run_in_executor( - executor=self._thread_pool, func=get_session + session_name = await self._gcs_aio_client.internal_kv_get( + b"session_name", + namespace=ray_constants.KV_NAMESPACE_SESSION, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) + return session_name.decode() async def run(self, server): self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( diff --git a/dashboard/optional_utils.py b/dashboard/optional_utils.py index 3865c0073..7a3a8d1e0 100644 --- a/dashboard/optional_utils.py +++ b/dashboard/optional_utils.py @@ -261,6 +261,10 @@ def init_ray_and_catch_exceptions() -> Callable: try: address = self.get_gcs_address() logger.info(f"Connecting to ray with address={address}") + # Set the gcs rpc timeout to shorter + os.environ["RAY_gcs_server_request_timeout_seconds"] = str( + dashboard_consts.GCS_RPC_TIMEOUT_SECONDS + ) # Init ray without logging to driver # to avoid infinite logging issue. ray.init( diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index df527dd3c..8538e2e33 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2447,6 +2447,10 @@ def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHan have been created with Actor.options(name="name").remote(). This works for both detached & non-detached actors. + This method is a sync call and it'll timeout after 60s. This can be modified + by setting OS env RAY_gcs_server_request_timeout_seconds before starting + the cluster. + Args: name: The name of the actor. namespace: The namespace of the actor, or None to specify the current diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index c3773652f..3343f85de 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -96,6 +96,9 @@ HANDLE_METRIC_PUSH_INTERVAL_S = 10 # Timeout for GCS internal KV service RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None +# Timeout for GCS RPC request +RAY_GCS_RPC_TIMEOUT_S = 3.0 + # Env var to control legacy sync deployment handle behavior in DAG. SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY = "SERVE_DEPLOYMENT_HANDLE_IS_SYNC" diff --git a/python/ray/serve/_private/http_state.py b/python/ray/serve/_private/http_state.py index 5dc338a38..e293162e5 100644 --- a/python/ray/serve/_private/http_state.py +++ b/python/ray/serve/_private/http_state.py @@ -7,6 +7,7 @@ import ray from ray.actor import ActorHandle from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray._private.gcs_utils import GcsClient from ray.serve.config import HTTPOptions, DeploymentMode from ray.serve._private.constants import ( ASYNC_CONCURRENCY, @@ -37,6 +38,7 @@ class HTTPState: detached: bool, config: HTTPOptions, head_node_id: str, + gcs_client: GcsClient, # Used by unit testing _start_proxies_on_init: bool = True, ): @@ -49,6 +51,9 @@ class HTTPState: self._proxy_actors: Dict[NodeId, ActorHandle] = dict() self._proxy_actor_names: Dict[NodeId, str] = dict() self._head_node_id: str = head_node_id + + self._gcs_client = gcs_client + assert isinstance(head_node_id, str) # Will populate self.proxy_actors with existing actors. @@ -75,7 +80,7 @@ class HTTPState: def _get_target_nodes(self) -> List[Tuple[str, str]]: """Return the list of (node_id, ip_address) to deploy HTTP servers on.""" location = self._config.location - target_nodes = get_all_node_ids() + target_nodes = get_all_node_ids(self._gcs_client) if location == DeploymentMode.NoServer: return [] @@ -112,6 +117,7 @@ class HTTPState: def _start_proxies_if_needed(self) -> None: """Start a proxy on every node if it doesn't already exist.""" + for node_id, node_ip_address in self._get_target_nodes(): if node_id in self._proxy_actors: continue @@ -151,7 +157,7 @@ class HTTPState: def _stop_proxies_if_needed(self) -> bool: """Removes proxy actors from any nodes that no longer exist.""" - all_node_ids = {node_id for node_id, _ in get_all_node_ids()} + all_node_ids = {node_id for node_id, _ in get_all_node_ids(self._gcs_client)} to_stop = [] for node_id in self._proxy_actors: if node_id not in all_node_ids: diff --git a/python/ray/serve/_private/storage/kv_store.py b/python/ray/serve/_private/storage/kv_store.py index 55994845d..8ac175ad2 100644 --- a/python/ray/serve/_private/storage/kv_store.py +++ b/python/ray/serve/_private/storage/kv_store.py @@ -28,12 +28,15 @@ class RayInternalKVStore(KVStoreBase): def __init__( self, - namespace: str = None, + namespace: Optional[str] = None, + gcs_client: Optional[GcsClient] = None, ): if namespace is not None and not isinstance(namespace, str): raise TypeError("namespace must a string, got: {}.".format(type(namespace))) - - self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) + if gcs_client is not None: + self.gcs_client = gcs_client + else: + self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) self.timeout = RAY_SERVE_KV_TIMEOUT_S self.namespace = namespace or "" diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 8082ac86f..0b51c9ad6 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -21,7 +21,7 @@ import ray import ray.util.serialization_addons from ray.actor import ActorHandle from ray.exceptions import RayTaskError -from ray.serve._private.constants import HTTP_PROXY_TIMEOUT +from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request from ray.util.serialization import StandaloneSerializationContext @@ -145,19 +145,21 @@ def format_actor_name(actor_name, controller_name=None, *modifiers): return name -def get_all_node_ids() -> List[Tuple[str, str]]: +def get_all_node_ids(gcs_client) -> List[Tuple[str, str]]: """Get IDs for all live nodes in the cluster. Returns a list of (node_id: str, ip_address: str). The node_id can be passed into the Ray SchedulingPolicy API. """ - node_ids = [] - # Sort on NodeID to ensure the ordering is deterministic across the cluster. - for node in sorted(ray.nodes(), key=lambda entry: entry["NodeID"]): - # print(node) - if node["Alive"]: - node_ids.append((node["NodeID"], node["NodeName"])) + nodes = gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S) + node_ids = [ + (ray.NodeID.from_binary(node.node_id).hex(), node.node_name) + for node in nodes.node_info_list + if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE + ] + # Sort on NodeID to ensure the ordering is deterministic across the cluster. + sorted(node_ids) return node_ids diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 96745544e..ad698fad3 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -13,6 +13,7 @@ from ray._private.utils import import_attr from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray.actor import ActorHandle from ray.exceptions import RayTaskError +from ray._private.gcs_utils import GcsClient from ray.serve._private.autoscaling_policy import BasicAutoscalingPolicy from ray.serve._private.common import ( ApplicationStatus, @@ -96,9 +97,10 @@ class ServeController: # Used to read/write checkpoints. self.ray_worker_namespace = ray.get_runtime_context().namespace self.controller_name = controller_name + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) kv_store_namespace = f"{self.controller_name}-{self.ray_worker_namespace}" - self.kv_store = RayInternalKVStore(kv_store_namespace) - self.snapshot_store = RayInternalKVStore(namespace=kv_store_namespace) + self.kv_store = RayInternalKVStore(kv_store_namespace, gcs_client) + self.snapshot_store = RayInternalKVStore(kv_store_namespace, gcs_client) # Dictionary of deployment_name -> proxy_name -> queue length. self.deployment_stats = defaultdict(lambda: defaultdict(dict)) @@ -114,6 +116,7 @@ class ServeController: detached, http_config, head_node_id, + gcs_client, ) self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host) diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index f72dcc312..957a79f95 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -15,6 +15,7 @@ def test_node_selection(): detached=True, config=http_options, head_node_id=head_node_id, + gcs_client=None, _start_proxies_on_init=False, ) diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 58f660fd3..ff7fef025 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -28,6 +28,7 @@ from ray.serve._private.constants import ( SERVE_PROXY_NAME, SERVE_ROOT_URL_ENV_KEY, ) +from ray._private.gcs_utils import GcsClient from ray.serve.context import get_global_client from ray.serve.exceptions import RayServeException from ray.serve.generated.serve_pb2 import ActorNameList @@ -87,6 +88,7 @@ def lower_slow_startup_threshold_and_reset(): def test_shutdown(ray_shutdown): ray.init(num_cpus=16) serve.start(http_options=dict(port=8003)) + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) @serve.deployment def f(): @@ -100,7 +102,7 @@ def test_shutdown(ray_shutdown): format_actor_name( SERVE_PROXY_NAME, serve.context._global_client._controller_name, - get_all_node_ids()[0][0], + get_all_node_ids(gcs_client)[0][0], ), ] @@ -239,10 +241,11 @@ def test_multiple_routers(ray_cluster): node_ids = ray._private.state.node_ids() assert len(node_ids) == 2 serve.start(http_options=dict(port=8005, location="EveryNode")) + gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address) def get_proxy_names(): proxy_names = [] - for node_id, _ in get_all_node_ids(): + for node_id, _ in get_all_node_ids(gcs_client): proxy_names.append( format_actor_name( SERVE_PROXY_NAME, diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index ce03c1c65..38d9b0916 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -1,4 +1,5 @@ import sys +import os import threading from time import sleep @@ -609,8 +610,37 @@ def test_pg_actor_workloads(ray_start_regular_with_external_redis): assert pid == ray.get(c.pid.remote()) +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=60, + gcs_server_request_timeout_seconds=10, + ) + ], + indirect=True, +) +def test_get_actor_when_gcs_is_down(ray_start_regular_with_external_redis): + @ray.remote + def create_actor(): + @ray.remote + class A: + def pid(self): + return os.getpid() + + a = A.options(lifetime="detached", name="A").remote() + ray.get(a.pid.remote()) + + ray.get(create_actor.remote()) + + ray._private.worker._global_node.kill_gcs_server() + + with pytest.raises(ray.exceptions.GetTimeoutError): + ray.get_actor("A") + + if __name__ == "__main__": - import os import pytest diff --git a/python/ray/tests/test_ray_cluster_with_external_redis.py b/python/ray/tests/test_ray_cluster_with_external_redis.py deleted file mode 100644 index 8a1bb8dee..000000000 --- a/python/ray/tests/test_ray_cluster_with_external_redis.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import pytest -import sys - -import ray - - -@pytest.mark.parametrize( - "call_ray_start_with_external_redis", - [ - "6379", - "6379,6380", - "6379,6380,6381", - ], - indirect=True, -) -def test_using_hostnames(call_ray_start_with_external_redis): - ray.init(address="127.0.0.1:6379", _redis_password="123") - - @ray.remote - def f(): - return 1 - - assert ray.get(f.remote()) == 1 - - @ray.remote - class Counter: - def __init__(self): - self.count = 0 - - def inc_and_get(self): - self.count += 1 - return self.count - - counter = Counter.remote() - assert ray.get(counter.inc_and_get.remote()) == 1 - - -if __name__ == "__main__": - import pytest - - # Make subprocess happy in bazel. - os.environ["LC_ALL"] = "en_US.UTF-8" - os.environ["LANG"] = "en_US.UTF-8" - if os.environ.get("PARALLEL_CI"): - sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) - else: - sys.exit(pytest.main(["-sv", __file__]))