mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[serve] Make serve agent not blocking when GCS is down. (#27526)
This PR fixed several issue which block serve agent when GCS is down. We need to make sure serve agent is always alive and can make sure the external requests can be sent to the agent and check the status. - internal kv used in dashboard/agent blocks the agent. We use the async one instead - serve controller use ray.nodes which is a blocking call and blocking forever. change to use gcs client with timeout - agent use serve controller client which is a blocking call with max retries = -1. This blocks until controller is back. To enable Serve HA, we also need to setup: - RAY_gcs_server_request_timeout_seconds=5 - RAY_SERVE_KV_TIMEOUT_S=5 which we should set in KubeRay.
This commit is contained in:
parent
87ff765647
commit
dac7bf17d9
22 changed files with 282 additions and 146 deletions
1
ci/ci.sh
1
ci/ci.sh
|
@ -135,6 +135,7 @@ test_core() {
|
|||
prepare_docker() {
|
||||
rm "${WORKSPACE_DIR}"/python/dist/* ||:
|
||||
pushd "${WORKSPACE_DIR}/python"
|
||||
pip install -e . --verbose
|
||||
python setup.py bdist_wheel
|
||||
tmp_dir="/tmp/prepare_docker_$RANDOM"
|
||||
mkdir -p $tmp_dir
|
||||
|
|
|
@ -27,6 +27,7 @@ GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
|
|||
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
|
||||
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
|
||||
)
|
||||
GCS_RPC_TIMEOUT_SECONDS = 3
|
||||
# aiohttp_cache
|
||||
AIOHTTP_CACHE_TTL_SECONDS = 2
|
||||
AIOHTTP_CACHE_MAX_SIZE = 128
|
||||
|
|
|
@ -7,7 +7,6 @@ import ray._private.ray_constants as ray_constants
|
|||
import ray._private.utils as utils
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.experimental.internal_kv as internal_kv
|
||||
from ray.core.generated import event_pb2, event_pb2_grpc
|
||||
from ray.dashboard.modules.event import event_consts
|
||||
from ray.dashboard.modules.event.event_utils import monitor_events
|
||||
|
@ -24,6 +23,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
self._monitor: Union[asyncio.Task, None] = None
|
||||
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
|
||||
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
|
||||
self._gcs_aio_client = dashboard_agent.gcs_aio_client
|
||||
|
||||
logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
|
||||
|
||||
async def _connect_to_dashboard(self):
|
||||
|
@ -35,11 +36,12 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
"""
|
||||
while True:
|
||||
try:
|
||||
# TODO: Use async version if performance is an issue
|
||||
dashboard_rpc_address = internal_kv._internal_kv_get(
|
||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||
dashboard_rpc_address = await self._gcs_aio_client.internal_kv_get(
|
||||
dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(),
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
timeout=1,
|
||||
)
|
||||
dashboard_rpc_address = dashboard_rpc_address.decode()
|
||||
if dashboard_rpc_address:
|
||||
logger.info("Report events to %s", dashboard_rpc_address)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as optional_utils
|
||||
from ray.dashboard.modules.healthz.utils import HealthChecker
|
||||
from aiohttp.web import Request, Response, HTTPServiceUnavailable
|
||||
from aiohttp.web import Request, Response
|
||||
import grpc
|
||||
|
||||
routes = optional_utils.ClassMethodRouteTable
|
||||
|
@ -26,7 +26,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule):
|
|||
try:
|
||||
alive = await self._health_checker.check_local_raylet_liveness()
|
||||
if alive is False:
|
||||
return HTTPServiceUnavailable(reason="Local Raylet failed")
|
||||
return Response(status=503, text="Local Raylet failed")
|
||||
except grpc.RpcError as e:
|
||||
# We only consider the error other than GCS unreachable as raylet failure
|
||||
# to avoid false positive.
|
||||
|
@ -38,7 +38,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule):
|
|||
grpc.StatusCode.UNKNOWN,
|
||||
grpc.StatusCode.DEADLINE_EXCEEDED,
|
||||
):
|
||||
return HTTPServiceUnavailable(reason=e.message())
|
||||
return Response(status=503, text=f"Health check failed due to: {e}")
|
||||
|
||||
return Response(
|
||||
text="success",
|
||||
|
|
|
@ -14,10 +14,10 @@ class HealthChecker:
|
|||
return False
|
||||
|
||||
liveness = await self._gcs_aio_client.check_alive(
|
||||
[self._local_node_address.encode()], 1
|
||||
[self._local_node_address.encode()], 0.1
|
||||
)
|
||||
return liveness[0]
|
||||
|
||||
async def check_gcs_liveness(self) -> bool:
|
||||
await self._gcs_aio_client.check_alive([], 1)
|
||||
await self._gcs_aio_client.check_alive([], 0.1)
|
||||
return True
|
||||
|
|
|
@ -80,6 +80,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
# The time it takes until the head node is registered. None means
|
||||
# head node hasn't been registered.
|
||||
self._head_node_registration_time_s = None
|
||||
self._gcs_aio_client = dashboard_head.gcs_aio_client
|
||||
|
||||
async def _update_stubs(self, change):
|
||||
if change.old:
|
||||
|
|
|
@ -14,9 +14,9 @@ import psutil
|
|||
import ray
|
||||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
|
||||
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.experimental.internal_kv as internal_kv
|
||||
from ray._private.metrics_agent import Gauge, MetricsAgent, Record
|
||||
from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS
|
||||
from ray.core.generated import reporter_pb2, reporter_pb2_grpc
|
||||
|
@ -227,7 +227,7 @@ class ReporterAgent(
|
|||
logical_cpu_count = psutil.cpu_count()
|
||||
physical_cpu_count = psutil.cpu_count(logical=False)
|
||||
self._cpu_counts = (logical_cpu_count, physical_cpu_count)
|
||||
|
||||
self._gcs_aio_client = dashboard_agent.gcs_aio_client
|
||||
self._ip = dashboard_agent.ip
|
||||
self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
|
||||
self._hostname = socket.gethostname()
|
||||
|
@ -787,8 +787,10 @@ class ReporterAgent(
|
|||
"""Get any changes to the log files and push updates to kv."""
|
||||
while True:
|
||||
try:
|
||||
formatted_status_string = internal_kv._internal_kv_get(
|
||||
DEBUG_AUTOSCALING_STATUS
|
||||
formatted_status_string = await self._gcs_aio_client.internal_kv_get(
|
||||
DEBUG_AUTOSCALING_STATUS.encode(),
|
||||
None,
|
||||
timeout=GCS_RPC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
stats = self._get_all_stats()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import asyncio
|
||||
import aiohttp.web
|
||||
import yaml
|
||||
|
||||
|
@ -9,8 +9,8 @@ import ray
|
|||
import ray._private.services
|
||||
import ray._private.utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.experimental.internal_kv as internal_kv
|
||||
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
|
||||
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
|
||||
from ray._private.ray_constants import (
|
||||
|
@ -40,6 +40,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
gcs_address = dashboard_head.gcs_address
|
||||
temp_dir = dashboard_head.temp_dir
|
||||
self.service_discovery = PrometheusServiceDiscoveryWriter(gcs_address, temp_dir)
|
||||
self._gcs_aio_client = dashboard_head.gcs_aio_client
|
||||
|
||||
async def _update_stubs(self, change):
|
||||
if change.old:
|
||||
|
@ -127,15 +128,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
autoscaler writes them there.
|
||||
"""
|
||||
|
||||
assert ray.experimental.internal_kv._internal_kv_initialized()
|
||||
legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
|
||||
formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
|
||||
(legacy_status, formatted_status_string, error) = await asyncio.gather(
|
||||
*[
|
||||
self._gcs_aio_client.internal_kv_get(
|
||||
key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS
|
||||
)
|
||||
for key in [
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_STATUS,
|
||||
DEBUG_AUTOSCALING_ERROR,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
formatted_status = (
|
||||
json.loads(formatted_status_string.decode())
|
||||
if formatted_status_string
|
||||
else {}
|
||||
)
|
||||
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Got cluster status.",
|
||||
|
|
|
@ -2,9 +2,9 @@ import json
|
|||
import logging
|
||||
|
||||
from aiohttp.web import Request, Response
|
||||
|
||||
import dataclasses
|
||||
import ray
|
||||
import asyncio
|
||||
import aiohttp.web
|
||||
import ray.dashboard.optional_utils as optional_utils
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
|
@ -25,6 +25,8 @@ routes = optional_utils.ClassMethodRouteTable
|
|||
class ServeAgent(dashboard_utils.DashboardAgentModule):
|
||||
def __init__(self, dashboard_agent):
|
||||
super().__init__(dashboard_agent)
|
||||
self._controller = None
|
||||
self._controller_lock = asyncio.Lock()
|
||||
|
||||
# TODO: It's better to use `/api/version`.
|
||||
# It requires a refactor of ClassMethodRouteTable to differentiate the server.
|
||||
|
@ -48,12 +50,24 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
|
|||
async def get_all_deployments(self, req: Request) -> Response:
|
||||
from ray.serve.schema import ServeApplicationSchema
|
||||
|
||||
client = self.get_serve_client()
|
||||
controller = await self.get_serve_controller()
|
||||
|
||||
if client is None:
|
||||
if controller is None:
|
||||
config = ServeApplicationSchema.get_empty_schema_dict()
|
||||
else:
|
||||
config = client.get_app_config()
|
||||
try:
|
||||
config = await controller.get_app_config.remote()
|
||||
except ray.exceptions.RayTaskError as e:
|
||||
# Task failure sometimes are due to GCS
|
||||
# failure. When GCS failed, we expect a longer time
|
||||
# to recover.
|
||||
return Response(
|
||||
status=503,
|
||||
text=(
|
||||
"Fail to get the response from the controller. "
|
||||
f"Potentially the GCS is down: {e}"
|
||||
),
|
||||
)
|
||||
|
||||
return Response(
|
||||
text=json.dumps(config),
|
||||
|
@ -65,13 +79,20 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
|
|||
async def get_all_deployment_statuses(self, req: Request) -> Response:
|
||||
from ray.serve.schema import serve_status_to_schema, ServeStatusSchema
|
||||
|
||||
client = self.get_serve_client()
|
||||
controller = await self.get_serve_controller()
|
||||
|
||||
if client is None:
|
||||
if controller is None:
|
||||
status_json = ServeStatusSchema.get_empty_schema_dict()
|
||||
status_json_str = json.dumps(status_json)
|
||||
else:
|
||||
status = client.get_serve_status()
|
||||
from ray.serve._private.common import StatusOverview
|
||||
from ray.serve.generated.serve_pb2 import (
|
||||
StatusOverview as StatusOverviewProto,
|
||||
)
|
||||
|
||||
serve_status = await controller.get_serve_status.remote()
|
||||
proto = StatusOverviewProto.FromString(serve_status)
|
||||
status = StatusOverview.from_proto(proto)
|
||||
status_json_str = serve_status_to_schema(status).json()
|
||||
|
||||
return Response(
|
||||
|
@ -84,7 +105,7 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
|
|||
async def delete_serve_application(self, req: Request) -> Response:
|
||||
from ray import serve
|
||||
|
||||
if self.get_serve_client() is not None:
|
||||
if await self.get_serve_controller() is not None:
|
||||
serve.shutdown()
|
||||
|
||||
return Response()
|
||||
|
@ -144,21 +165,44 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
|
|||
|
||||
return Response()
|
||||
|
||||
def get_serve_client(self):
|
||||
"""Gets the ServeControllerClient to the this cluster's Serve app.
|
||||
async def get_serve_controller(self):
|
||||
"""Gets the ServeController to the this cluster's Serve app.
|
||||
|
||||
return: If Serve is running on this Ray cluster, returns a client to
|
||||
the Serve controller. If Serve is not running, returns None.
|
||||
"""
|
||||
async with self._controller_lock:
|
||||
if self._controller is not None:
|
||||
try:
|
||||
await self._controller.check_alive.remote()
|
||||
return self._controller
|
||||
except ray.exceptions.RayActorError:
|
||||
logger.info("Controller is dead")
|
||||
self._controller = None
|
||||
|
||||
from ray.serve.context import get_global_client
|
||||
from ray.serve.exceptions import RayServeException
|
||||
# Try to connect to serve even when we detect the actor is dead
|
||||
# because the user might have started a new
|
||||
# serve cluter.
|
||||
from ray.serve._private.constants import (
|
||||
SERVE_CONTROLLER_NAME,
|
||||
SERVE_NAMESPACE,
|
||||
)
|
||||
|
||||
try:
|
||||
return get_global_client(_health_check_controller=True)
|
||||
except RayServeException:
|
||||
logger.debug("There's no Serve app running on this Ray cluster.")
|
||||
return None
|
||||
try:
|
||||
# get_actor is a sync call but it'll timeout after
|
||||
# ray.dashboard.consts.GCS_RPC_TIMEOUT_SECONDS
|
||||
self._controller = ray.get_actor(
|
||||
SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"There is no "
|
||||
"instance running on this Ray cluster. Please "
|
||||
"call `serve.start(detached=True) to start "
|
||||
f"one: {e}"
|
||||
)
|
||||
|
||||
return self._controller
|
||||
|
||||
async def run(self, server):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import requests
|
||||
import pytest
|
||||
import ray
|
||||
import sys
|
||||
from ray import serve
|
||||
from ray.tests.conftest import * # noqa: F401,F403
|
||||
from ray._private.test_utils import generate_system_config_map
|
||||
|
||||
DEPLOYMENTS_URL = "http://localhost:52365/api/serve/deployments/"
|
||||
STATUS_URL = "http://localhost:52365/api/serve/deployments/status"
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular_with_external_redis",
|
||||
[
|
||||
{
|
||||
**generate_system_config_map(
|
||||
gcs_failover_worker_reconnect_timeout=20,
|
||||
gcs_rpc_server_reconnect_timeout_s=3600,
|
||||
gcs_server_request_timeout_seconds=3,
|
||||
),
|
||||
}
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_deployments_get_tolerane(monkeypatch, ray_start_regular_with_external_redis):
|
||||
# test serve agent's availability when gcs is down
|
||||
monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3")
|
||||
serve.start(detached=True)
|
||||
|
||||
get_response = requests.get(DEPLOYMENTS_URL, timeout=15)
|
||||
assert get_response.status_code == 200
|
||||
ray._private.worker._global_node.kill_gcs_server()
|
||||
|
||||
get_response = requests.get(DEPLOYMENTS_URL, timeout=30)
|
||||
assert get_response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular_with_external_redis",
|
||||
[
|
||||
{
|
||||
**generate_system_config_map(
|
||||
gcs_failover_worker_reconnect_timeout=20,
|
||||
gcs_rpc_server_reconnect_timeout_s=3600,
|
||||
gcs_server_request_timeout_seconds=1,
|
||||
),
|
||||
}
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_status_url_get_tolerane(monkeypatch, ray_start_regular_with_external_redis):
|
||||
# test serve agent's availability when gcs is down
|
||||
monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3")
|
||||
serve.start(detached=True)
|
||||
get_response = requests.get(STATUS_URL, timeout=15)
|
||||
assert get_response.status_code == 200
|
||||
|
||||
ray._private.worker._global_node.kill_gcs_server()
|
||||
|
||||
get_response = requests.get(STATUS_URL, timeout=30)
|
||||
assert get_response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-vs", "--forked", __file__]))
|
|
@ -12,18 +12,14 @@ import aiohttp.web
|
|||
from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
import ray
|
||||
from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK
|
||||
from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK, GCS_RPC_TIMEOUT_SECONDS
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray._private import ray_constants
|
||||
from ray._private.storage import _load_class
|
||||
from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc
|
||||
from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JobInfoStorageClient
|
||||
from ray.experimental.internal_kv import (
|
||||
_internal_kv_get,
|
||||
_internal_kv_initialized,
|
||||
_internal_kv_list,
|
||||
)
|
||||
|
||||
from ray.job_submission import JobInfo
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
|
@ -87,7 +83,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
self._gcs_job_info_stub = None
|
||||
self._gcs_actor_info_stub = None
|
||||
self._dashboard_head = dashboard_head
|
||||
assert _internal_kv_initialized()
|
||||
self._gcs_aio_client = dashboard_head.gcs_aio_client
|
||||
self._job_info_client = None
|
||||
# For offloading CPU intensive work.
|
||||
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
|
@ -382,47 +378,47 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
# Serve wraps Ray's internal KV store and specially formats the keys.
|
||||
# These are the keys we are interested in:
|
||||
# SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY
|
||||
# TODO: Convert to async GRPC, if CPU usage is not a concern.
|
||||
def get_deployments():
|
||||
serve_keys = _internal_kv_list(
|
||||
SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
)
|
||||
serve_snapshot_keys = filter(
|
||||
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
|
||||
)
|
||||
|
||||
deployments_per_controller: List[Dict[str, Any]] = []
|
||||
for key in serve_snapshot_keys:
|
||||
val_bytes = _internal_kv_get(
|
||||
key, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
) or "{}".encode("utf-8")
|
||||
deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
|
||||
# Merge the deployments dicts of all controllers.
|
||||
deployments: Dict[str, Any] = {
|
||||
k: v for d in deployments_per_controller for k, v in d.items()
|
||||
}
|
||||
# Replace the keys (deployment names) with their hashes to prevent
|
||||
# collisions caused by the automatic conversion to camelcase by the
|
||||
# dashboard agent.
|
||||
return {
|
||||
hashlib.sha1(name.encode()).hexdigest(): info
|
||||
for name, info in deployments.items()
|
||||
}
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
executor=self._thread_pool, func=get_deployments
|
||||
serve_keys = await self._gcs_aio_client.internal_kv_keys(
|
||||
SERVE_CONTROLLER_NAME.encode(),
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE,
|
||||
timeout=GCS_RPC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
self._gcs_aio_client.internal_kv_get(
|
||||
key,
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE,
|
||||
timeout=GCS_RPC_TIMEOUT_SECONDS,
|
||||
)
|
||||
for key in serve_keys
|
||||
if SERVE_SNAPSHOT_KEY in key.decode()
|
||||
]
|
||||
|
||||
serve_snapshot_vals = await asyncio.gather(*tasks)
|
||||
|
||||
deployments_per_controller: List[Dict[str, Any]] = [
|
||||
json.loads(val.decode()) for val in serve_snapshot_vals
|
||||
]
|
||||
|
||||
# Merge the deployments dicts of all controllers.
|
||||
deployments: Dict[str, Any] = {
|
||||
k: v for d in deployments_per_controller for k, v in d.items()
|
||||
}
|
||||
# Replace the keys (deployment names) with their hashes to prevent
|
||||
# collisions caused by the automatic conversion to camelcase by the
|
||||
# dashboard agent.
|
||||
return {
|
||||
hashlib.sha1(name.encode()).hexdigest(): info
|
||||
for name, info in deployments.items()
|
||||
}
|
||||
|
||||
async def get_session_name(self):
|
||||
# TODO(yic): Convert to async GRPC.
|
||||
def get_session():
|
||||
return ray.experimental.internal_kv._internal_kv_get(
|
||||
"session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
|
||||
).decode()
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
executor=self._thread_pool, func=get_session
|
||||
session_name = await self._gcs_aio_client.internal_kv_get(
|
||||
b"session_name",
|
||||
namespace=ray_constants.KV_NAMESPACE_SESSION,
|
||||
timeout=GCS_RPC_TIMEOUT_SECONDS,
|
||||
)
|
||||
return session_name.decode()
|
||||
|
||||
async def run(self, server):
|
||||
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||
|
|
|
@ -261,6 +261,10 @@ def init_ray_and_catch_exceptions() -> Callable:
|
|||
try:
|
||||
address = self.get_gcs_address()
|
||||
logger.info(f"Connecting to ray with address={address}")
|
||||
# Set the gcs rpc timeout to shorter
|
||||
os.environ["RAY_gcs_server_request_timeout_seconds"] = str(
|
||||
dashboard_consts.GCS_RPC_TIMEOUT_SECONDS
|
||||
)
|
||||
# Init ray without logging to driver
|
||||
# to avoid infinite logging issue.
|
||||
ray.init(
|
||||
|
|
|
@ -2447,6 +2447,10 @@ def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHan
|
|||
have been created with Actor.options(name="name").remote(). This
|
||||
works for both detached & non-detached actors.
|
||||
|
||||
This method is a sync call and it'll timeout after 60s. This can be modified
|
||||
by setting OS env RAY_gcs_server_request_timeout_seconds before starting
|
||||
the cluster.
|
||||
|
||||
Args:
|
||||
name: The name of the actor.
|
||||
namespace: The namespace of the actor, or None to specify the current
|
||||
|
|
|
@ -96,6 +96,9 @@ HANDLE_METRIC_PUSH_INTERVAL_S = 10
|
|||
# Timeout for GCS internal KV service
|
||||
RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None
|
||||
|
||||
# Timeout for GCS RPC request
|
||||
RAY_GCS_RPC_TIMEOUT_S = 3.0
|
||||
|
||||
# Env var to control legacy sync deployment handle behavior in DAG.
|
||||
SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY = "SERVE_DEPLOYMENT_HANDLE_IS_SYNC"
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import ray
|
|||
from ray.actor import ActorHandle
|
||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
from ray.serve.config import HTTPOptions, DeploymentMode
|
||||
from ray.serve._private.constants import (
|
||||
ASYNC_CONCURRENCY,
|
||||
|
@ -37,6 +38,7 @@ class HTTPState:
|
|||
detached: bool,
|
||||
config: HTTPOptions,
|
||||
head_node_id: str,
|
||||
gcs_client: GcsClient,
|
||||
# Used by unit testing
|
||||
_start_proxies_on_init: bool = True,
|
||||
):
|
||||
|
@ -49,6 +51,9 @@ class HTTPState:
|
|||
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
|
||||
self._proxy_actor_names: Dict[NodeId, str] = dict()
|
||||
self._head_node_id: str = head_node_id
|
||||
|
||||
self._gcs_client = gcs_client
|
||||
|
||||
assert isinstance(head_node_id, str)
|
||||
|
||||
# Will populate self.proxy_actors with existing actors.
|
||||
|
@ -75,7 +80,7 @@ class HTTPState:
|
|||
def _get_target_nodes(self) -> List[Tuple[str, str]]:
|
||||
"""Return the list of (node_id, ip_address) to deploy HTTP servers on."""
|
||||
location = self._config.location
|
||||
target_nodes = get_all_node_ids()
|
||||
target_nodes = get_all_node_ids(self._gcs_client)
|
||||
|
||||
if location == DeploymentMode.NoServer:
|
||||
return []
|
||||
|
@ -112,6 +117,7 @@ class HTTPState:
|
|||
|
||||
def _start_proxies_if_needed(self) -> None:
|
||||
"""Start a proxy on every node if it doesn't already exist."""
|
||||
|
||||
for node_id, node_ip_address in self._get_target_nodes():
|
||||
if node_id in self._proxy_actors:
|
||||
continue
|
||||
|
@ -151,7 +157,7 @@ class HTTPState:
|
|||
|
||||
def _stop_proxies_if_needed(self) -> bool:
|
||||
"""Removes proxy actors from any nodes that no longer exist."""
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids(self._gcs_client)}
|
||||
to_stop = []
|
||||
for node_id in self._proxy_actors:
|
||||
if node_id not in all_node_ids:
|
||||
|
|
|
@ -28,12 +28,15 @@ class RayInternalKVStore(KVStoreBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str = None,
|
||||
namespace: Optional[str] = None,
|
||||
gcs_client: Optional[GcsClient] = None,
|
||||
):
|
||||
if namespace is not None and not isinstance(namespace, str):
|
||||
raise TypeError("namespace must a string, got: {}.".format(type(namespace)))
|
||||
|
||||
self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
if gcs_client is not None:
|
||||
self.gcs_client = gcs_client
|
||||
else:
|
||||
self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
self.timeout = RAY_SERVE_KV_TIMEOUT_S
|
||||
self.namespace = namespace or ""
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import ray
|
|||
import ray.util.serialization_addons
|
||||
from ray.actor import ActorHandle
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S
|
||||
from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request
|
||||
from ray.util.serialization import StandaloneSerializationContext
|
||||
|
||||
|
@ -145,19 +145,21 @@ def format_actor_name(actor_name, controller_name=None, *modifiers):
|
|||
return name
|
||||
|
||||
|
||||
def get_all_node_ids() -> List[Tuple[str, str]]:
|
||||
def get_all_node_ids(gcs_client) -> List[Tuple[str, str]]:
|
||||
"""Get IDs for all live nodes in the cluster.
|
||||
|
||||
Returns a list of (node_id: str, ip_address: str). The node_id can be
|
||||
passed into the Ray SchedulingPolicy API.
|
||||
"""
|
||||
node_ids = []
|
||||
# Sort on NodeID to ensure the ordering is deterministic across the cluster.
|
||||
for node in sorted(ray.nodes(), key=lambda entry: entry["NodeID"]):
|
||||
# print(node)
|
||||
if node["Alive"]:
|
||||
node_ids.append((node["NodeID"], node["NodeName"]))
|
||||
nodes = gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S)
|
||||
node_ids = [
|
||||
(ray.NodeID.from_binary(node.node_id).hex(), node.node_name)
|
||||
for node in nodes.node_info_list
|
||||
if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE
|
||||
]
|
||||
|
||||
# Sort on NodeID to ensure the ordering is deterministic across the cluster.
|
||||
sorted(node_ids)
|
||||
return node_ids
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from ray._private.utils import import_attr
|
|||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
from ray.actor import ActorHandle
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
from ray.serve._private.autoscaling_policy import BasicAutoscalingPolicy
|
||||
from ray.serve._private.common import (
|
||||
ApplicationStatus,
|
||||
|
@ -96,9 +97,10 @@ class ServeController:
|
|||
# Used to read/write checkpoints.
|
||||
self.ray_worker_namespace = ray.get_runtime_context().namespace
|
||||
self.controller_name = controller_name
|
||||
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
kv_store_namespace = f"{self.controller_name}-{self.ray_worker_namespace}"
|
||||
self.kv_store = RayInternalKVStore(kv_store_namespace)
|
||||
self.snapshot_store = RayInternalKVStore(namespace=kv_store_namespace)
|
||||
self.kv_store = RayInternalKVStore(kv_store_namespace, gcs_client)
|
||||
self.snapshot_store = RayInternalKVStore(kv_store_namespace, gcs_client)
|
||||
|
||||
# Dictionary of deployment_name -> proxy_name -> queue length.
|
||||
self.deployment_stats = defaultdict(lambda: defaultdict(dict))
|
||||
|
@ -114,6 +116,7 @@ class ServeController:
|
|||
detached,
|
||||
http_config,
|
||||
head_node_id,
|
||||
gcs_client,
|
||||
)
|
||||
self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ def test_node_selection():
|
|||
detached=True,
|
||||
config=http_options,
|
||||
head_node_id=head_node_id,
|
||||
gcs_client=None,
|
||||
_start_proxies_on_init=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from ray.serve._private.constants import (
|
|||
SERVE_PROXY_NAME,
|
||||
SERVE_ROOT_URL_ENV_KEY,
|
||||
)
|
||||
from ray._private.gcs_utils import GcsClient
|
||||
from ray.serve.context import get_global_client
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.generated.serve_pb2 import ActorNameList
|
||||
|
@ -87,6 +88,7 @@ def lower_slow_startup_threshold_and_reset():
|
|||
def test_shutdown(ray_shutdown):
|
||||
ray.init(num_cpus=16)
|
||||
serve.start(http_options=dict(port=8003))
|
||||
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
|
||||
@serve.deployment
|
||||
def f():
|
||||
|
@ -100,7 +102,7 @@ def test_shutdown(ray_shutdown):
|
|||
format_actor_name(
|
||||
SERVE_PROXY_NAME,
|
||||
serve.context._global_client._controller_name,
|
||||
get_all_node_ids()[0][0],
|
||||
get_all_node_ids(gcs_client)[0][0],
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -239,10 +241,11 @@ def test_multiple_routers(ray_cluster):
|
|||
node_ids = ray._private.state.node_ids()
|
||||
assert len(node_ids) == 2
|
||||
serve.start(http_options=dict(port=8005, location="EveryNode"))
|
||||
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
|
||||
|
||||
def get_proxy_names():
|
||||
proxy_names = []
|
||||
for node_id, _ in get_all_node_ids():
|
||||
for node_id, _ in get_all_node_ids(gcs_client):
|
||||
proxy_names.append(
|
||||
format_actor_name(
|
||||
SERVE_PROXY_NAME,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
import os
|
||||
import threading
|
||||
from time import sleep
|
||||
|
||||
|
@ -609,8 +610,37 @@ def test_pg_actor_workloads(ray_start_regular_with_external_redis):
|
|||
assert pid == ray.get(c.pid.remote())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular_with_external_redis",
|
||||
[
|
||||
generate_system_config_map(
|
||||
gcs_failover_worker_reconnect_timeout=20,
|
||||
gcs_rpc_server_reconnect_timeout_s=60,
|
||||
gcs_server_request_timeout_seconds=10,
|
||||
)
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_get_actor_when_gcs_is_down(ray_start_regular_with_external_redis):
|
||||
@ray.remote
|
||||
def create_actor():
|
||||
@ray.remote
|
||||
class A:
|
||||
def pid(self):
|
||||
return os.getpid()
|
||||
|
||||
a = A.options(lifetime="detached", name="A").remote()
|
||||
ray.get(a.pid.remote())
|
||||
|
||||
ray.get(create_actor.remote())
|
||||
|
||||
ray._private.worker._global_node.kill_gcs_server()
|
||||
|
||||
with pytest.raises(ray.exceptions.GetTimeoutError):
|
||||
ray.get_actor("A")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"call_ray_start_with_external_redis",
|
||||
[
|
||||
"6379",
|
||||
"6379,6380",
|
||||
"6379,6380,6381",
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_using_hostnames(call_ray_start_with_external_redis):
|
||||
ray.init(address="127.0.0.1:6379", _redis_password="123")
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 1
|
||||
|
||||
assert ray.get(f.remote()) == 1
|
||||
|
||||
@ray.remote
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def inc_and_get(self):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
counter = Counter.remote()
|
||||
assert ray.get(counter.inc_and_get.remote()) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
# Make subprocess happy in bazel.
|
||||
os.environ["LC_ALL"] = "en_US.UTF-8"
|
||||
os.environ["LANG"] = "en_US.UTF-8"
|
||||
if os.environ.get("PARALLEL_CI"):
|
||||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
|
||||
else:
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
Loading…
Add table
Reference in a new issue