mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[State Observability API] Error handling (#24413)
This improves error handling per https://docs.google.com/document/d/1IeEsJOiurg-zctOcBjY-tQVbsCmURFSnUCTkx_4a7Cw/edit#heading=h.pdzl9cil9e8z (the RPC part). Semantics If all queries to the source failed, raise a RayStateApiException. If partial queries are failed, warnings.warn the partial failure when print_api_stats=True. It is true for CLI. It is false when it is used within Python API or json / yaml format is required.
This commit is contained in:
parent
e73c37cc17
commit
a7e759317b
8 changed files with 677 additions and 169 deletions
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
|
||||
import aiohttp.web
|
||||
|
||||
import dataclasses
|
||||
|
@ -10,6 +11,7 @@ import ray.dashboard.optional_utils as dashboard_optional_utils
|
|||
from ray.dashboard.optional_utils import rest_response
|
||||
from ray.dashboard.state_aggregator import StateAPIManager
|
||||
from ray.experimental.state.common import ListApiOptions
|
||||
from ray.experimental.state.exception import DataSourceUnavailable
|
||||
from ray.experimental.state.state_manager import StateDataSourceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -33,13 +35,19 @@ class StateHead(dashboard_utils.DashboardHeadModule):
|
|||
def _options_from_req(self, req) -> ListApiOptions:
|
||||
"""Obtain `ListApiOptions` from the aiohttp request."""
|
||||
limit = int(req.query.get("limit"))
|
||||
# Only apply 80% of the timeout so that
|
||||
# the API will reply before client times out if query to the source fails.
|
||||
timeout = int(req.query.get("timeout"))
|
||||
return ListApiOptions(limit=limit, timeout=timeout)
|
||||
|
||||
def _reply(self, success: bool, message: str, result: dict):
|
||||
def _reply(self, success: bool, error_message: str, result: dict, **kwargs):
|
||||
"""Reply to the client."""
|
||||
return rest_response(
|
||||
success=success, message=message, result=result, convert_google_style=False
|
||||
success=success,
|
||||
message=error_message,
|
||||
result=result,
|
||||
convert_google_style=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _update_raylet_stubs(self, change: Change):
|
||||
|
@ -85,57 +93,61 @@ class StateHead(dashboard_utils.DashboardHeadModule):
|
|||
int(ports[1]),
|
||||
)
|
||||
|
||||
async def _handle_list_api(self, list_api_fn, req):
|
||||
try:
|
||||
result = await list_api_fn(option=self._options_from_req(req))
|
||||
return self._reply(
|
||||
success=True,
|
||||
error_message="",
|
||||
result=result.result,
|
||||
partial_failure_warning=result.partial_failure_warning,
|
||||
)
|
||||
except DataSourceUnavailable as e:
|
||||
return self._reply(success=False, error_message=str(e), result=None)
|
||||
|
||||
@routes.get("/api/v0/actors")
|
||||
async def list_actors(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_actors(option=self._options_from_req(req))
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_actors, req)
|
||||
|
||||
@routes.get("/api/v0/jobs")
|
||||
async def list_jobs(self, req) -> aiohttp.web.Response:
|
||||
data = self._state_api.list_jobs(option=self._options_from_req(req))
|
||||
return self._reply(
|
||||
success=True,
|
||||
message="",
|
||||
result={
|
||||
job_id: dataclasses.asdict(job_info)
|
||||
for job_id, job_info in data.items()
|
||||
},
|
||||
)
|
||||
try:
|
||||
result = self._state_api.list_jobs(option=self._options_from_req(req))
|
||||
return self._reply(
|
||||
success=True,
|
||||
error_message="",
|
||||
result={
|
||||
job_id: dataclasses.asdict(job_info)
|
||||
for job_id, job_info in result.result.items()
|
||||
},
|
||||
partial_failure_warning=result.partial_failure_warning,
|
||||
)
|
||||
except DataSourceUnavailable as e:
|
||||
return self._reply(success=False, error_message=str(e), result=None)
|
||||
|
||||
@routes.get("/api/v0/nodes")
|
||||
async def list_nodes(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_nodes(option=self._options_from_req(req))
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_nodes, req)
|
||||
|
||||
@routes.get("/api/v0/placement_groups")
|
||||
async def list_placement_groups(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_placement_groups(
|
||||
option=self._options_from_req(req)
|
||||
)
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_placement_groups, req)
|
||||
|
||||
@routes.get("/api/v0/workers")
|
||||
async def list_workers(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_workers(option=self._options_from_req(req))
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_workers, req)
|
||||
|
||||
@routes.get("/api/v0/tasks")
|
||||
async def list_tasks(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_tasks(option=self._options_from_req(req))
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_tasks, req)
|
||||
|
||||
@routes.get("/api/v0/objects")
|
||||
async def list_objects(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_objects(option=self._options_from_req(req))
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_objects, req)
|
||||
|
||||
@routes.get("/api/v0/runtime_envs")
|
||||
@dashboard_optional_utils.aiohttp_cache
|
||||
async def list_runtime_envs(self, req) -> aiohttp.web.Response:
|
||||
data = await self._state_api.list_runtime_envs(
|
||||
option=self._options_from_req(req)
|
||||
)
|
||||
return self._reply(success=True, message="", result=data)
|
||||
return await self._handle_list_api(self._state_api.list_runtime_envs, req)
|
||||
|
||||
async def run(self, server):
|
||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import List, Dict
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
|
||||
from ray.core.generated.common_pb2 import TaskStatus
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.memory_utils as memory_utils
|
||||
from ray.dashboard.modules.job.common import JobInfo
|
||||
|
||||
from ray.experimental.state.common import (
|
||||
filter_fields,
|
||||
|
@ -19,13 +18,34 @@ from ray.experimental.state.common import (
|
|||
ObjectState,
|
||||
RuntimeEnvState,
|
||||
ListApiOptions,
|
||||
ListApiResponse,
|
||||
)
|
||||
from ray.experimental.state.state_manager import (
|
||||
StateDataSourceClient,
|
||||
DataSourceUnavailable,
|
||||
)
|
||||
from ray.experimental.state.state_manager import StateDataSourceClient
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray._private.utils import binary_to_hex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GCS_QUERY_FAILURE_WARNING = (
|
||||
"Failed to query data from GCS. It is due to "
|
||||
"(1) GCS is unexpectedly failed. "
|
||||
"(2) GCS is overloaded. "
|
||||
"(3) There's an unexpected network issue. "
|
||||
"Please check the gcs_server.out log to find the root cause."
|
||||
)
|
||||
NODE_QUERY_FAILURE_WARNING = (
|
||||
"Failed to query data from {type}. "
|
||||
"Queryed {total} {type} "
|
||||
"and {network_failures} {type} failed to reply. It is due to "
|
||||
"(1) {type} is unexpectedly failed. "
|
||||
"(2) {type} is overloaded. "
|
||||
"(3) There's an unexpected network issue. Please check the "
|
||||
"{log_command} to find the root cause."
|
||||
)
|
||||
|
||||
|
||||
# TODO(sang): Move the class to state/state_manager.py.
|
||||
# TODO(sang): Remove *State and replaces with Pydantic or protobuf.
|
||||
|
@ -42,14 +62,19 @@ class StateAPIManager:
|
|||
def data_source_client(self):
|
||||
return self._client
|
||||
|
||||
async def list_actors(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all actor information from the cluster.
|
||||
|
||||
Returns:
|
||||
{actor_id -> actor_data_in_dict}
|
||||
actor_data_in_dict's schema is in ActorState
|
||||
|
||||
"""
|
||||
reply = await self._client.get_all_actor_info(timeout=option.timeout)
|
||||
try:
|
||||
reply = await self._client.get_all_actor_info(timeout=option.timeout)
|
||||
except DataSourceUnavailable:
|
||||
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
|
||||
|
||||
result = []
|
||||
for message in reply.actor_table_data:
|
||||
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
|
||||
|
@ -58,16 +83,24 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["actor_id"])
|
||||
return {d["actor_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["actor_id"]: d for d in islice(result, option.limit)}
|
||||
)
|
||||
|
||||
async def list_placement_groups(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all placement group information from the cluster.
|
||||
|
||||
Returns:
|
||||
{pg_id -> pg_data_in_dict}
|
||||
pg_data_in_dict's schema is in PlacementGroupState
|
||||
"""
|
||||
reply = await self._client.get_all_placement_group_info(timeout=option.timeout)
|
||||
try:
|
||||
reply = await self._client.get_all_placement_group_info(
|
||||
timeout=option.timeout
|
||||
)
|
||||
except DataSourceUnavailable:
|
||||
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
|
||||
|
||||
result = []
|
||||
for message in reply.placement_group_table_data:
|
||||
|
||||
|
@ -80,16 +113,22 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["placement_group_id"])
|
||||
return {d["placement_group_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["placement_group_id"]: d for d in islice(result, option.limit)}
|
||||
)
|
||||
|
||||
async def list_nodes(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all node information from the cluster.
|
||||
|
||||
Returns:
|
||||
{node_id -> node_data_in_dict}
|
||||
node_data_in_dict's schema is in NodeState
|
||||
"""
|
||||
reply = await self._client.get_all_node_info(timeout=option.timeout)
|
||||
try:
|
||||
reply = await self._client.get_all_node_info(timeout=option.timeout)
|
||||
except DataSourceUnavailable:
|
||||
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
|
||||
|
||||
result = []
|
||||
for message in reply.node_info_list:
|
||||
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
|
||||
|
@ -98,16 +137,22 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["node_id"])
|
||||
return {d["node_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["node_id"]: d for d in islice(result, option.limit)}
|
||||
)
|
||||
|
||||
async def list_workers(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all worker information from the cluster.
|
||||
|
||||
Returns:
|
||||
{worker_id -> worker_data_in_dict}
|
||||
worker_data_in_dict's schema is in WorkerState
|
||||
"""
|
||||
reply = await self._client.get_all_worker_info(timeout=option.timeout)
|
||||
try:
|
||||
reply = await self._client.get_all_worker_info(timeout=option.timeout)
|
||||
except DataSourceUnavailable:
|
||||
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
|
||||
|
||||
result = []
|
||||
for message in reply.worker_table_data:
|
||||
data = self._message_to_dict(
|
||||
|
@ -119,34 +164,65 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["worker_id"])
|
||||
return {d["worker_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["worker_id"]: d for d in islice(result, option.limit)}
|
||||
)
|
||||
|
||||
def list_jobs(self, *, option: ListApiOptions) -> Dict[str, JobInfo]:
|
||||
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
# TODO(sang): Support limit & timeout & async calls.
|
||||
return self._client.get_job_info()
|
||||
try:
|
||||
result = self._client.get_job_info()
|
||||
except DataSourceUnavailable:
|
||||
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
|
||||
return ListApiResponse(result=result)
|
||||
|
||||
async def list_tasks(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all task information from the cluster.
|
||||
|
||||
Returns:
|
||||
{task_id -> task_data_in_dict}
|
||||
task_data_in_dict's schema is in TaskState
|
||||
"""
|
||||
raylet_ids = self._client.get_all_registered_raylet_ids()
|
||||
replies = await asyncio.gather(
|
||||
*[
|
||||
self._client.get_task_info(node_id, timeout=option.timeout)
|
||||
for node_id in self._client.get_all_registered_raylet_ids()
|
||||
]
|
||||
for node_id in raylet_ids
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
unresponsive_nodes = 0
|
||||
running_task_id = set()
|
||||
successful_replies = []
|
||||
for reply in replies:
|
||||
if isinstance(reply, DataSourceUnavailable):
|
||||
unresponsive_nodes += 1
|
||||
continue
|
||||
elif isinstance(reply, Exception):
|
||||
raise reply
|
||||
|
||||
successful_replies.append(reply)
|
||||
for task_id in reply.running_task_ids:
|
||||
running_task_id.add(binary_to_hex(task_id))
|
||||
|
||||
partial_failure_warning = None
|
||||
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
|
||||
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="raylet",
|
||||
total=len(raylet_ids),
|
||||
network_failures=unresponsive_nodes,
|
||||
log_command="raylet.out",
|
||||
)
|
||||
if unresponsive_nodes == len(raylet_ids):
|
||||
raise DataSourceUnavailable(warning_msg)
|
||||
partial_failure_warning = (
|
||||
f"The returned data may contain incomplete result. {warning_msg}"
|
||||
)
|
||||
|
||||
result = []
|
||||
for reply in replies:
|
||||
logger.info(reply)
|
||||
for reply in successful_replies:
|
||||
assert not isinstance(reply, Exception)
|
||||
tasks = reply.owned_task_info_entries
|
||||
for task in tasks:
|
||||
data = self._message_to_dict(
|
||||
|
@ -162,24 +238,36 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["task_id"])
|
||||
return {d["task_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["task_id"]: d for d in islice(result, option.limit)},
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
)
|
||||
|
||||
async def list_objects(self, *, option: ListApiOptions) -> dict:
|
||||
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all object information from the cluster.
|
||||
|
||||
Returns:
|
||||
{object_id -> object_data_in_dict}
|
||||
object_data_in_dict's schema is in ObjectState
|
||||
"""
|
||||
raylet_ids = self._client.get_all_registered_raylet_ids()
|
||||
replies = await asyncio.gather(
|
||||
*[
|
||||
self._client.get_object_info(node_id, timeout=option.timeout)
|
||||
for node_id in self._client.get_all_registered_raylet_ids()
|
||||
]
|
||||
for node_id in raylet_ids
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
unresponsive_nodes = 0
|
||||
worker_stats = []
|
||||
for reply in replies:
|
||||
for reply, node_id in zip(replies, raylet_ids):
|
||||
if isinstance(reply, DataSourceUnavailable):
|
||||
unresponsive_nodes += 1
|
||||
continue
|
||||
elif isinstance(reply, Exception):
|
||||
raise reply
|
||||
|
||||
for core_worker_stat in reply.core_workers_stats:
|
||||
# NOTE: Set preserving_proto_field_name=False here because
|
||||
# `construct_memory_table` requires a dictionary that has
|
||||
|
@ -193,6 +281,20 @@ class StateAPIManager:
|
|||
)
|
||||
)
|
||||
|
||||
partial_failure_warning = None
|
||||
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
|
||||
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="raylet",
|
||||
total=len(raylet_ids),
|
||||
network_failures=unresponsive_nodes,
|
||||
log_command="raylet.out",
|
||||
)
|
||||
if unresponsive_nodes == len(raylet_ids):
|
||||
raise DataSourceUnavailable(warning_msg)
|
||||
partial_failure_warning = (
|
||||
f"The returned data may contain incomplete result. {warning_msg}"
|
||||
)
|
||||
|
||||
result = []
|
||||
memory_table = memory_utils.construct_memory_table(worker_stats)
|
||||
for entry in memory_table.table:
|
||||
|
@ -207,9 +309,12 @@ class StateAPIManager:
|
|||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["object_id"])
|
||||
return {d["object_id"]: d for d in islice(result, option.limit)}
|
||||
return ListApiResponse(
|
||||
result={d["object_id"]: d for d in islice(result, option.limit)},
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
)
|
||||
|
||||
async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]:
|
||||
async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
"""List all runtime env information from the cluster.
|
||||
|
||||
Returns:
|
||||
|
@ -219,14 +324,24 @@ class StateAPIManager:
|
|||
We don't have id -> data mapping like other API because runtime env
|
||||
doesn't have unique ids.
|
||||
"""
|
||||
agent_ids = self._client.get_all_registered_agent_ids()
|
||||
replies = await asyncio.gather(
|
||||
*[
|
||||
self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
|
||||
for node_id in self._client.get_all_registered_agent_ids()
|
||||
]
|
||||
for node_id in agent_ids
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
result = []
|
||||
unresponsive_nodes = 0
|
||||
for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
|
||||
if isinstance(reply, DataSourceUnavailable):
|
||||
unresponsive_nodes += 1
|
||||
continue
|
||||
elif isinstance(reply, Exception):
|
||||
raise reply
|
||||
|
||||
states = reply.runtime_env_states
|
||||
for state in states:
|
||||
data = self._message_to_dict(message=state, fields_to_decode=[])
|
||||
|
@ -238,6 +353,20 @@ class StateAPIManager:
|
|||
data = filter_fields(data, RuntimeEnvState)
|
||||
result.append(data)
|
||||
|
||||
partial_failure_warning = None
|
||||
if len(agent_ids) > 0 and unresponsive_nodes > 0:
|
||||
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="agent",
|
||||
total=len(agent_ids),
|
||||
network_failures=unresponsive_nodes,
|
||||
log_command="dashboard_agent.log",
|
||||
)
|
||||
if unresponsive_nodes == len(agent_ids):
|
||||
raise DataSourceUnavailable(warning_msg)
|
||||
partial_failure_warning = (
|
||||
f"The returned data may contain incomplete result. {warning_msg}"
|
||||
)
|
||||
|
||||
# Sort to make the output deterministic.
|
||||
def sort_func(entry):
|
||||
# If creation time is not there yet (runtime env is failed
|
||||
|
@ -251,7 +380,10 @@ class StateAPIManager:
|
|||
return float(entry["creation_time_ms"])
|
||||
|
||||
result.sort(key=sort_func, reverse=True)
|
||||
return list(islice(result, option.limit))
|
||||
return ListApiResponse(
|
||||
result=list(islice(result, option.limit)),
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
)
|
||||
|
||||
def _message_to_dict(
|
||||
self,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import requests
|
||||
import warnings
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
|
@ -8,18 +9,26 @@ from ray.experimental.state.common import (
|
|||
DEFAULT_RPC_TIMEOUT,
|
||||
DEFAULT_LIMIT,
|
||||
)
|
||||
from ray.experimental.state.exception import RayStateApiException
|
||||
|
||||
|
||||
# TODO(sang): Replace it with auto-generated methods.
|
||||
def _list(resource_name: str, options: ListApiOptions, api_server_url: str = None):
|
||||
def _list(
|
||||
resource_name: str,
|
||||
options: ListApiOptions,
|
||||
api_server_url: str = None,
|
||||
_explain: bool = False,
|
||||
):
|
||||
"""Query the API server in address to list "resource_name" states.
|
||||
|
||||
Args:
|
||||
resource_name: The name of the resource. E.g., actor, task.
|
||||
options: The options for the REST API that are translated to query strings.
|
||||
address: The address of API server. If it is not give, it assumes the ray
|
||||
api_server_url: The address of API server. If it is not give, it assumes the ray
|
||||
is already connected and obtains the API server address using
|
||||
Ray API.
|
||||
explain: Print the API information such as API
|
||||
latency or failed query information.
|
||||
"""
|
||||
if api_server_url is None:
|
||||
assert ray.is_initialized()
|
||||
|
@ -40,10 +49,18 @@ def _list(resource_name: str, options: ListApiOptions, api_server_url: str = Non
|
|||
r.raise_for_status()
|
||||
|
||||
response = r.json()
|
||||
if not response["result"]:
|
||||
raise ValueError(
|
||||
"API server internal error. See dashboard.log file for more details."
|
||||
if response["result"] is False:
|
||||
raise RayStateApiException(
|
||||
"API server internal error. See dashboard.log file for more details. "
|
||||
f"Error: {response['msg']}"
|
||||
)
|
||||
|
||||
if _explain:
|
||||
# Print warnings if anything was given.
|
||||
warning_msg = response["data"].get("partial_failure_warning", None)
|
||||
if warning_msg is not None:
|
||||
warnings.warn(warning_msg, RuntimeWarning)
|
||||
|
||||
return r.json()["data"]["result"]
|
||||
|
||||
|
||||
|
@ -51,11 +68,13 @@ def list_actors(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"actors",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -63,11 +82,13 @@ def list_placement_groups(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"placement_groups",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -75,11 +96,13 @@ def list_nodes(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"nodes",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -87,11 +110,13 @@ def list_jobs(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"jobs",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -99,11 +124,13 @@ def list_workers(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"workers",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -111,11 +138,13 @@ def list_tasks(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"tasks",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
|
@ -123,17 +152,25 @@ def list_objects(
|
|||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"objects",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
||||
|
||||
def list_runtime_envs(api_server_url: str = None, limit: int = 1000, timeout: int = 30):
|
||||
def list_runtime_envs(
|
||||
api_server_url: str = None,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
timeout: int = DEFAULT_RPC_TIMEOUT,
|
||||
_explain: bool = False,
|
||||
):
|
||||
return _list(
|
||||
"runtime_envs",
|
||||
ListApiOptions(limit=limit, timeout=timeout),
|
||||
api_server_url=api_server_url,
|
||||
_explain=_explain,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import List, Dict, Union
|
||||
|
||||
from ray.dashboard.modules.job.common import JobInfo
|
||||
|
||||
|
@ -23,11 +24,20 @@ def filter_fields(data: dict, state_dataclass) -> dict:
|
|||
class ListApiOptions:
|
||||
limit: int
|
||||
timeout: int
|
||||
# When the request is processed on the server side,
|
||||
# we should apply multiplier so that server side can finish
|
||||
# processing a request within timeout. Otherwise,
|
||||
# timeout will always lead Http timeout.
|
||||
_server_timeout_multiplier: float = 0.8
|
||||
|
||||
# TODO(sang): Use Pydantic instead.
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.limit, int)
|
||||
assert isinstance(self.timeout, int)
|
||||
# To return the data to users, when there's a partial failure
|
||||
# we need to have a timeout that's smaller than the users' timeout.
|
||||
# 80% is configured arbitrarily.
|
||||
self.timeout = int(self.timeout * self._server_timeout_multiplier)
|
||||
|
||||
|
||||
# TODO(sang): Replace it with Pydantic or gRPC schema (once interface is finalized).
|
||||
|
@ -94,3 +104,30 @@ class RuntimeEnvState:
|
|||
error: str
|
||||
creation_time_ms: float
|
||||
node_id: str
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class ListApiResponse:
|
||||
# Returned data. None if no data is returned.
|
||||
result: Union[
|
||||
Dict[
|
||||
str,
|
||||
Union[
|
||||
ActorState,
|
||||
PlacementGroupState,
|
||||
NodeState,
|
||||
JobInfo,
|
||||
WorkerState,
|
||||
TaskState,
|
||||
ObjectState,
|
||||
],
|
||||
],
|
||||
List[RuntimeEnvState],
|
||||
] = None
|
||||
# List API can have a partial failure if queries to
|
||||
# all sources fail. For example, getting object states
|
||||
# require to ping all raylets, and it is possible some of
|
||||
# them fails. Note that it is impossible to guarantee high
|
||||
# availability of data because ray's state information is
|
||||
# not replicated.
|
||||
partial_failure_warning: str = ""
|
||||
|
|
12
python/ray/experimental/state/exception.py
Normal file
12
python/ray/experimental/state/exception.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
"""Internal Error"""
|
||||
|
||||
|
||||
class DataSourceUnavailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
"""User-facing Error"""
|
||||
|
||||
|
||||
class RayStateApiException(Exception):
|
||||
pass
|
|
@ -59,6 +59,12 @@ def get_state_api_output_to_print(
|
|||
)
|
||||
|
||||
|
||||
def _should_explain(format: AvailableFormat):
|
||||
# If the format is json or yaml, it should not print stats because
|
||||
# users don't want additional strings.
|
||||
return format == AvailableFormat.DEFAULT or format == AvailableFormat.TABLE
|
||||
|
||||
|
||||
@click.group("list")
|
||||
@click.pass_context
|
||||
def list_state_cli_group(ctx):
|
||||
|
@ -71,6 +77,7 @@ def list_state_cli_group(ctx):
|
|||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
num_retries=20,
|
||||
)
|
||||
|
||||
if api_server_url is None:
|
||||
raise ValueError(
|
||||
(
|
||||
|
@ -92,9 +99,11 @@ def list_state_cli_group(ctx):
|
|||
@click.pass_context
|
||||
def actors(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_actors(api_server_url=url), format=AvailableFormat(format)
|
||||
list_actors(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -106,10 +115,11 @@ def actors(ctx, format: str):
|
|||
@click.pass_context
|
||||
def placement_groups(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_placement_groups(api_server_url=url),
|
||||
format=AvailableFormat(format),
|
||||
list_placement_groups(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -121,9 +131,11 @@ def placement_groups(ctx, format: str):
|
|||
@click.pass_context
|
||||
def nodes(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_nodes(api_server_url=url), format=AvailableFormat(format)
|
||||
list_nodes(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -135,9 +147,11 @@ def nodes(ctx, format: str):
|
|||
@click.pass_context
|
||||
def jobs(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_jobs(api_server_url=url), format=AvailableFormat(format)
|
||||
list_jobs(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -149,9 +163,11 @@ def jobs(ctx, format: str):
|
|||
@click.pass_context
|
||||
def workers(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_workers(api_server_url=url), format=AvailableFormat(format)
|
||||
list_workers(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -163,9 +179,11 @@ def workers(ctx, format: str):
|
|||
@click.pass_context
|
||||
def tasks(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_tasks(api_server_url=url), format=AvailableFormat(format)
|
||||
list_tasks(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -177,9 +195,11 @@ def tasks(ctx, format: str):
|
|||
@click.pass_context
|
||||
def objects(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_objects(api_server_url=url), format=AvailableFormat(format)
|
||||
list_objects(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -191,9 +211,10 @@ def objects(ctx, format: str):
|
|||
@click.pass_context
|
||||
def runtime_envs(ctx, format: str):
|
||||
url = ctx.obj["api_server_url"]
|
||||
format = AvailableFormat(format)
|
||||
print(
|
||||
get_state_api_output_to_print(
|
||||
list_runtime_envs(api_server_url=url),
|
||||
format=AvailableFormat(format),
|
||||
list_runtime_envs(api_server_url=url, _explain=_should_explain(format)),
|
||||
format=format,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@ from functools import wraps
|
|||
import grpc
|
||||
import ray
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
from ray import ray_constants
|
||||
|
||||
from ray.core.generated.gcs_service_pb2 import (
|
||||
|
@ -33,19 +33,13 @@ from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub
|
|||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub
|
||||
from ray.dashboard.modules.job.common import JobInfoStorageClient, JobInfo
|
||||
from ray.experimental.state.exception import DataSourceUnavailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateSourceNetworkException(Exception):
|
||||
"""Exceptions raised when there's a network error from data source query."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def handle_network_errors(func):
|
||||
"""Apply the network error handling logic to each APIs,
|
||||
such as retry or exception policies.
|
||||
def handle_grpc_network_errors(func):
|
||||
"""Decorator to add a network handling logic.
|
||||
|
||||
It is a helper method for `StateDataSourceClient`.
|
||||
The method can only be used for async methods.
|
||||
|
@ -54,20 +48,32 @@ def handle_network_errors(func):
|
|||
|
||||
@wraps(func)
|
||||
async def api_with_network_error_handler(*args, **kwargs):
|
||||
"""Apply the network error handling logic to each APIs,
|
||||
such as retry or exception policies.
|
||||
|
||||
Returns:
|
||||
If RPC succeeds, it returns what the original function returns.
|
||||
If RPC fails, it raises exceptions.
|
||||
Exceptions:
|
||||
DataSourceUnavailable: if the source is unavailable because it is down
|
||||
or there's a slow network issue causing timeout.
|
||||
Otherwise, the raw network exceptions (e.g., gRPC) will be raised.
|
||||
"""
|
||||
# TODO(sang): Add a retry policy.
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (
|
||||
# https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc-exceptions
|
||||
grpc.aio.AioRpcError,
|
||||
grpc.aio.InternalError,
|
||||
grpc.aio.AbortError,
|
||||
grpc.aio.BaseError,
|
||||
grpc.aio.UsageError,
|
||||
) as e:
|
||||
raise StateSourceNetworkException(
|
||||
f"Failed to query the data source, {func}"
|
||||
) from e
|
||||
except grpc.aio.AioRpcError as e:
|
||||
if (
|
||||
e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
|
||||
or e.code() == grpc.StatusCode.UNAVAILABLE
|
||||
):
|
||||
raise DataSourceUnavailable(
|
||||
"Failed to query the data source. "
|
||||
"It is either there's a network issue, or the source is down."
|
||||
)
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
return api_with_network_error_handler
|
||||
|
||||
|
@ -81,8 +87,9 @@ class StateDataSourceClient:
|
|||
finding services and register stubs through `register*` APIs.
|
||||
|
||||
Non `register*` APIs
|
||||
- Return the protobuf directly if it succeeds to query the source.
|
||||
- Raises an exception if there's any network issue.
|
||||
- throw a ValueError if it cannot find the source.
|
||||
- throw `StateSourceNetworkException` if there's any network errors.
|
||||
"""
|
||||
|
||||
def __init__(self, gcs_channel: grpc.aio.Channel):
|
||||
|
@ -132,50 +139,66 @@ class StateDataSourceClient:
|
|||
def get_all_registered_agent_ids(self) -> List[str]:
|
||||
return self._agent_stubs.keys()
|
||||
|
||||
@handle_network_errors
|
||||
async def get_all_actor_info(self, timeout: int = None) -> GetAllActorInfoReply:
|
||||
@handle_grpc_network_errors
|
||||
async def get_all_actor_info(
|
||||
self, timeout: int = None
|
||||
) -> Optional[GetAllActorInfoReply]:
|
||||
request = GetAllActorInfoRequest()
|
||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
||||
request, timeout=timeout
|
||||
)
|
||||
return reply
|
||||
|
||||
@handle_network_errors
|
||||
@handle_grpc_network_errors
|
||||
async def get_all_placement_group_info(
|
||||
self, timeout: int = None
|
||||
) -> GetAllPlacementGroupReply:
|
||||
) -> Optional[GetAllPlacementGroupReply]:
|
||||
request = GetAllPlacementGroupRequest()
|
||||
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
|
||||
request, timeout=timeout
|
||||
)
|
||||
return reply
|
||||
|
||||
@handle_network_errors
|
||||
async def get_all_node_info(self, timeout: int = None) -> GetAllNodeInfoReply:
|
||||
@handle_grpc_network_errors
|
||||
async def get_all_node_info(
|
||||
self, timeout: int = None
|
||||
) -> Optional[GetAllNodeInfoReply]:
|
||||
request = GetAllNodeInfoRequest()
|
||||
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
|
||||
return reply
|
||||
|
||||
@handle_network_errors
|
||||
async def get_all_worker_info(self, timeout: int = None) -> GetAllWorkerInfoReply:
|
||||
@handle_grpc_network_errors
|
||||
async def get_all_worker_info(
|
||||
self, timeout: int = None
|
||||
) -> Optional[GetAllWorkerInfoReply]:
|
||||
request = GetAllWorkerInfoRequest()
|
||||
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
|
||||
request, timeout=timeout
|
||||
)
|
||||
return reply
|
||||
|
||||
def get_job_info(self) -> Dict[str, JobInfo]:
|
||||
# Cannot use @handle_network_errors because async def is not supported yet.
|
||||
def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
|
||||
# Cannot use @handle_grpc_network_errors because async def is not supported yet.
|
||||
# TODO(sang): Support timeout & make it async
|
||||
try:
|
||||
return self._job_client.get_all_jobs()
|
||||
except Exception as e:
|
||||
raise StateSourceNetworkException("Failed to query the job info.") from e
|
||||
except grpc.aio.AioRpcError as e:
|
||||
if (
|
||||
e.code == grpc.StatusCode.DEADLINE_EXCEEDED
|
||||
or e.code == grpc.StatusCode.UNAVAILABLE
|
||||
):
|
||||
raise DataSourceUnavailable(
|
||||
"Failed to query the data source. "
|
||||
"It is either there's a network issue, or the source is down."
|
||||
)
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@handle_network_errors
|
||||
@handle_grpc_network_errors
|
||||
async def get_task_info(
|
||||
self, node_id: str, timeout: int = None
|
||||
) -> GetTasksInfoReply:
|
||||
) -> Optional[GetTasksInfoReply]:
|
||||
stub = self._raylet_stubs.get(node_id)
|
||||
if not stub:
|
||||
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
|
||||
|
@ -183,10 +206,10 @@ class StateDataSourceClient:
|
|||
reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout)
|
||||
return reply
|
||||
|
||||
@handle_network_errors
|
||||
@handle_grpc_network_errors
|
||||
async def get_object_info(
|
||||
self, node_id: str, timeout: int = None
|
||||
) -> GetNodeStatsReply:
|
||||
) -> Optional[GetNodeStatsReply]:
|
||||
stub = self._raylet_stubs.get(node_id)
|
||||
if not stub:
|
||||
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
|
||||
|
@ -197,10 +220,10 @@ class StateDataSourceClient:
|
|||
)
|
||||
return reply
|
||||
|
||||
@handle_network_errors
|
||||
@handle_grpc_network_errors
|
||||
async def get_runtime_envs_info(
|
||||
self, node_id: str, timeout: int = None
|
||||
) -> GetRuntimeEnvsInfoReply:
|
||||
) -> Optional[GetRuntimeEnvsInfoReply]:
|
||||
stub = self._agent_stubs.get(node_id)
|
||||
if not stub:
|
||||
raise ValueError(f"Agent for a node id, {node_id} doesn't exist.")
|
||||
|
|
|
@ -8,7 +8,10 @@ from dataclasses import fields
|
|||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from asyncmock import AsyncMock
|
||||
if sys.version_info > (3, 7, 0):
|
||||
from unittest.mock import AsyncMock
|
||||
else:
|
||||
from asyncmock import AsyncMock
|
||||
|
||||
import ray
|
||||
import ray.ray_constants as ray_constants
|
||||
|
@ -41,7 +44,11 @@ from ray.core.generated.runtime_env_common_pb2 import (
|
|||
)
|
||||
from ray.core.generated.runtime_env_agent_pb2 import GetRuntimeEnvsInfoReply
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
from ray.dashboard.state_aggregator import StateAPIManager
|
||||
from ray.dashboard.state_aggregator import (
|
||||
StateAPIManager,
|
||||
GCS_QUERY_FAILURE_WARNING,
|
||||
NODE_QUERY_FAILURE_WARNING,
|
||||
)
|
||||
from ray.experimental.state.api import (
|
||||
list_actors,
|
||||
list_placement_groups,
|
||||
|
@ -64,9 +71,9 @@ from ray.experimental.state.common import (
|
|||
DEFAULT_RPC_TIMEOUT,
|
||||
DEFAULT_LIMIT,
|
||||
)
|
||||
from ray.experimental.state.exception import DataSourceUnavailable, RayStateApiException
|
||||
from ray.experimental.state.state_manager import (
|
||||
StateDataSourceClient,
|
||||
StateSourceNetworkException,
|
||||
)
|
||||
from ray.experimental.state.state_cli import (
|
||||
list_state_cli_group,
|
||||
|
@ -188,7 +195,7 @@ def generate_runtime_env_info(runtime_env, creation_time=None):
|
|||
|
||||
|
||||
def list_api_options(timeout: int = DEFAULT_RPC_TIMEOUT, limit: int = DEFAULT_LIMIT):
|
||||
return ListApiOptions(limit=limit, timeout=timeout)
|
||||
return ListApiOptions(limit=limit, timeout=timeout, _server_timeout_multiplier=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -199,15 +206,25 @@ async def test_api_manager_list_actors(state_api_manager):
|
|||
actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")]
|
||||
)
|
||||
result = await state_api_manager.list_actors(option=list_api_options())
|
||||
actor_data = list(result.values())[0]
|
||||
data = result.result
|
||||
actor_data = list(data.values())[0]
|
||||
verify_schema(ActorState, actor_data)
|
||||
|
||||
"""
|
||||
Test limit
|
||||
"""
|
||||
assert len(result) == 2
|
||||
assert len(data) == 2
|
||||
result = await state_api_manager.list_actors(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable()
|
||||
with pytest.raises(DataSourceUnavailable) as exc_info:
|
||||
result = await state_api_manager.list_actors(option=list_api_options(limit=1))
|
||||
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -223,17 +240,31 @@ async def test_api_manager_list_pgs(state_api_manager):
|
|||
)
|
||||
)
|
||||
result = await state_api_manager.list_placement_groups(option=list_api_options())
|
||||
data = list(result.values())[0]
|
||||
data = result.result
|
||||
data = list(data.values())[0]
|
||||
verify_schema(PlacementGroupState, data)
|
||||
|
||||
"""
|
||||
Test limit
|
||||
"""
|
||||
assert len(result) == 2
|
||||
assert len(data) == 2
|
||||
result = await state_api_manager.list_placement_groups(
|
||||
option=list_api_options(limit=1)
|
||||
)
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_all_placement_group_info.side_effect = (
|
||||
DataSourceUnavailable()
|
||||
)
|
||||
with pytest.raises(DataSourceUnavailable) as exc_info:
|
||||
result = await state_api_manager.list_placement_groups(
|
||||
option=list_api_options(limit=1)
|
||||
)
|
||||
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -244,15 +275,25 @@ async def test_api_manager_list_nodes(state_api_manager):
|
|||
node_info_list=[generate_node_data(id), generate_node_data(b"12345")]
|
||||
)
|
||||
result = await state_api_manager.list_nodes(option=list_api_options())
|
||||
data = list(result.values())[0]
|
||||
data = result.result
|
||||
data = list(data.values())[0]
|
||||
verify_schema(NodeState, data)
|
||||
|
||||
"""
|
||||
Test limit
|
||||
"""
|
||||
assert len(result) == 2
|
||||
assert len(data) == 2
|
||||
result = await state_api_manager.list_nodes(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_all_node_info.side_effect = DataSourceUnavailable()
|
||||
with pytest.raises(DataSourceUnavailable) as exc_info:
|
||||
result = await state_api_manager.list_nodes(option=list_api_options(limit=1))
|
||||
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -266,19 +307,30 @@ async def test_api_manager_list_workers(state_api_manager):
|
|||
]
|
||||
)
|
||||
result = await state_api_manager.list_workers(option=list_api_options())
|
||||
data = list(result.values())[0]
|
||||
data = result.result
|
||||
data = list(data.values())[0]
|
||||
verify_schema(WorkerState, data)
|
||||
|
||||
"""
|
||||
Test limit
|
||||
"""
|
||||
assert len(result) == 2
|
||||
assert len(result.result) == 2
|
||||
result = await state_api_manager.list_workers(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable()
|
||||
with pytest.raises(DataSourceUnavailable) as exc_info:
|
||||
result = await state_api_manager.list_workers(option=list_api_options(limit=1))
|
||||
assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=("Not passing in CI although it works locally. Will handle it later.")
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info <= (3, 7, 0),
|
||||
reason=("Not passing in CI although it works locally. Will handle it later."),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_manager_list_tasks(state_api_manager):
|
||||
|
@ -288,17 +340,19 @@ async def test_api_manager_list_tasks(state_api_manager):
|
|||
|
||||
first_task_name = "1"
|
||||
second_task_name = "2"
|
||||
data_source_client.get_task_info = AsyncMock()
|
||||
data_source_client.get_task_info.side_effect = [
|
||||
generate_task_data(b"1234", first_task_name),
|
||||
generate_task_data(b"2345", second_task_name),
|
||||
]
|
||||
result = await state_api_manager.list_tasks(option=list_api_options())
|
||||
data_source_client.get_task_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
data_source_client.get_task_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
result = list(result.values())
|
||||
assert len(result) == 2
|
||||
verify_schema(TaskState, result[0])
|
||||
verify_schema(TaskState, result[1])
|
||||
data_source_client.get_task_info.assert_any_await("1", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
data_source_client.get_task_info.assert_any_await("2", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
data = result.result
|
||||
data = list(data.values())
|
||||
assert len(data) == 2
|
||||
verify_schema(TaskState, data[0])
|
||||
verify_schema(TaskState, data[1])
|
||||
|
||||
"""
|
||||
Test limit
|
||||
|
@ -308,11 +362,38 @@ async def test_api_manager_list_tasks(state_api_manager):
|
|||
generate_task_data(b"2345", second_task_name),
|
||||
]
|
||||
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_task_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
generate_task_data(b"2345", second_task_name),
|
||||
]
|
||||
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
|
||||
# Make sure warnings are returned.
|
||||
warning = result.partial_failure_warning
|
||||
assert (
|
||||
NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="raylet", total=2, network_failures=1, log_command="raylet.out"
|
||||
)
|
||||
in warning
|
||||
)
|
||||
|
||||
# Test if all RPCs fail, it will raise an exception.
|
||||
data_source_client.get_task_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
DataSourceUnavailable(),
|
||||
]
|
||||
with pytest.raises(DataSourceUnavailable):
|
||||
result = await state_api_manager.list_tasks(option=list_api_options(limit=1))
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=("Not passing in CI although it works locally. Will handle it later.")
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info <= (3, 7, 0),
|
||||
reason=("Not passing in CI although it works locally. Will handle it later."),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_manager_list_objects(state_api_manager):
|
||||
|
@ -322,17 +403,23 @@ async def test_api_manager_list_objects(state_api_manager):
|
|||
data_source_client.get_all_registered_raylet_ids = MagicMock()
|
||||
data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]
|
||||
|
||||
data_source_client.get_object_info = AsyncMock()
|
||||
data_source_client.get_object_info.side_effect = [
|
||||
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]),
|
||||
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
|
||||
]
|
||||
result = await state_api_manager.list_objects(option=list_api_options())
|
||||
data_source_client.get_object_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
data_source_client.get_object_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
|
||||
result = list(result.values())
|
||||
assert len(result) == 2
|
||||
verify_schema(ObjectState, result[0])
|
||||
verify_schema(ObjectState, result[1])
|
||||
data = result.result
|
||||
data_source_client.get_object_info.assert_any_await(
|
||||
"1", timeout=DEFAULT_RPC_TIMEOUT
|
||||
)
|
||||
data_source_client.get_object_info.assert_any_await(
|
||||
"2", timeout=DEFAULT_RPC_TIMEOUT
|
||||
)
|
||||
data = list(data.values())
|
||||
assert len(data) == 2
|
||||
verify_schema(ObjectState, data[0])
|
||||
verify_schema(ObjectState, data[1])
|
||||
|
||||
"""
|
||||
Test limit
|
||||
|
@ -342,11 +429,38 @@ async def test_api_manager_list_objects(state_api_manager):
|
|||
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
|
||||
]
|
||||
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_object_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
|
||||
]
|
||||
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
|
||||
# Make sure warnings are returned.
|
||||
warning = result.partial_failure_warning
|
||||
assert (
|
||||
NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="raylet", total=2, network_failures=1, log_command="raylet.out"
|
||||
)
|
||||
in warning
|
||||
)
|
||||
|
||||
# Test if all RPCs fail, it will raise an exception.
|
||||
data_source_client.get_object_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
DataSourceUnavailable(),
|
||||
]
|
||||
with pytest.raises(DataSourceUnavailable):
|
||||
result = await state_api_manager.list_objects(option=list_api_options(limit=1))
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=("Not passing in CI although it works locally. Will handle it later.")
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info <= (3, 7, 0),
|
||||
reason=("Not passing in CI although it works locally. Will handle it later."),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_manager_list_runtime_envs(state_api_manager):
|
||||
|
@ -354,6 +468,7 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
|
|||
data_source_client.get_all_registered_agent_ids = MagicMock()
|
||||
data_source_client.get_all_registered_agent_ids.return_value = ["1", "2", "3"]
|
||||
|
||||
data_source_client.get_runtime_envs_info = AsyncMock()
|
||||
data_source_client.get_runtime_envs_info.side_effect = [
|
||||
generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
|
||||
generate_runtime_env_info(
|
||||
|
@ -362,23 +477,25 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
|
|||
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), creation_time=10),
|
||||
]
|
||||
result = await state_api_manager.list_runtime_envs(option=list_api_options())
|
||||
data_source_client.get_runtime_envs_info.assert_any_call(
|
||||
data = result.result
|
||||
data_source_client.get_runtime_envs_info.assert_any_await(
|
||||
"1", timeout=DEFAULT_RPC_TIMEOUT
|
||||
)
|
||||
data_source_client.get_runtime_envs_info.assert_any_call(
|
||||
data_source_client.get_runtime_envs_info.assert_any_await(
|
||||
"2", timeout=DEFAULT_RPC_TIMEOUT
|
||||
)
|
||||
data_source_client.get_runtime_envs_info.assert_any_call(
|
||||
|
||||
data_source_client.get_runtime_envs_info.assert_any_await(
|
||||
"3", timeout=DEFAULT_RPC_TIMEOUT
|
||||
)
|
||||
assert len(result) == 3
|
||||
verify_schema(RuntimeEnvState, result[0])
|
||||
verify_schema(RuntimeEnvState, result[1])
|
||||
verify_schema(RuntimeEnvState, result[2])
|
||||
assert len(data) == 3
|
||||
verify_schema(RuntimeEnvState, data[0])
|
||||
verify_schema(RuntimeEnvState, data[1])
|
||||
verify_schema(RuntimeEnvState, data[2])
|
||||
|
||||
# Make sure the higher creation time is sorted first.
|
||||
assert "creation_time_ms" not in result[0]
|
||||
result[1]["creation_time_ms"] > result[2]["creation_time_ms"]
|
||||
assert "creation_time_ms" not in data[0]
|
||||
data[1]["creation_time_ms"] > data[2]["creation_time_ms"]
|
||||
|
||||
"""
|
||||
Test limit
|
||||
|
@ -391,7 +508,38 @@ async def test_api_manager_list_runtime_envs(state_api_manager):
|
|||
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
|
||||
]
|
||||
result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1))
|
||||
assert len(result) == 1
|
||||
data = result.result
|
||||
assert len(data) == 1
|
||||
|
||||
"""
|
||||
Test error handling
|
||||
"""
|
||||
data_source_client.get_runtime_envs_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
|
||||
generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
|
||||
]
|
||||
result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1))
|
||||
# Make sure warnings are returned.
|
||||
warning = result.partial_failure_warning
|
||||
print(warning)
|
||||
assert (
|
||||
NODE_QUERY_FAILURE_WARNING.format(
|
||||
type="agent", total=3, network_failures=1, log_command="dashboard_agent.log"
|
||||
)
|
||||
in warning
|
||||
)
|
||||
|
||||
# Test if all RPCs fail, it will raise an exception.
|
||||
data_source_client.get_runtime_envs_info.side_effect = [
|
||||
DataSourceUnavailable(),
|
||||
DataSourceUnavailable(),
|
||||
DataSourceUnavailable(),
|
||||
]
|
||||
with pytest.raises(DataSourceUnavailable):
|
||||
result = await state_api_manager.list_runtime_envs(
|
||||
option=list_api_options(limit=1)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
@ -495,7 +643,7 @@ async def test_state_data_source_client(ray_start_cluster):
|
|||
"""
|
||||
with pytest.raises(ValueError):
|
||||
# Since we didn't register this node id, it should raise an exception.
|
||||
result = await client.get_object_info("1234")
|
||||
result = await client.get_runtime_envs_info("1234")
|
||||
wait_for_condition(lambda: len(ray.nodes()) == 2)
|
||||
for node in ray.nodes():
|
||||
node_id = node["NodeID"]
|
||||
|
@ -527,10 +675,9 @@ async def test_state_data_source_client(ray_start_cluster):
|
|||
if node["Alive"]:
|
||||
continue
|
||||
|
||||
# Querying to the dead node raises gRPC error, which should be
|
||||
# translated into `StateSourceNetworkException`
|
||||
with pytest.raises(StateSourceNetworkException):
|
||||
result = await client.get_object_info(node_id)
|
||||
# Querying to the dead node raises gRPC error, which should raise an exception.
|
||||
with pytest.raises(DataSourceUnavailable):
|
||||
await client.get_object_info(node_id)
|
||||
|
||||
# Make sure unregister API works as expected.
|
||||
client.unregister_raylet_client(node_id)
|
||||
|
@ -685,6 +832,7 @@ def test_list_jobs(shutdown_only):
|
|||
|
||||
def verify():
|
||||
job_data = list(list_jobs().values())[0]
|
||||
print(job_data)
|
||||
job_id_from_api = list(list_jobs().keys())[0]
|
||||
correct_state = job_data["status"] == "SUCCEEDED"
|
||||
correct_id = job_id == job_id_from_api
|
||||
|
@ -914,6 +1062,91 @@ def test_limit(shutdown_only):
|
|||
assert output == list_actors(limit=2)
|
||||
|
||||
|
||||
def test_network_failure(shutdown_only):
|
||||
"""When the request fails due to network failure,
|
||||
verifies it raises an exception."""
|
||||
ray.init()
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
import time
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
a = [f.remote() for _ in range(4)] # noqa
|
||||
wait_for_condition(lambda: len(list_tasks()) == 4)
|
||||
|
||||
# Kill raylet so that list_tasks will have network error on querying raylets.
|
||||
ray.worker._global_node.kill_raylet()
|
||||
|
||||
with pytest.raises(RayStateApiException):
|
||||
list_tasks(_explain=True)
|
||||
|
||||
|
||||
def test_network_partial_failures(ray_start_cluster):
|
||||
"""When the request fails due to network failure,
|
||||
verifies it prints proper warning."""
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=2)
|
||||
ray.init(address=cluster.address)
|
||||
n = cluster.add_node(num_cpus=2)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
import time
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
a = [f.remote() for _ in range(4)] # noqa
|
||||
wait_for_condition(lambda: len(list_tasks()) == 4)
|
||||
|
||||
# Make sure when there's 0 node failure, it doesn't print the error.
|
||||
with pytest.warns(None) as record:
|
||||
list_tasks(_explain=True)
|
||||
assert len(record) == 0
|
||||
|
||||
# Kill raylet so that list_tasks will have network error on querying raylets.
|
||||
cluster.remove_node(n, allow_graceful=False)
|
||||
|
||||
with pytest.warns(RuntimeWarning):
|
||||
list_tasks(_explain=True)
|
||||
|
||||
# Make sure when _explain == False, warning is not printed.
|
||||
with pytest.warns(None) as record:
|
||||
list_tasks(_explain=False)
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
def test_network_partial_failures_timeout(monkeypatch, ray_start_cluster):
|
||||
"""When the request fails due to network timeout,
|
||||
verifies it prints proper warning."""
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=2)
|
||||
ray.init(address=cluster.address)
|
||||
with monkeypatch.context() as m:
|
||||
# defer for 10s for the second node.
|
||||
m.setenv(
|
||||
"RAY_testing_asio_delay_us",
|
||||
"NodeManagerService.grpc_server.GetTasksInfo=10000000:10000000",
|
||||
)
|
||||
cluster.add_node(num_cpus=2)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
import time
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
a = [f.remote() for _ in range(4)] # noqa
|
||||
|
||||
def verify():
|
||||
with pytest.warns(None) as record:
|
||||
list_tasks(_explain=True, timeout=5)
|
||||
return len(record) == 1
|
||||
|
||||
wait_for_condition(verify)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_format_print(state_api_manager):
|
||||
data_source_client = state_api_manager.data_source_client
|
||||
|
@ -922,6 +1155,7 @@ async def test_cli_format_print(state_api_manager):
|
|||
actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")]
|
||||
)
|
||||
result = await state_api_manager.list_actors(option=list_api_options())
|
||||
result = result.result
|
||||
# If the format is not yaml, it will raise an exception.
|
||||
yaml.load(
|
||||
get_state_api_output_to_print(result, format=AvailableFormat.YAML),
|
||||
|
|
Loading…
Add table
Reference in a new issue