[State Observability API] Error handling (#24413)

This improves error handling per https://docs.google.com/document/d/1IeEsJOiurg-zctOcBjY-tQVbsCmURFSnUCTkx_4a7Cw/edit#heading=h.pdzl9cil9e8z (the RPC part).

Semantics
If all queries to the source failed, raise a RayStateApiException.

If partial queries are failed, warnings.warn the partial failure when print_api_stats=True. It is true for CLI. It is false when it is used within Python API or json / yaml format is required.
This commit is contained in:
SangBin Cho 2022-05-24 19:56:49 +09:00 committed by GitHub
parent e73c37cc17
commit a7e759317b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 677 additions and 169 deletions

View file

@ -1,4 +1,5 @@
import logging
import aiohttp.web
import dataclasses
@ -10,6 +11,7 @@ import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response
from ray.dashboard.state_aggregator import StateAPIManager
from ray.experimental.state.common import ListApiOptions
from ray.experimental.state.exception import DataSourceUnavailable
from ray.experimental.state.state_manager import StateDataSourceClient
logger = logging.getLogger(__name__)
@ -33,13 +35,19 @@ class StateHead(dashboard_utils.DashboardHeadModule):
def _options_from_req(self, req) -> ListApiOptions:
"""Obtain `ListApiOptions` from the aiohttp request."""
limit = int(req.query.get("limit"))
# Only apply 80% of the timeout so that
# the API will reply before client times out if query to the source fails.
timeout = int(req.query.get("timeout"))
return ListApiOptions(limit=limit, timeout=timeout)
def _reply(self, success: bool, message: str, result: dict):
def _reply(self, success: bool, error_message: str, result: dict, **kwargs):
"""Reply to the client."""
return rest_response(
success=success, message=message, result=result, convert_google_style=False
success=success,
message=error_message,
result=result,
convert_google_style=False,
**kwargs,
)
async def _update_raylet_stubs(self, change: Change):
@ -85,57 +93,61 @@ class StateHead(dashboard_utils.DashboardHeadModule):
int(ports[1]),
)
async def _handle_list_api(self, list_api_fn, req):
try:
result = await list_api_fn(option=self._options_from_req(req))
return self._reply(
success=True,
error_message="",
result=result.result,
partial_failure_warning=result.partial_failure_warning,
)
except DataSourceUnavailable as e:
return self._reply(success=False, error_message=str(e), result=None)
@routes.get("/api/v0/actors")
async def list_actors(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_actors(option=self._options_from_req(req))
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_actors, req)
@routes.get("/api/v0/jobs")
async def list_jobs(self, req) -> aiohttp.web.Response:
data = self._state_api.list_jobs(option=self._options_from_req(req))
return self._reply(
success=True,
message="",
result={
job_id: dataclasses.asdict(job_info)
for job_id, job_info in data.items()
},
)
try:
result = self._state_api.list_jobs(option=self._options_from_req(req))
return self._reply(
success=True,
error_message="",
result={
job_id: dataclasses.asdict(job_info)
for job_id, job_info in result.result.items()
},
partial_failure_warning=result.partial_failure_warning,
)
except DataSourceUnavailable as e:
return self._reply(success=False, error_message=str(e), result=None)
@routes.get("/api/v0/nodes")
async def list_nodes(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_nodes(option=self._options_from_req(req))
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_nodes, req)
@routes.get("/api/v0/placement_groups")
async def list_placement_groups(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_placement_groups(
option=self._options_from_req(req)
)
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_placement_groups, req)
@routes.get("/api/v0/workers")
async def list_workers(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_workers(option=self._options_from_req(req))
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_workers, req)
@routes.get("/api/v0/tasks")
async def list_tasks(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_tasks(option=self._options_from_req(req))
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_tasks, req)
@routes.get("/api/v0/objects")
async def list_objects(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_objects(option=self._options_from_req(req))
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_objects, req)
@routes.get("/api/v0/runtime_envs")
@dashboard_optional_utils.aiohttp_cache
async def list_runtime_envs(self, req) -> aiohttp.web.Response:
data = await self._state_api.list_runtime_envs(
option=self._options_from_req(req)
)
return self._reply(success=True, message="", result=data)
return await self._handle_list_api(self._state_api.list_runtime_envs, req)
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel

View file

@ -1,13 +1,12 @@
import asyncio
import logging
from typing import List, Dict
from itertools import islice
from typing import List
from ray.core.generated.common_pb2 import TaskStatus
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.memory_utils as memory_utils
from ray.dashboard.modules.job.common import JobInfo
from ray.experimental.state.common import (
filter_fields,
@ -19,13 +18,34 @@ from ray.experimental.state.common import (
ObjectState,
RuntimeEnvState,
ListApiOptions,
ListApiResponse,
)
from ray.experimental.state.state_manager import (
StateDataSourceClient,
DataSourceUnavailable,
)
from ray.experimental.state.state_manager import StateDataSourceClient
from ray.runtime_env import RuntimeEnv
from ray._private.utils import binary_to_hex
logger = logging.getLogger(__name__)
GCS_QUERY_FAILURE_WARNING = (
"Failed to query data from GCS. It is due to "
"(1) GCS is unexpectedly failed. "
"(2) GCS is overloaded. "
"(3) There's an unexpected network issue. "
"Please check the gcs_server.out log to find the root cause."
)
NODE_QUERY_FAILURE_WARNING = (
"Failed to query data from {type}. "
"Queryed {total} {type} "
"and {network_failures} {type} failed to reply. It is due to "
"(1) {type} is unexpectedly failed. "
"(2) {type} is overloaded. "
"(3) There's an unexpected network issue. Please check the "
"{log_command} to find the root cause."
)
# TODO(sang): Move the class to state/state_manager.py.
# TODO(sang): Remove *State and replaces with Pydantic or protobuf.
@ -42,14 +62,19 @@ class StateAPIManager:
def data_source_client(self):
return self._client
async def list_actors(self, *, option: ListApiOptions) -> dict:
async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all actor information from the cluster.
Returns:
{actor_id -> actor_data_in_dict}
actor_data_in_dict's schema is in ActorState
"""
reply = await self._client.get_all_actor_info(timeout=option.timeout)
try:
reply = await self._client.get_all_actor_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.actor_table_data:
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
@ -58,16 +83,24 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["actor_id"])
return {d["actor_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["actor_id"]: d for d in islice(result, option.limit)}
)
async def list_placement_groups(self, *, option: ListApiOptions) -> dict:
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all placement group information from the cluster.
Returns:
{pg_id -> pg_data_in_dict}
pg_data_in_dict's schema is in PlacementGroupState
"""
reply = await self._client.get_all_placement_group_info(timeout=option.timeout)
try:
reply = await self._client.get_all_placement_group_info(
timeout=option.timeout
)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.placement_group_table_data:
@ -80,16 +113,22 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["placement_group_id"])
return {d["placement_group_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["placement_group_id"]: d for d in islice(result, option.limit)}
)
async def list_nodes(self, *, option: ListApiOptions) -> dict:
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all node information from the cluster.
Returns:
{node_id -> node_data_in_dict}
node_data_in_dict's schema is in NodeState
"""
reply = await self._client.get_all_node_info(timeout=option.timeout)
try:
reply = await self._client.get_all_node_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
@ -98,16 +137,22 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["node_id"])
return {d["node_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["node_id"]: d for d in islice(result, option.limit)}
)
async def list_workers(self, *, option: ListApiOptions) -> dict:
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all worker information from the cluster.
Returns:
{worker_id -> worker_data_in_dict}
worker_data_in_dict's schema is in WorkerState
"""
reply = await self._client.get_all_worker_info(timeout=option.timeout)
try:
reply = await self._client.get_all_worker_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.worker_table_data:
data = self._message_to_dict(
@ -119,34 +164,65 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["worker_id"])
return {d["worker_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["worker_id"]: d for d in islice(result, option.limit)}
)
def list_jobs(self, *, option: ListApiOptions) -> Dict[str, JobInfo]:
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
# TODO(sang): Support limit & timeout & async calls.
return self._client.get_job_info()
try:
result = self._client.get_job_info()
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
return ListApiResponse(result=result)
async def list_tasks(self, *, option: ListApiOptions) -> dict:
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all task information from the cluster.
Returns:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
raylet_ids = self._client.get_all_registered_raylet_ids()
replies = await asyncio.gather(
*[
self._client.get_task_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_raylet_ids()
]
for node_id in raylet_ids
],
return_exceptions=True,
)
unresponsive_nodes = 0
running_task_id = set()
successful_replies = []
for reply in replies:
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
successful_replies.append(reply)
for task_id in reply.running_task_ids:
running_task_id.add(binary_to_hex(task_id))
partial_failure_warning = None
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="raylet",
total=len(raylet_ids),
network_failures=unresponsive_nodes,
log_command="raylet.out",
)
if unresponsive_nodes == len(raylet_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
result = []
for reply in replies:
logger.info(reply)
for reply in successful_replies:
assert not isinstance(reply, Exception)
tasks = reply.owned_task_info_entries
for task in tasks:
data = self._message_to_dict(
@ -162,24 +238,36 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["task_id"])
return {d["task_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["task_id"]: d for d in islice(result, option.limit)},
partial_failure_warning=partial_failure_warning,
)
async def list_objects(self, *, option: ListApiOptions) -> dict:
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all object information from the cluster.
Returns:
{object_id -> object_data_in_dict}
object_data_in_dict's schema is in ObjectState
"""
raylet_ids = self._client.get_all_registered_raylet_ids()
replies = await asyncio.gather(
*[
self._client.get_object_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_raylet_ids()
]
for node_id in raylet_ids
],
return_exceptions=True,
)
unresponsive_nodes = 0
worker_stats = []
for reply in replies:
for reply, node_id in zip(replies, raylet_ids):
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
for core_worker_stat in reply.core_workers_stats:
# NOTE: Set preserving_proto_field_name=False here because
# `construct_memory_table` requires a dictionary that has
@ -193,6 +281,20 @@ class StateAPIManager:
)
)
partial_failure_warning = None
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="raylet",
total=len(raylet_ids),
network_failures=unresponsive_nodes,
log_command="raylet.out",
)
if unresponsive_nodes == len(raylet_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
result = []
memory_table = memory_utils.construct_memory_table(worker_stats)
for entry in memory_table.table:
@ -207,9 +309,12 @@ class StateAPIManager:
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["object_id"])
return {d["object_id"]: d for d in islice(result, option.limit)}
return ListApiResponse(
result={d["object_id"]: d for d in islice(result, option.limit)},
partial_failure_warning=partial_failure_warning,
)
async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]:
async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all runtime env information from the cluster.
Returns:
@ -219,14 +324,24 @@ class StateAPIManager:
We don't have id -> data mapping like other API because runtime env
doesn't have unique ids.
"""
agent_ids = self._client.get_all_registered_agent_ids()
replies = await asyncio.gather(
*[
self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_agent_ids()
]
for node_id in agent_ids
],
return_exceptions=True,
)
result = []
unresponsive_nodes = 0
for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
states = reply.runtime_env_states
for state in states:
data = self._message_to_dict(message=state, fields_to_decode=[])
@ -238,6 +353,20 @@ class StateAPIManager:
data = filter_fields(data, RuntimeEnvState)
result.append(data)
partial_failure_warning = None
if len(agent_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="agent",
total=len(agent_ids),
network_failures=unresponsive_nodes,
log_command="dashboard_agent.log",
)
if unresponsive_nodes == len(agent_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
# Sort to make the output deterministic.
def sort_func(entry):
# If creation time is not there yet (runtime env is failed
@ -251,7 +380,10 @@ class StateAPIManager:
return float(entry["creation_time_ms"])
result.sort(key=sort_func, reverse=True)
return list(islice(result, option.limit))
return ListApiResponse(
result=list(islice(result, option.limit)),
partial_failure_warning=partial_failure_warning,
)
def _message_to_dict(
self,

View file

@ -1,4 +1,5 @@
import requests
import warnings
from dataclasses import fields
@ -8,18 +9,26 @@ from ray.experimental.state.common import (
DEFAULT_RPC_TIMEOUT,
DEFAULT_LIMIT,
)
from ray.experimental.state.exception import RayStateApiException
# TODO(sang): Replace it with auto-generated methods.
def _list(resource_name: str, options: ListApiOptions, api_server_url: str = None):
def _list(
resource_name: str,
options: ListApiOptions,
api_server_url: str = None,
_explain: bool = False,
):
"""Query the API server in address to list "resource_name" states.
Args:
resource_name: The name of the resource. E.g., actor, task.
options: The options for the REST API that are translated to query strings.
address: The address of API server. If it is not give, it assumes the ray
api_server_url: The address of API server. If it is not give, it assumes the ray
is already connected and obtains the API server address using
Ray API.
explain: Print the API information such as API
latency or failed query information.
"""
if api_server_url is None:
assert ray.is_initialized()
@ -40,10 +49,18 @@ def _list(resource_name: str, options: ListApiOptions, api_server_url: str = Non
r.raise_for_status()
response = r.json()
if not response["result"]:
raise ValueError(
"API server internal error. See dashboard.log file for more details."
if response["result"] is False:
raise RayStateApiException(
"API server internal error. See dashboard.log file for more details. "
f"Error: {response['msg']}"
)
if _explain:
# Print warnings if anything was given.
warning_msg = response["data"].get("partial_failure_warning", None)
if warning_msg is not None:
warnings.warn(warning_msg, RuntimeWarning)
return r.json()["data"]["result"]
@ -51,11 +68,13 @@ def list_actors(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"actors",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -63,11 +82,13 @@ def list_placement_groups(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"placement_groups",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -75,11 +96,13 @@ def list_nodes(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"nodes",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -87,11 +110,13 @@ def list_jobs(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"jobs",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -99,11 +124,13 @@ def list_workers(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"workers",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -111,11 +138,13 @@ def list_tasks(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"tasks",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
@ -123,17 +152,25 @@ def list_objects(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"objects",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)
def list_runtime_envs(api_server_url: str = None, limit: int = 1000, timeout: int = 30):
def list_runtime_envs(
api_server_url: str = None,
limit: int = DEFAULT_LIMIT,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
):
return _list(
"runtime_envs",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
_explain=_explain,
)

View file

@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass, fields
from typing import List, Dict, Union
from ray.dashboard.modules.job.common import JobInfo
@ -23,11 +24,20 @@ def filter_fields(data: dict, state_dataclass) -> dict:
class ListApiOptions:
limit: int
timeout: int
# When the request is processed on the server side,
# we should apply multiplier so that server side can finish
# processing a request within timeout. Otherwise,
# timeout will always lead Http timeout.
_server_timeout_multiplier: float = 0.8
# TODO(sang): Use Pydantic instead.
def __post_init__(self):
assert isinstance(self.limit, int)
assert isinstance(self.timeout, int)
# To return the data to users, when there's a partial failure
# we need to have a timeout that's smaller than the users' timeout.
# 80% is configured arbitrarily.
self.timeout = int(self.timeout * self._server_timeout_multiplier)
# TODO(sang): Replace it with Pydantic or gRPC schema (once interface is finalized).
@ -94,3 +104,30 @@ class RuntimeEnvState:
error: str
creation_time_ms: float
node_id: str
@dataclass(init=True)
class ListApiResponse:
# Returned data. None if no data is returned.
result: Union[
Dict[
str,
Union[
ActorState,
PlacementGroupState,
NodeState,
JobInfo,
WorkerState,
TaskState,
ObjectState,
],
],
List[RuntimeEnvState],
] = None
# List API can have a partial failure if queries to
# all sources fail. For example, getting object states
# require to ping all raylets, and it is possible some of
# them fails. Note that it is impossible to guarantee high
# availability of data because ray's state information is
# not replicated.
partial_failure_warning: str = ""

View file

@ -0,0 +1,12 @@
"""Internal Error"""
class DataSourceUnavailable(Exception):
pass
"""User-facing Error"""
class RayStateApiException(Exception):
pass

View file

@ -59,6 +59,12 @@ def get_state_api_output_to_print(
)
def _should_explain(format: AvailableFormat):
# If the format is json or yaml, it should not print stats because
# users don't want additional strings.
return format == AvailableFormat.DEFAULT or format == AvailableFormat.TABLE
@click.group("list")
@click.pass_context
def list_state_cli_group(ctx):
@ -71,6 +77,7 @@ def list_state_cli_group(ctx):
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
num_retries=20,
)
if api_server_url is None:
raise ValueError(
(
@ -92,9 +99,11 @@ def list_state_cli_group(ctx):
@click.pass_context
def actors(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_actors(api_server_url=url), format=AvailableFormat(format)
list_actors(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -106,10 +115,11 @@ def actors(ctx, format: str):
@click.pass_context
def placement_groups(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_placement_groups(api_server_url=url),
format=AvailableFormat(format),
list_placement_groups(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -121,9 +131,11 @@ def placement_groups(ctx, format: str):
@click.pass_context
def nodes(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_nodes(api_server_url=url), format=AvailableFormat(format)
list_nodes(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -135,9 +147,11 @@ def nodes(ctx, format: str):
@click.pass_context
def jobs(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_jobs(api_server_url=url), format=AvailableFormat(format)
list_jobs(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -149,9 +163,11 @@ def jobs(ctx, format: str):
@click.pass_context
def workers(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_workers(api_server_url=url), format=AvailableFormat(format)
list_workers(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -163,9 +179,11 @@ def workers(ctx, format: str):
@click.pass_context
def tasks(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_tasks(api_server_url=url), format=AvailableFormat(format)
list_tasks(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -177,9 +195,11 @@ def tasks(ctx, format: str):
@click.pass_context
def objects(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_objects(api_server_url=url), format=AvailableFormat(format)
list_objects(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)
@ -191,9 +211,10 @@ def objects(ctx, format: str):
@click.pass_context
def runtime_envs(ctx, format: str):
url = ctx.obj["api_server_url"]
format = AvailableFormat(format)
print(
get_state_api_output_to_print(
list_runtime_envs(api_server_url=url),
format=AvailableFormat(format),
list_runtime_envs(api_server_url=url, _explain=_should_explain(format)),
format=format,
)
)

View file

@ -6,7 +6,7 @@ from functools import wraps
import grpc
import ray
from typing import Dict, List
from typing import Dict, List, Optional
from ray import ray_constants
from ray.core.generated.gcs_service_pb2 import (
@ -33,19 +33,13 @@ from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub
from ray.dashboard.modules.job.common import JobInfoStorageClient, JobInfo
from ray.experimental.state.exception import DataSourceUnavailable
logger = logging.getLogger(__name__)
class StateSourceNetworkException(Exception):
"""Exceptions raised when there's a network error from data source query."""
pass
def handle_network_errors(func):
"""Apply the network error handling logic to each APIs,
such as retry or exception policies.
def handle_grpc_network_errors(func):
"""Decorator to add a network handling logic.
It is a helper method for `StateDataSourceClient`.
The method can only be used for async methods.
@ -54,20 +48,32 @@ def handle_network_errors(func):
@wraps(func)
async def api_with_network_error_handler(*args, **kwargs):
"""Apply the network error handling logic to each APIs,
such as retry or exception policies.
Returns:
If RPC succeeds, it returns what the original function returns.
If RPC fails, it raises exceptions.
Exceptions:
DataSourceUnavailable: if the source is unavailable because it is down
or there's a slow network issue causing timeout.
Otherwise, the raw network exceptions (e.g., gRPC) will be raised.
"""
# TODO(sang): Add a retry policy.
try:
return await func(*args, **kwargs)
except (
# https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc-exceptions
grpc.aio.AioRpcError,
grpc.aio.InternalError,
grpc.aio.AbortError,
grpc.aio.BaseError,
grpc.aio.UsageError,
) as e:
raise StateSourceNetworkException(
f"Failed to query the data source, {func}"
) from e
except grpc.aio.AioRpcError as e:
if (
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
or e.code() == grpc.StatusCode.UNAVAILABLE
):
raise DataSourceUnavailable(
"Failed to query the data source. "
"It is either there's a network issue, or the source is down."
)
else:
logger.exception(e)
raise e
return api_with_network_error_handler
@ -81,8 +87,9 @@ class StateDataSourceClient:
finding services and register stubs through `register*` APIs.
Non `register*` APIs
- Return the protobuf directly if it succeeds to query the source.
- Raises an exception if there's any network issue.
- throw a ValueError if it cannot find the source.
- throw `StateSourceNetworkException` if there's any network errors.
"""
def __init__(self, gcs_channel: grpc.aio.Channel):
@ -132,50 +139,66 @@ class StateDataSourceClient:
def get_all_registered_agent_ids(self) -> List[str]:
return self._agent_stubs.keys()
@handle_network_errors
async def get_all_actor_info(self, timeout: int = None) -> GetAllActorInfoReply:
@handle_grpc_network_errors
async def get_all_actor_info(
self, timeout: int = None
) -> Optional[GetAllActorInfoReply]:
request = GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=timeout
)
return reply
@handle_network_errors
@handle_grpc_network_errors
async def get_all_placement_group_info(
self, timeout: int = None
) -> GetAllPlacementGroupReply:
) -> Optional[GetAllPlacementGroupReply]:
request = GetAllPlacementGroupRequest()
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
request, timeout=timeout
)
return reply
@handle_network_errors
async def get_all_node_info(self, timeout: int = None) -> GetAllNodeInfoReply:
@handle_grpc_network_errors
async def get_all_node_info(
self, timeout: int = None
) -> Optional[GetAllNodeInfoReply]:
request = GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
return reply
@handle_network_errors
async def get_all_worker_info(self, timeout: int = None) -> GetAllWorkerInfoReply:
@handle_grpc_network_errors
async def get_all_worker_info(
self, timeout: int = None
) -> Optional[GetAllWorkerInfoReply]:
request = GetAllWorkerInfoRequest()
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
request, timeout=timeout
)
return reply
def get_job_info(self) -> Dict[str, JobInfo]:
# Cannot use @handle_network_errors because async def is not supported yet.
def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
# Cannot use @handle_grpc_network_errors because async def is not supported yet.
# TODO(sang): Support timeout & make it async
try:
return self._job_client.get_all_jobs()
except Exception as e:
raise StateSourceNetworkException("Failed to query the job info.") from e
except grpc.aio.AioRpcError as e:
if (
e.code == grpc.StatusCode.DEADLINE_EXCEEDED
or e.code == grpc.StatusCode.UNAVAILABLE
):
raise DataSourceUnavailable(
"Failed to query the data source. "
"It is either there's a network issue, or the source is down."
)
else:
logger.exception(e)
raise e
@handle_network_errors
@handle_grpc_network_errors
async def get_task_info(
self, node_id: str, timeout: int = None
) -> GetTasksInfoReply:
) -> Optional[GetTasksInfoReply]:
stub = self._raylet_stubs.get(node_id)
if not stub:
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
@ -183,10 +206,10 @@ class StateDataSourceClient:
reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout)
return reply
@handle_network_errors
@handle_grpc_network_errors
async def get_object_info(
self, node_id: str, timeout: int = None
) -> GetNodeStatsReply:
) -> Optional[GetNodeStatsReply]:
stub = self._raylet_stubs.get(node_id)
if not stub:
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
@ -197,10 +220,10 @@ class StateDataSourceClient:
)
return reply
@handle_network_errors
@handle_grpc_network_errors
async def get_runtime_envs_info(
self, node_id: str, timeout: int = None
) -> GetRuntimeEnvsInfoReply:
) -> Optional[GetRuntimeEnvsInfoReply]:
stub = self._agent_stubs.get(node_id)
if not stub:
raise ValueError(f"Agent for a node id, {node_id} doesn't exist.")

View file

@ -8,7 +8,10 @@ from dataclasses import fields
from unittest.mock import MagicMock
from asyncmock import AsyncMock
if sys.version_info > (3, 7, 0):
from unittest.mock import AsyncMock
else:
from asyncmock import AsyncMock
import ray
import ray.ray_constants as ray_constants
@ -41,7 +44,11 @@ from ray.core.generated.runtime_env_common_pb2 import (
)
from ray.core.generated.runtime_env_agent_pb2 import GetRuntimeEnvsInfoReply
import ray.dashboard.consts as dashboard_consts
from ray.dashboard.state_aggregator import StateAPIManager
from ray.dashboard.state_aggregator import (
StateAPIManager,
GCS_QUERY_FAILURE_WARNING,
NODE_QUERY_FAILURE_WARNING,
)
from ray.experimental.state.api import (
list_actors,
list_placement_groups,
@ -64,9 +71,9 @@ from ray.experimental.state.common import (
DEFAULT_RPC_TIMEOUT,
DEFAULT_LIMIT,
)
from ray.experimental.state.exception import DataSourceUnavailable, RayStateApiException
from ray.experimental.state.state_manager import (
StateDataSourceClient,
StateSourceNetworkException,
)
from ray.experimental.state.state_cli import (
list_state_cli_group,
@ -188,7 +195,7 @@ def generate_runtime_env_info(runtime_env, creation_time=None):
def list_api_options(timeout: int = DEFAULT_RPC_TIMEOUT, limit: int = DEFAULT_LIMIT):
return ListApiOptions(limit=limit, timeout=timeout)
return ListApiOptions(limit=limit, timeout=timeout, _server_timeout_multiplier=1.0)
@pytest.mark.asyncio
@ -199,15 +206,25 @@ async def test_api_manager_list_actors(state_api_manager):
actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")]
)
result = await state_api_manager.list_actors(option=list_api_options())
actor_data = list(result.values())[0]
data = result.result
actor_data = list(data.values())[0]
verify_schema(ActorState, actor_data)
"""
Test limit
"""
assert len(result) == 2
assert len(data) == 2
result = await state_api_manager.list_actors(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable()
with pytest.raises(DataSourceUnavailable) as exc_info:
result = await state_api_manager.list_actors(option=list_api_options(limit=1))
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
@pytest.mark.asyncio
@ -223,17 +240,31 @@ async def test_api_manager_list_pgs(state_api_manager):
)
)
result = await state_api_manager.list_placement_groups(option=list_api_options())
data = list(result.values())[0]
data = result.result
data = list(data.values())[0]
verify_schema(PlacementGroupState, data)
"""
Test limit
"""
assert len(result) == 2
assert len(data) == 2
result = await state_api_manager.list_placement_groups(
option=list_api_options(limit=1)
)
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_all_placement_group_info.side_effect = (
DataSourceUnavailable()
)
with pytest.raises(DataSourceUnavailable) as exc_info:
result = await state_api_manager.list_placement_groups(
option=list_api_options(limit=1)
)
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
@pytest.mark.asyncio
@ -244,15 +275,25 @@ async def test_api_manager_list_nodes(state_api_manager):
node_info_list=[generate_node_data(id), generate_node_data(b"12345")]
)
result = await state_api_manager.list_nodes(option=list_api_options())
data = list(result.values())[0]
data = result.result
data = list(data.values())[0]
verify_schema(NodeState, data)
"""
Test limit
"""
assert len(result) == 2
assert len(data) == 2
result = await state_api_manager.list_nodes(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_all_node_info.side_effect = DataSourceUnavailable()
with pytest.raises(DataSourceUnavailable) as exc_info:
result = await state_api_manager.list_nodes(option=list_api_options(limit=1))
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
@pytest.mark.asyncio
@ -266,19 +307,30 @@ async def test_api_manager_list_workers(state_api_manager):
]
)
result = await state_api_manager.list_workers(option=list_api_options())
data = list(result.values())[0]
data = result.result
data = list(data.values())[0]
verify_schema(WorkerState, data)
"""
Test limit
"""
assert len(result) == 2
assert len(result.result) == 2
result = await state_api_manager.list_workers(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable()
with pytest.raises(DataSourceUnavailable) as exc_info:
result = await state_api_manager.list_workers(option=list_api_options(limit=1))
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
@pytest.mark.skip(
reason=("Not passing in CI although it works locally. Will handle it later.")
@pytest.mark.skipif(
sys.version_info <= (3, 7, 0),
reason=("Not passing in CI although it works locally. Will handle it later."),
)
@pytest.mark.asyncio
async def test_api_manager_list_tasks(state_api_manager):
@ -288,17 +340,19 @@ async def test_api_manager_list_tasks(state_api_manager):
first_task_name = "1"
second_task_name = "2"
data_source_client.get_task_info = AsyncMock()
data_source_client.get_task_info.side_effect = [
generate_task_data(b"1234", first_task_name),
generate_task_data(b"2345", second_task_name),
]
result = await state_api_manager.list_tasks(option=list_api_options())
data_source_client.get_task_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
data_source_client.get_task_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
result = list(result.values())
assert len(result) == 2
verify_schema(TaskState, result[0])
verify_schema(TaskState, result[1])
data_source_client.get_task_info.assert_any_await("1", timeout=DEFAULT_RPC_TIMEOUT)
data_source_client.get_task_info.assert_any_await("2", timeout=DEFAULT_RPC_TIMEOUT)
data = result.result
data = list(data.values())
assert len(data) == 2
verify_schema(TaskState, data[0])
verify_schema(TaskState, data[1])
"""
Test limit
@ -308,11 +362,38 @@ async def test_api_manager_list_tasks(state_api_manager):
generate_task_data(b"2345", second_task_name),
]
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_task_info.side_effect = [
DataSourceUnavailable(),
generate_task_data(b"2345", second_task_name),
]
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
# Make sure warnings are returned.
warning = result.partial_failure_warning
assert (
NODE_QUERY_FAILURE_WARNING.format(
type="raylet", total=2, network_failures=1, log_command="raylet.out"
)
in warning
)
# Test if all RPCs fail, it will raise an exception.
data_source_client.get_task_info.side_effect = [
DataSourceUnavailable(),
DataSourceUnavailable(),
]
with pytest.raises(DataSourceUnavailable):
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
@pytest.mark.skip(
reason=("Not passing in CI although it works locally. Will handle it later.")
@pytest.mark.skipif(
sys.version_info <= (3, 7, 0),
reason=("Not passing in CI although it works locally. Will handle it later."),
)
@pytest.mark.asyncio
async def test_api_manager_list_objects(state_api_manager):
@ -322,17 +403,23 @@ async def test_api_manager_list_objects(state_api_manager):
data_source_client.get_all_registered_raylet_ids = MagicMock()
data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]
data_source_client.get_object_info = AsyncMock()
data_source_client.get_object_info.side_effect = [
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]),
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
]
result = await state_api_manager.list_objects(option=list_api_options())
data_source_client.get_object_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
data_source_client.get_object_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
result = list(result.values())
assert len(result) == 2
verify_schema(ObjectState, result[0])
verify_schema(ObjectState, result[1])
data = result.result
data_source_client.get_object_info.assert_any_await(
"1", timeout=DEFAULT_RPC_TIMEOUT
)
data_source_client.get_object_info.assert_any_await(
"2", timeout=DEFAULT_RPC_TIMEOUT
)
data = list(data.values())
assert len(data) == 2
verify_schema(ObjectState, data[0])
verify_schema(ObjectState, data[1])
"""
Test limit
@ -342,11 +429,38 @@ async def test_api_manager_list_objects(state_api_manager):
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
]
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_object_info.side_effect = [
DataSourceUnavailable(),
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
]
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
# Make sure warnings are returned.
warning = result.partial_failure_warning
assert (
NODE_QUERY_FAILURE_WARNING.format(
type="raylet", total=2, network_failures=1, log_command="raylet.out"
)
in warning
)
# Test if all RPCs fail, it will raise an exception.
data_source_client.get_object_info.side_effect = [
DataSourceUnavailable(),
DataSourceUnavailable(),
]
with pytest.raises(DataSourceUnavailable):
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
@pytest.mark.skip(
reason=("Not passing in CI although it works locally. Will handle it later.")
@pytest.mark.skipif(
sys.version_info <= (3, 7, 0),
reason=("Not passing in CI although it works locally. Will handle it later."),
)
@pytest.mark.asyncio
async def test_api_manager_list_runtime_envs(state_api_manager):
@ -354,6 +468,7 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
data_source_client.get_all_registered_agent_ids = MagicMock()
data_source_client.get_all_registered_agent_ids.return_value = ["1", "2", "3"]
data_source_client.get_runtime_envs_info = AsyncMock()
data_source_client.get_runtime_envs_info.side_effect = [
generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
generate_runtime_env_info(
@ -362,23 +477,25 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), creation_time=10),
]
result = await state_api_manager.list_runtime_envs(option=list_api_options())
data_source_client.get_runtime_envs_info.assert_any_call(
data = result.result
data_source_client.get_runtime_envs_info.assert_any_await(
"1", timeout=DEFAULT_RPC_TIMEOUT
)
data_source_client.get_runtime_envs_info.assert_any_call(
data_source_client.get_runtime_envs_info.assert_any_await(
"2", timeout=DEFAULT_RPC_TIMEOUT
)
data_source_client.get_runtime_envs_info.assert_any_call(
data_source_client.get_runtime_envs_info.assert_any_await(
"3", timeout=DEFAULT_RPC_TIMEOUT
)
assert len(result) == 3
verify_schema(RuntimeEnvState, result[0])
verify_schema(RuntimeEnvState, result[1])
verify_schema(RuntimeEnvState, result[2])
assert len(data) == 3
verify_schema(RuntimeEnvState, data[0])
verify_schema(RuntimeEnvState, data[1])
verify_schema(RuntimeEnvState, data[2])
# Make sure the higher creation time is sorted first.
assert "creation_time_ms" not in result[0]
result[1]["creation_time_ms"] > result[2]["creation_time_ms"]
assert "creation_time_ms" not in data[0]
data[1]["creation_time_ms"] > data[2]["creation_time_ms"]
"""
Test limit
@ -391,7 +508,38 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
]
result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1))
assert len(result) == 1
data = result.result
assert len(data) == 1
"""
Test error handling
"""
data_source_client.get_runtime_envs_info.side_effect = [
DataSourceUnavailable(),
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
]
result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1))
# Make sure warnings are returned.
warning = result.partial_failure_warning
print(warning)
assert (
NODE_QUERY_FAILURE_WARNING.format(
type="agent", total=3, network_failures=1, log_command="dashboard_agent.log"
)
in warning
)
# Test if all RPCs fail, it will raise an exception.
data_source_client.get_runtime_envs_info.side_effect = [
DataSourceUnavailable(),
DataSourceUnavailable(),
DataSourceUnavailable(),
]
with pytest.raises(DataSourceUnavailable):
result = await state_api_manager.list_runtime_envs(
option=list_api_options(limit=1)
)
"""
@ -495,7 +643,7 @@ async def test_state_data_source_client(ray_start_cluster):
"""
with pytest.raises(ValueError):
# Since we didn't register this node id, it should raise an exception.
result = await client.get_object_info("1234")
result = await client.get_runtime_envs_info("1234")
wait_for_condition(lambda: len(ray.nodes()) == 2)
for node in ray.nodes():
node_id = node["NodeID"]
@ -527,10 +675,9 @@ async def test_state_data_source_client(ray_start_cluster):
if node["Alive"]:
continue
# Querying to the dead node raises gRPC error, which should be
# translated into `StateSourceNetworkException`
with pytest.raises(StateSourceNetworkException):
result = await client.get_object_info(node_id)
# Querying to the dead node raises gRPC error, which should raise an exception.
with pytest.raises(DataSourceUnavailable):
await client.get_object_info(node_id)
# Make sure unregister API works as expected.
client.unregister_raylet_client(node_id)
@ -685,6 +832,7 @@ def test_list_jobs(shutdown_only):
def verify():
job_data = list(list_jobs().values())[0]
print(job_data)
job_id_from_api = list(list_jobs().keys())[0]
correct_state = job_data["status"] == "SUCCEEDED"
correct_id = job_id == job_id_from_api
@ -914,6 +1062,91 @@ def test_limit(shutdown_only):
assert output == list_actors(limit=2)
def test_network_failure(shutdown_only):
"""When the request fails due to network failure,
verifies it raises an exception."""
ray.init()
@ray.remote
def f():
import time
time.sleep(30)
a = [f.remote() for _ in range(4)] # noqa
wait_for_condition(lambda: len(list_tasks()) == 4)
# Kill raylet so that list_tasks will have network error on querying raylets.
ray.worker._global_node.kill_raylet()
with pytest.raises(RayStateApiException):
list_tasks(_explain=True)
def test_network_partial_failures(ray_start_cluster):
"""When the request fails due to network failure,
verifies it prints proper warning."""
cluster = ray_start_cluster
cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
n = cluster.add_node(num_cpus=2)
@ray.remote
def f():
import time
time.sleep(30)
a = [f.remote() for _ in range(4)] # noqa
wait_for_condition(lambda: len(list_tasks()) == 4)
# Make sure when there's 0 node failure, it doesn't print the error.
with pytest.warns(None) as record:
list_tasks(_explain=True)
assert len(record) == 0
# Kill raylet so that list_tasks will have network error on querying raylets.
cluster.remove_node(n, allow_graceful=False)
with pytest.warns(RuntimeWarning):
list_tasks(_explain=True)
# Make sure when _explain == False, warning is not printed.
with pytest.warns(None) as record:
list_tasks(_explain=False)
assert len(record) == 0
def test_network_partial_failures_timeout(monkeypatch, ray_start_cluster):
"""When the request fails due to network timeout,
verifies it prints proper warning."""
cluster = ray_start_cluster
cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
with monkeypatch.context() as m:
# defer for 10s for the second node.
m.setenv(
"RAY_testing_asio_delay_us",
"NodeManagerService.grpc_server.GetTasksInfo=10000000:10000000",
)
cluster.add_node(num_cpus=2)
@ray.remote
def f():
import time
time.sleep(30)
a = [f.remote() for _ in range(4)] # noqa
def verify():
with pytest.warns(None) as record:
list_tasks(_explain=True, timeout=5)
return len(record) == 1
wait_for_condition(verify)
@pytest.mark.asyncio
async def test_cli_format_print(state_api_manager):
data_source_client = state_api_manager.data_source_client
@ -922,6 +1155,7 @@ async def test_cli_format_print(state_api_manager):
actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")]
)
result = await state_api_manager.list_actors(option=list_api_options())
result = result.result
# If the format is not yaml, it will raise an exception.
yaml.load(
get_state_api_output_to_print(result, format=AvailableFormat.YAML),