[serve] Make serve agent not blocking when GCS is down. (#27526) (#27674)

This PR fixed several issue which block serve agent when GCS is down. We need to make sure serve agent is always alive and can make sure the external requests can be sent to the agent and check the status.

- internal kv used in dashboard/agent blocks the agent. We use the async one instead
- serve controller use ray.nodes which is a blocking call and blocking forever. change to use gcs client with timeout
- agent use serve controller client which is a blocking call with max retries = -1. This blocks until controller is back.

To enable Serve HA, we also need to setup:

- RAY_gcs_server_request_timeout_seconds=5
- RAY_SERVE_KV_TIMEOUT_S=5

which we should set in KubeRay.
This commit is contained in:
Yi Cheng 2022-08-09 01:21:12 +00:00 committed by GitHub
parent faceb1a0e3
commit a09ae65823
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 282 additions and 146 deletions

View file

@ -135,6 +135,7 @@ test_core() {
prepare_docker() {
rm "${WORKSPACE_DIR}"/python/dist/* ||:
pushd "${WORKSPACE_DIR}/python"
pip install -e . --verbose
python setup.py bdist_wheel
tmp_dir="/tmp/prepare_docker_$RANDOM"
mkdir -p $tmp_dir

View file

@ -27,6 +27,7 @@ GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
)
GCS_RPC_TIMEOUT_SECONDS = 3
# aiohttp_cache
AIOHTTP_CACHE_TTL_SECONDS = 2
AIOHTTP_CACHE_MAX_SIZE = 128

View file

@ -7,7 +7,6 @@ import ray._private.ray_constants as ray_constants
import ray._private.utils as utils
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
import ray.experimental.internal_kv as internal_kv
from ray.core.generated import event_pb2, event_pb2_grpc
from ray.dashboard.modules.event import event_consts
from ray.dashboard.modules.event.event_utils import monitor_events
@ -24,6 +23,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
self._monitor: Union[asyncio.Task, None] = None
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
self._gcs_aio_client = dashboard_agent.gcs_aio_client
logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
async def _connect_to_dashboard(self):
@ -35,11 +36,12 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
"""
while True:
try:
# TODO: Use async version if performance is an issue
dashboard_rpc_address = internal_kv._internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
dashboard_rpc_address = await self._gcs_aio_client.internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
timeout=1,
)
dashboard_rpc_address = dashboard_rpc_address.decode()
if dashboard_rpc_address:
logger.info("Report events to %s", dashboard_rpc_address)
options = ray_constants.GLOBAL_GRPC_OPTIONS

View file

@ -1,7 +1,7 @@
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as optional_utils
from ray.dashboard.modules.healthz.utils import HealthChecker
from aiohttp.web import Request, Response, HTTPServiceUnavailable
from aiohttp.web import Request, Response
import grpc
routes = optional_utils.ClassMethodRouteTable
@ -26,7 +26,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule):
try:
alive = await self._health_checker.check_local_raylet_liveness()
if alive is False:
return HTTPServiceUnavailable(reason="Local Raylet failed")
return Response(status=503, text="Local Raylet failed")
except grpc.RpcError as e:
# We only consider the error other than GCS unreachable as raylet failure
# to avoid false positive.
@ -38,7 +38,7 @@ class HealthzAgent(dashboard_utils.DashboardAgentModule):
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.DEADLINE_EXCEEDED,
):
return HTTPServiceUnavailable(reason=e.message())
return Response(status=503, text=f"Health check failed due to: {e}")
return Response(
text="success",

View file

@ -14,10 +14,10 @@ class HealthChecker:
return False
liveness = await self._gcs_aio_client.check_alive(
[self._local_node_address.encode()], 1
[self._local_node_address.encode()], 0.1
)
return liveness[0]
async def check_gcs_liveness(self) -> bool:
await self._gcs_aio_client.check_alive([], 1)
await self._gcs_aio_client.check_alive([], 0.1)
return True

View file

@ -80,6 +80,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
# The time it takes until the head node is registered. None means
# head node hasn't been registered.
self._head_node_registration_time_s = None
self._gcs_aio_client = dashboard_head.gcs_aio_client
async def _update_stubs(self, change):
if change.old:

View file

@ -14,9 +14,9 @@ import psutil
import ray
import ray._private.services
import ray._private.utils
from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.dashboard.utils as dashboard_utils
import ray.experimental.internal_kv as internal_kv
from ray._private.metrics_agent import Gauge, MetricsAgent, Record
from ray._private.ray_constants import DEBUG_AUTOSCALING_STATUS
from ray.core.generated import reporter_pb2, reporter_pb2_grpc
@ -227,7 +227,7 @@ class ReporterAgent(
logical_cpu_count = psutil.cpu_count()
physical_cpu_count = psutil.cpu_count(logical=False)
self._cpu_counts = (logical_cpu_count, physical_cpu_count)
self._gcs_aio_client = dashboard_agent.gcs_aio_client
self._ip = dashboard_agent.ip
self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
self._hostname = socket.gethostname()
@ -787,8 +787,10 @@ class ReporterAgent(
"""Get any changes to the log files and push updates to kv."""
while True:
try:
formatted_status_string = internal_kv._internal_kv_get(
DEBUG_AUTOSCALING_STATUS
formatted_status_string = await self._gcs_aio_client.internal_kv_get(
DEBUG_AUTOSCALING_STATUS.encode(),
None,
timeout=GCS_RPC_TIMEOUT_SECONDS,
)
stats = self._get_all_stats()

View file

@ -1,7 +1,7 @@
import json
import logging
import os
import asyncio
import aiohttp.web
import yaml
@ -9,8 +9,8 @@ import ray
import ray._private.services
import ray._private.utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS
import ray.dashboard.utils as dashboard_utils
import ray.experimental.internal_kv as internal_kv
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
from ray._private.ray_constants import (
@ -40,6 +40,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
gcs_address = dashboard_head.gcs_address
temp_dir = dashboard_head.temp_dir
self.service_discovery = PrometheusServiceDiscoveryWriter(gcs_address, temp_dir)
self._gcs_aio_client = dashboard_head.gcs_aio_client
async def _update_stubs(self, change):
if change.old:
@ -127,15 +128,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
autoscaler writes them there.
"""
assert ray.experimental.internal_kv._internal_kv_initialized()
legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
(legacy_status, formatted_status_string, error) = await asyncio.gather(
*[
self._gcs_aio_client.internal_kv_get(
key.encode(), namespace=None, timeout=GCS_RPC_TIMEOUT_SECONDS
)
for key in [
DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_ERROR,
]
]
)
formatted_status = (
json.loads(formatted_status_string.decode())
if formatted_status_string
else {}
)
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
return dashboard_optional_utils.rest_response(
success=True,
message="Got cluster status.",

View file

@ -2,9 +2,9 @@ import json
import logging
from aiohttp.web import Request, Response
import dataclasses
import ray
import asyncio
import aiohttp.web
import ray.dashboard.optional_utils as optional_utils
import ray.dashboard.utils as dashboard_utils
@ -25,6 +25,8 @@ routes = optional_utils.ClassMethodRouteTable
class ServeAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
self._controller = None
self._controller_lock = asyncio.Lock()
# TODO: It's better to use `/api/version`.
# It requires a refactor of ClassMethodRouteTable to differentiate the server.
@ -48,12 +50,24 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
async def get_all_deployments(self, req: Request) -> Response:
from ray.serve.schema import ServeApplicationSchema
client = self.get_serve_client()
controller = await self.get_serve_controller()
if client is None:
if controller is None:
config = ServeApplicationSchema.get_empty_schema_dict()
else:
config = client.get_app_config()
try:
config = await controller.get_app_config.remote()
except ray.exceptions.RayTaskError as e:
# Task failure sometimes are due to GCS
# failure. When GCS failed, we expect a longer time
# to recover.
return Response(
status=503,
text=(
"Fail to get the response from the controller. "
f"Potentially the GCS is down: {e}"
),
)
return Response(
text=json.dumps(config),
@ -65,13 +79,20 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
async def get_all_deployment_statuses(self, req: Request) -> Response:
from ray.serve.schema import serve_status_to_schema, ServeStatusSchema
client = self.get_serve_client()
controller = await self.get_serve_controller()
if client is None:
if controller is None:
status_json = ServeStatusSchema.get_empty_schema_dict()
status_json_str = json.dumps(status_json)
else:
status = client.get_serve_status()
from ray.serve._private.common import StatusOverview
from ray.serve.generated.serve_pb2 import (
StatusOverview as StatusOverviewProto,
)
serve_status = await controller.get_serve_status.remote()
proto = StatusOverviewProto.FromString(serve_status)
status = StatusOverview.from_proto(proto)
status_json_str = serve_status_to_schema(status).json()
return Response(
@ -84,7 +105,7 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
async def delete_serve_application(self, req: Request) -> Response:
from ray import serve
if self.get_serve_client() is not None:
if await self.get_serve_controller() is not None:
serve.shutdown()
return Response()
@ -105,21 +126,44 @@ class ServeAgent(dashboard_utils.DashboardAgentModule):
return Response()
def get_serve_client(self):
"""Gets the ServeControllerClient to the this cluster's Serve app.
async def get_serve_controller(self):
"""Gets the ServeController to the this cluster's Serve app.
return: If Serve is running on this Ray cluster, returns a client to
the Serve controller. If Serve is not running, returns None.
"""
async with self._controller_lock:
if self._controller is not None:
try:
await self._controller.check_alive.remote()
return self._controller
except ray.exceptions.RayActorError:
logger.info("Controller is dead")
self._controller = None
from ray.serve.context import get_global_client
from ray.serve.exceptions import RayServeException
# Try to connect to serve even when we detect the actor is dead
# because the user might have started a new
# serve cluter.
from ray.serve._private.constants import (
SERVE_CONTROLLER_NAME,
SERVE_NAMESPACE,
)
try:
return get_global_client(_health_check_controller=True)
except RayServeException:
logger.debug("There's no Serve app running on this Ray cluster.")
return None
# get_actor is a sync call but it'll timeout after
# ray.dashboard.consts.GCS_RPC_TIMEOUT_SECONDS
self._controller = ray.get_actor(
SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
)
except Exception as e:
logger.debug(
"There is no "
"instance running on this Ray cluster. Please "
"call `serve.start(detached=True) to start "
f"one: {e}"
)
return self._controller
async def run(self, server):
pass

View file

@ -0,0 +1,68 @@
import requests
import pytest
import ray
import sys
from ray import serve
from ray.tests.conftest import * # noqa: F401,F403
from ray._private.test_utils import generate_system_config_map
DEPLOYMENTS_URL = "http://localhost:52365/api/serve/deployments/"
STATUS_URL = "http://localhost:52365/api/serve/deployments/status"
@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.")
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
{
**generate_system_config_map(
gcs_failover_worker_reconnect_timeout=20,
gcs_rpc_server_reconnect_timeout_s=3600,
gcs_server_request_timeout_seconds=3,
),
}
],
indirect=True,
)
def test_deployments_get_tolerane(monkeypatch, ray_start_regular_with_external_redis):
# test serve agent's availability when gcs is down
monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3")
serve.start(detached=True)
get_response = requests.get(DEPLOYMENTS_URL, timeout=15)
assert get_response.status_code == 200
ray._private.worker._global_node.kill_gcs_server()
get_response = requests.get(DEPLOYMENTS_URL, timeout=30)
assert get_response.status_code == 503
@pytest.mark.skipif(sys.platform == "darwin", reason="Flaky on OSX.")
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
{
**generate_system_config_map(
gcs_failover_worker_reconnect_timeout=20,
gcs_rpc_server_reconnect_timeout_s=3600,
gcs_server_request_timeout_seconds=1,
),
}
],
indirect=True,
)
def test_status_url_get_tolerane(monkeypatch, ray_start_regular_with_external_redis):
# test serve agent's availability when gcs is down
monkeypatch.setenv("RAY_SERVE_KV_TIMEOUT_S", "3")
serve.start(detached=True)
get_response = requests.get(STATUS_URL, timeout=15)
assert get_response.status_code == 200
ray._private.worker._global_node.kill_gcs_server()
get_response = requests.get(STATUS_URL, timeout=30)
assert get_response.status_code == 200
if __name__ == "__main__":
sys.exit(pytest.main(["-vs", "--forked", __file__]))

View file

@ -12,18 +12,14 @@ import aiohttp.web
from pydantic import BaseModel, Extra, Field, validator
import ray
from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK
from ray.dashboard.consts import RAY_CLUSTER_ACTIVITY_HOOK, GCS_RPC_TIMEOUT_SECONDS
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray._private import ray_constants
from ray._private.storage import _load_class
from ray.core.generated import gcs_pb2, gcs_service_pb2, gcs_service_pb2_grpc
from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JobInfoStorageClient
from ray.experimental.internal_kv import (
_internal_kv_get,
_internal_kv_initialized,
_internal_kv_list,
)
from ray.job_submission import JobInfo
from ray.runtime_env import RuntimeEnv
@ -87,7 +83,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._gcs_job_info_stub = None
self._gcs_actor_info_stub = None
self._dashboard_head = dashboard_head
assert _internal_kv_initialized()
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None
# For offloading CPU intensive work.
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
@ -382,21 +378,28 @@ class APIHead(dashboard_utils.DashboardHeadModule):
# Serve wraps Ray's internal KV store and specially formats the keys.
# These are the keys we are interested in:
# SERVE_CONTROLLER_NAME(+ optional random letters):SERVE_SNAPSHOT_KEY
# TODO: Convert to async GRPC, if CPU usage is not a concern.
def get_deployments():
serve_keys = _internal_kv_list(
SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
)
serve_snapshot_keys = filter(
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
serve_keys = await self._gcs_aio_client.internal_kv_keys(
SERVE_CONTROLLER_NAME.encode(),
namespace=ray_constants.KV_NAMESPACE_SERVE,
timeout=GCS_RPC_TIMEOUT_SECONDS,
)
deployments_per_controller: List[Dict[str, Any]] = []
for key in serve_snapshot_keys:
val_bytes = _internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_SERVE
) or "{}".encode("utf-8")
deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
tasks = [
self._gcs_aio_client.internal_kv_get(
key,
namespace=ray_constants.KV_NAMESPACE_SERVE,
timeout=GCS_RPC_TIMEOUT_SECONDS,
)
for key in serve_keys
if SERVE_SNAPSHOT_KEY in key.decode()
]
serve_snapshot_vals = await asyncio.gather(*tasks)
deployments_per_controller: List[Dict[str, Any]] = [
json.loads(val.decode()) for val in serve_snapshot_vals
]
# Merge the deployments dicts of all controllers.
deployments: Dict[str, Any] = {
k: v for d in deployments_per_controller for k, v in d.items()
@ -409,20 +412,13 @@ class APIHead(dashboard_utils.DashboardHeadModule):
for name, info in deployments.items()
}
return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_deployments
)
async def get_session_name(self):
# TODO(yic): Convert to async GRPC.
def get_session():
return ray.experimental.internal_kv._internal_kv_get(
"session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
).decode()
return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_session
session_name = await self._gcs_aio_client.internal_kv_get(
b"session_name",
namespace=ray_constants.KV_NAMESPACE_SESSION,
timeout=GCS_RPC_TIMEOUT_SECONDS,
)
return session_name.decode()
async def run(self, server):
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(

View file

@ -261,6 +261,10 @@ def init_ray_and_catch_exceptions() -> Callable:
try:
address = self.get_gcs_address()
logger.info(f"Connecting to ray with address={address}")
# Set the gcs rpc timeout to shorter
os.environ["RAY_gcs_server_request_timeout_seconds"] = str(
dashboard_consts.GCS_RPC_TIMEOUT_SECONDS
)
# Init ray without logging to driver
# to avoid infinite logging issue.
ray.init(

View file

@ -2447,6 +2447,10 @@ def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHan
have been created with Actor.options(name="name").remote(). This
works for both detached & non-detached actors.
This method is a sync call and it'll timeout after 60s. This can be modified
by setting OS env RAY_gcs_server_request_timeout_seconds before starting
the cluster.
Args:
name: The name of the actor.
namespace: The namespace of the actor, or None to specify the current

View file

@ -96,6 +96,9 @@ HANDLE_METRIC_PUSH_INTERVAL_S = 10
# Timeout for GCS internal KV service
RAY_SERVE_KV_TIMEOUT_S = float(os.environ.get("RAY_SERVE_KV_TIMEOUT_S", "0")) or None
# Timeout for GCS RPC request
RAY_GCS_RPC_TIMEOUT_S = 3.0
# Env var to control legacy sync deployment handle behavior in DAG.
SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY = "SERVE_DEPLOYMENT_HANDLE_IS_SYNC"

View file

@ -7,6 +7,7 @@ import ray
from ray.actor import ActorHandle
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray._private.gcs_utils import GcsClient
from ray.serve.config import HTTPOptions, DeploymentMode
from ray.serve._private.constants import (
ASYNC_CONCURRENCY,
@ -37,6 +38,7 @@ class HTTPState:
detached: bool,
config: HTTPOptions,
head_node_id: str,
gcs_client: GcsClient,
# Used by unit testing
_start_proxies_on_init: bool = True,
):
@ -49,6 +51,9 @@ class HTTPState:
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
self._proxy_actor_names: Dict[NodeId, str] = dict()
self._head_node_id: str = head_node_id
self._gcs_client = gcs_client
assert isinstance(head_node_id, str)
# Will populate self.proxy_actors with existing actors.
@ -75,7 +80,7 @@ class HTTPState:
def _get_target_nodes(self) -> List[Tuple[str, str]]:
"""Return the list of (node_id, ip_address) to deploy HTTP servers on."""
location = self._config.location
target_nodes = get_all_node_ids()
target_nodes = get_all_node_ids(self._gcs_client)
if location == DeploymentMode.NoServer:
return []
@ -112,6 +117,7 @@ class HTTPState:
def _start_proxies_if_needed(self) -> None:
"""Start a proxy on every node if it doesn't already exist."""
for node_id, node_ip_address in self._get_target_nodes():
if node_id in self._proxy_actors:
continue
@ -151,7 +157,7 @@ class HTTPState:
def _stop_proxies_if_needed(self) -> bool:
"""Removes proxy actors from any nodes that no longer exist."""
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
all_node_ids = {node_id for node_id, _ in get_all_node_ids(self._gcs_client)}
to_stop = []
for node_id in self._proxy_actors:
if node_id not in all_node_ids:

View file

@ -28,11 +28,14 @@ class RayInternalKVStore(KVStoreBase):
def __init__(
self,
namespace: str = None,
namespace: Optional[str] = None,
gcs_client: Optional[GcsClient] = None,
):
if namespace is not None and not isinstance(namespace, str):
raise TypeError("namespace must a string, got: {}.".format(type(namespace)))
if gcs_client is not None:
self.gcs_client = gcs_client
else:
self.gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
self.timeout = RAY_SERVE_KV_TIMEOUT_S
self.namespace = namespace or ""

View file

@ -21,7 +21,7 @@ import ray
import ray.util.serialization_addons
from ray.actor import ActorHandle
from ray.exceptions import RayTaskError
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S
from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request
from ray.util.serialization import StandaloneSerializationContext
@ -145,19 +145,21 @@ def format_actor_name(actor_name, controller_name=None, *modifiers):
return name
def get_all_node_ids() -> List[Tuple[str, str]]:
def get_all_node_ids(gcs_client) -> List[Tuple[str, str]]:
"""Get IDs for all live nodes in the cluster.
Returns a list of (node_id: str, ip_address: str). The node_id can be
passed into the Ray SchedulingPolicy API.
"""
node_ids = []
# Sort on NodeID to ensure the ordering is deterministic across the cluster.
for node in sorted(ray.nodes(), key=lambda entry: entry["NodeID"]):
# print(node)
if node["Alive"]:
node_ids.append((node["NodeID"], node["NodeName"]))
nodes = gcs_client.get_all_node_info(timeout=RAY_GCS_RPC_TIMEOUT_S)
node_ids = [
(ray.NodeID.from_binary(node.node_id).hex(), node.node_name)
for node in nodes.node_info_list
if node.state == ray.core.generated.gcs_pb2.GcsNodeInfo.ALIVE
]
# Sort on NodeID to ensure the ordering is deterministic across the cluster.
sorted(node_ids)
return node_ids

View file

@ -13,6 +13,7 @@ from ray._private.utils import import_attr
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.actor import ActorHandle
from ray.exceptions import RayTaskError
from ray._private.gcs_utils import GcsClient
from ray.serve._private.autoscaling_policy import BasicAutoscalingPolicy
from ray.serve._private.common import (
ApplicationStatus,
@ -96,9 +97,10 @@ class ServeController:
# Used to read/write checkpoints.
self.ray_worker_namespace = ray.get_runtime_context().namespace
self.controller_name = controller_name
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
kv_store_namespace = f"{self.controller_name}-{self.ray_worker_namespace}"
self.kv_store = RayInternalKVStore(kv_store_namespace)
self.snapshot_store = RayInternalKVStore(namespace=kv_store_namespace)
self.kv_store = RayInternalKVStore(kv_store_namespace, gcs_client)
self.snapshot_store = RayInternalKVStore(kv_store_namespace, gcs_client)
# Dictionary of deployment_name -> proxy_name -> queue length.
self.deployment_stats = defaultdict(lambda: defaultdict(dict))
@ -114,6 +116,7 @@ class ServeController:
detached,
http_config,
head_node_id,
gcs_client,
)
self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)

View file

@ -15,6 +15,7 @@ def test_node_selection():
detached=True,
config=http_options,
head_node_id=head_node_id,
gcs_client=None,
_start_proxies_on_init=False,
)

View file

@ -28,6 +28,7 @@ from ray.serve._private.constants import (
SERVE_PROXY_NAME,
SERVE_ROOT_URL_ENV_KEY,
)
from ray._private.gcs_utils import GcsClient
from ray.serve.context import get_global_client
from ray.serve.exceptions import RayServeException
from ray.serve.generated.serve_pb2 import ActorNameList
@ -87,6 +88,7 @@ def lower_slow_startup_threshold_and_reset():
def test_shutdown(ray_shutdown):
ray.init(num_cpus=16)
serve.start(http_options=dict(port=8003))
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
@serve.deployment
def f():
@ -100,7 +102,7 @@ def test_shutdown(ray_shutdown):
format_actor_name(
SERVE_PROXY_NAME,
serve.context._global_client._controller_name,
get_all_node_ids()[0][0],
get_all_node_ids(gcs_client)[0][0],
),
]
@ -239,10 +241,11 @@ def test_multiple_routers(ray_cluster):
node_ids = ray._private.state.node_ids()
assert len(node_ids) == 2
serve.start(http_options=dict(port=8005, location="EveryNode"))
gcs_client = GcsClient(address=ray.get_runtime_context().gcs_address)
def get_proxy_names():
proxy_names = []
for node_id, _ in get_all_node_ids():
for node_id, _ in get_all_node_ids(gcs_client):
proxy_names.append(
format_actor_name(
SERVE_PROXY_NAME,

View file

@ -1,4 +1,5 @@
import sys
import os
import threading
from time import sleep
@ -609,8 +610,37 @@ def test_pg_actor_workloads(ray_start_regular_with_external_redis):
assert pid == ray.get(c.pid.remote())
@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
generate_system_config_map(
gcs_failover_worker_reconnect_timeout=20,
gcs_rpc_server_reconnect_timeout_s=60,
gcs_server_request_timeout_seconds=10,
)
],
indirect=True,
)
def test_get_actor_when_gcs_is_down(ray_start_regular_with_external_redis):
@ray.remote
def create_actor():
@ray.remote
class A:
def pid(self):
return os.getpid()
a = A.options(lifetime="detached", name="A").remote()
ray.get(a.pid.remote())
ray.get(create_actor.remote())
ray._private.worker._global_node.kill_gcs_server()
with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get_actor("A")
if __name__ == "__main__":
import os
import pytest

View file

@ -1,48 +0,0 @@
import os
import pytest
import sys
import ray
@pytest.mark.parametrize(
"call_ray_start_with_external_redis",
[
"6379",
"6379,6380",
"6379,6380,6381",
],
indirect=True,
)
def test_using_hostnames(call_ray_start_with_external_redis):
ray.init(address="127.0.0.1:6379", _redis_password="123")
@ray.remote
def f():
return 1
assert ray.get(f.remote()) == 1
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def inc_and_get(self):
self.count += 1
return self.count
counter = Counter.remote()
assert ray.get(counter.inc_and_get.remote()) == 1
if __name__ == "__main__":
import pytest
# Make subprocess happy in bazel.
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["LANG"] = "en_US.UTF-8"
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))