From a7e759317b93cb95b8cbf990343d2a6c6f20b781 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 24 May 2022 19:56:49 +0900 Subject: [PATCH] [State Observability API] Error handling (#24413) This improves error handling per https://docs.google.com/document/d/1IeEsJOiurg-zctOcBjY-tQVbsCmURFSnUCTkx_4a7Cw/edit#heading=h.pdzl9cil9e8z (the RPC part). Semantics If all queries to the source failed, raise a RayStateApiException. If partial queries are failed, warnings.warn the partial failure when print_api_stats=True. It is true for CLI. It is false when it is used within Python API or json / yaml format is required. --- dashboard/modules/state/state_head.py | 72 ++-- dashboard/state_aggregator.py | 196 ++++++++-- python/ray/experimental/state/api.py | 49 ++- python/ray/experimental/state/common.py | 37 ++ python/ray/experimental/state/exception.py | 12 + python/ray/experimental/state/state_cli.py | 41 ++- .../ray/experimental/state/state_manager.py | 103 +++--- python/ray/tests/test_state_api.py | 336 +++++++++++++++--- 8 files changed, 677 insertions(+), 169 deletions(-) create mode 100644 python/ray/experimental/state/exception.py diff --git a/dashboard/modules/state/state_head.py b/dashboard/modules/state/state_head.py index 66af66641..130375039 100644 --- a/dashboard/modules/state/state_head.py +++ b/dashboard/modules/state/state_head.py @@ -1,4 +1,5 @@ import logging + import aiohttp.web import dataclasses @@ -10,6 +11,7 @@ import ray.dashboard.optional_utils as dashboard_optional_utils from ray.dashboard.optional_utils import rest_response from ray.dashboard.state_aggregator import StateAPIManager from ray.experimental.state.common import ListApiOptions +from ray.experimental.state.exception import DataSourceUnavailable from ray.experimental.state.state_manager import StateDataSourceClient logger = logging.getLogger(__name__) @@ -33,13 +35,19 @@ class StateHead(dashboard_utils.DashboardHeadModule): def _options_from_req(self, req) -> ListApiOptions: """Obtain `ListApiOptions` from the aiohttp request.""" limit = int(req.query.get("limit")) + # Only apply 80% of the timeout so that + # the API will reply before client times out if query to the source fails. timeout = int(req.query.get("timeout")) return ListApiOptions(limit=limit, timeout=timeout) - def _reply(self, success: bool, message: str, result: dict): + def _reply(self, success: bool, error_message: str, result: dict, **kwargs): """Reply to the client.""" return rest_response( - success=success, message=message, result=result, convert_google_style=False + success=success, + message=error_message, + result=result, + convert_google_style=False, + **kwargs, ) async def _update_raylet_stubs(self, change: Change): @@ -85,57 +93,61 @@ class StateHead(dashboard_utils.DashboardHeadModule): int(ports[1]), ) + async def _handle_list_api(self, list_api_fn, req): + try: + result = await list_api_fn(option=self._options_from_req(req)) + return self._reply( + success=True, + error_message="", + result=result.result, + partial_failure_warning=result.partial_failure_warning, + ) + except DataSourceUnavailable as e: + return self._reply(success=False, error_message=str(e), result=None) + @routes.get("/api/v0/actors") async def list_actors(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_actors(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_actors, req) @routes.get("/api/v0/jobs") async def list_jobs(self, req) -> aiohttp.web.Response: - data = self._state_api.list_jobs(option=self._options_from_req(req)) - return self._reply( - success=True, - message="", - result={ - job_id: dataclasses.asdict(job_info) - for job_id, job_info in data.items() - }, - ) + try: + result = self._state_api.list_jobs(option=self._options_from_req(req)) + return self._reply( + success=True, + error_message="", + result={ + job_id: dataclasses.asdict(job_info) + for job_id, job_info in result.result.items() + }, + partial_failure_warning=result.partial_failure_warning, + ) + except DataSourceUnavailable as e: + return self._reply(success=False, error_message=str(e), result=None) @routes.get("/api/v0/nodes") async def list_nodes(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_nodes(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_nodes, req) @routes.get("/api/v0/placement_groups") async def list_placement_groups(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_placement_groups( - option=self._options_from_req(req) - ) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_placement_groups, req) @routes.get("/api/v0/workers") async def list_workers(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_workers(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_workers, req) @routes.get("/api/v0/tasks") async def list_tasks(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_tasks(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_tasks, req) @routes.get("/api/v0/objects") async def list_objects(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_objects(option=self._options_from_req(req)) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_objects, req) @routes.get("/api/v0/runtime_envs") - @dashboard_optional_utils.aiohttp_cache async def list_runtime_envs(self, req) -> aiohttp.web.Response: - data = await self._state_api.list_runtime_envs( - option=self._options_from_req(req) - ) - return self._reply(success=True, message="", result=data) + return await self._handle_list_api(self._state_api.list_runtime_envs, req) async def run(self, server): gcs_channel = self._dashboard_head.aiogrpc_gcs_channel diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index ec4d751ee..65bc16130 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -1,13 +1,12 @@ import asyncio import logging -from typing import List, Dict from itertools import islice +from typing import List from ray.core.generated.common_pb2 import TaskStatus import ray.dashboard.utils as dashboard_utils import ray.dashboard.memory_utils as memory_utils -from ray.dashboard.modules.job.common import JobInfo from ray.experimental.state.common import ( filter_fields, @@ -19,13 +18,34 @@ from ray.experimental.state.common import ( ObjectState, RuntimeEnvState, ListApiOptions, + ListApiResponse, +) +from ray.experimental.state.state_manager import ( + StateDataSourceClient, + DataSourceUnavailable, ) -from ray.experimental.state.state_manager import StateDataSourceClient from ray.runtime_env import RuntimeEnv from ray._private.utils import binary_to_hex logger = logging.getLogger(__name__) +GCS_QUERY_FAILURE_WARNING = ( + "Failed to query data from GCS. It is due to " + "(1) GCS is unexpectedly failed. " + "(2) GCS is overloaded. " + "(3) There's an unexpected network issue. " + "Please check the gcs_server.out log to find the root cause." +) +NODE_QUERY_FAILURE_WARNING = ( + "Failed to query data from {type}. " + "Queryed {total} {type} " + "and {network_failures} {type} failed to reply. It is due to " + "(1) {type} is unexpectedly failed. " + "(2) {type} is overloaded. " + "(3) There's an unexpected network issue. Please check the " + "{log_command} to find the root cause." +) + # TODO(sang): Move the class to state/state_manager.py. # TODO(sang): Remove *State and replaces with Pydantic or protobuf. @@ -42,14 +62,19 @@ class StateAPIManager: def data_source_client(self): return self._client - async def list_actors(self, *, option: ListApiOptions) -> dict: + async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: """List all actor information from the cluster. Returns: {actor_id -> actor_data_in_dict} actor_data_in_dict's schema is in ActorState + """ - reply = await self._client.get_all_actor_info(timeout=option.timeout) + try: + reply = await self._client.get_all_actor_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.actor_table_data: data = self._message_to_dict(message=message, fields_to_decode=["actor_id"]) @@ -58,16 +83,24 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["actor_id"]) - return {d["actor_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["actor_id"]: d for d in islice(result, option.limit)} + ) - async def list_placement_groups(self, *, option: ListApiOptions) -> dict: + async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: """List all placement group information from the cluster. Returns: {pg_id -> pg_data_in_dict} pg_data_in_dict's schema is in PlacementGroupState """ - reply = await self._client.get_all_placement_group_info(timeout=option.timeout) + try: + reply = await self._client.get_all_placement_group_info( + timeout=option.timeout + ) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.placement_group_table_data: @@ -80,16 +113,22 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["placement_group_id"]) - return {d["placement_group_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["placement_group_id"]: d for d in islice(result, option.limit)} + ) - async def list_nodes(self, *, option: ListApiOptions) -> dict: + async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: """List all node information from the cluster. Returns: {node_id -> node_data_in_dict} node_data_in_dict's schema is in NodeState """ - reply = await self._client.get_all_node_info(timeout=option.timeout) + try: + reply = await self._client.get_all_node_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.node_info_list: data = self._message_to_dict(message=message, fields_to_decode=["node_id"]) @@ -98,16 +137,22 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["node_id"]) - return {d["node_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["node_id"]: d for d in islice(result, option.limit)} + ) - async def list_workers(self, *, option: ListApiOptions) -> dict: + async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: """List all worker information from the cluster. Returns: {worker_id -> worker_data_in_dict} worker_data_in_dict's schema is in WorkerState """ - reply = await self._client.get_all_worker_info(timeout=option.timeout) + try: + reply = await self._client.get_all_worker_info(timeout=option.timeout) + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + result = [] for message in reply.worker_table_data: data = self._message_to_dict( @@ -119,34 +164,65 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["worker_id"]) - return {d["worker_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["worker_id"]: d for d in islice(result, option.limit)} + ) - def list_jobs(self, *, option: ListApiOptions) -> Dict[str, JobInfo]: + def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: # TODO(sang): Support limit & timeout & async calls. - return self._client.get_job_info() + try: + result = self._client.get_job_info() + except DataSourceUnavailable: + raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) + return ListApiResponse(result=result) - async def list_tasks(self, *, option: ListApiOptions) -> dict: + async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: """List all task information from the cluster. Returns: {task_id -> task_data_in_dict} task_data_in_dict's schema is in TaskState """ + raylet_ids = self._client.get_all_registered_raylet_ids() replies = await asyncio.gather( *[ self._client.get_task_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_raylet_ids() - ] + for node_id in raylet_ids + ], + return_exceptions=True, ) + unresponsive_nodes = 0 running_task_id = set() + successful_replies = [] for reply in replies: + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + + successful_replies.append(reply) for task_id in reply.running_task_ids: running_task_id.add(binary_to_hex(task_id)) + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + result = [] - for reply in replies: - logger.info(reply) + for reply in successful_replies: + assert not isinstance(reply, Exception) tasks = reply.owned_task_info_entries for task in tasks: data = self._message_to_dict( @@ -162,24 +238,36 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["task_id"]) - return {d["task_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["task_id"]: d for d in islice(result, option.limit)}, + partial_failure_warning=partial_failure_warning, + ) - async def list_objects(self, *, option: ListApiOptions) -> dict: + async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: """List all object information from the cluster. Returns: {object_id -> object_data_in_dict} object_data_in_dict's schema is in ObjectState """ + raylet_ids = self._client.get_all_registered_raylet_ids() replies = await asyncio.gather( *[ self._client.get_object_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_raylet_ids() - ] + for node_id in raylet_ids + ], + return_exceptions=True, ) + unresponsive_nodes = 0 worker_stats = [] - for reply in replies: + for reply, node_id in zip(replies, raylet_ids): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + for core_worker_stat in reply.core_workers_stats: # NOTE: Set preserving_proto_field_name=False here because # `construct_memory_table` requires a dictionary that has @@ -193,6 +281,20 @@ class StateAPIManager: ) ) + partial_failure_warning = None + if len(raylet_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="raylet", + total=len(raylet_ids), + network_failures=unresponsive_nodes, + log_command="raylet.out", + ) + if unresponsive_nodes == len(raylet_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + result = [] memory_table = memory_utils.construct_memory_table(worker_stats) for entry in memory_table.table: @@ -207,9 +309,12 @@ class StateAPIManager: # Sort to make the output deterministic. result.sort(key=lambda entry: entry["object_id"]) - return {d["object_id"]: d for d in islice(result, option.limit)} + return ListApiResponse( + result={d["object_id"]: d for d in islice(result, option.limit)}, + partial_failure_warning=partial_failure_warning, + ) - async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]: + async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse: """List all runtime env information from the cluster. Returns: @@ -219,14 +324,24 @@ class StateAPIManager: We don't have id -> data mapping like other API because runtime env doesn't have unique ids. """ + agent_ids = self._client.get_all_registered_agent_ids() replies = await asyncio.gather( *[ self._client.get_runtime_envs_info(node_id, timeout=option.timeout) - for node_id in self._client.get_all_registered_agent_ids() - ] + for node_id in agent_ids + ], + return_exceptions=True, ) + result = [] + unresponsive_nodes = 0 for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies): + if isinstance(reply, DataSourceUnavailable): + unresponsive_nodes += 1 + continue + elif isinstance(reply, Exception): + raise reply + states = reply.runtime_env_states for state in states: data = self._message_to_dict(message=state, fields_to_decode=[]) @@ -238,6 +353,20 @@ class StateAPIManager: data = filter_fields(data, RuntimeEnvState) result.append(data) + partial_failure_warning = None + if len(agent_ids) > 0 and unresponsive_nodes > 0: + warning_msg = NODE_QUERY_FAILURE_WARNING.format( + type="agent", + total=len(agent_ids), + network_failures=unresponsive_nodes, + log_command="dashboard_agent.log", + ) + if unresponsive_nodes == len(agent_ids): + raise DataSourceUnavailable(warning_msg) + partial_failure_warning = ( + f"The returned data may contain incomplete result. {warning_msg}" + ) + # Sort to make the output deterministic. def sort_func(entry): # If creation time is not there yet (runtime env is failed @@ -251,7 +380,10 @@ class StateAPIManager: return float(entry["creation_time_ms"]) result.sort(key=sort_func, reverse=True) - return list(islice(result, option.limit)) + return ListApiResponse( + result=list(islice(result, option.limit)), + partial_failure_warning=partial_failure_warning, + ) def _message_to_dict( self, diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index 4fe83f612..eea31f4b5 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -1,4 +1,5 @@ import requests +import warnings from dataclasses import fields @@ -8,18 +9,26 @@ from ray.experimental.state.common import ( DEFAULT_RPC_TIMEOUT, DEFAULT_LIMIT, ) +from ray.experimental.state.exception import RayStateApiException # TODO(sang): Replace it with auto-generated methods. -def _list(resource_name: str, options: ListApiOptions, api_server_url: str = None): +def _list( + resource_name: str, + options: ListApiOptions, + api_server_url: str = None, + _explain: bool = False, +): """Query the API server in address to list "resource_name" states. Args: resource_name: The name of the resource. E.g., actor, task. options: The options for the REST API that are translated to query strings. - address: The address of API server. If it is not give, it assumes the ray + api_server_url: The address of API server. If it is not give, it assumes the ray is already connected and obtains the API server address using Ray API. + explain: Print the API information such as API + latency or failed query information. """ if api_server_url is None: assert ray.is_initialized() @@ -40,10 +49,18 @@ def _list(resource_name: str, options: ListApiOptions, api_server_url: str = Non r.raise_for_status() response = r.json() - if not response["result"]: - raise ValueError( - "API server internal error. See dashboard.log file for more details." + if response["result"] is False: + raise RayStateApiException( + "API server internal error. See dashboard.log file for more details. " + f"Error: {response['msg']}" ) + + if _explain: + # Print warnings if anything was given. + warning_msg = response["data"].get("partial_failure_warning", None) + if warning_msg is not None: + warnings.warn(warning_msg, RuntimeWarning) + return r.json()["data"]["result"] @@ -51,11 +68,13 @@ def list_actors( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "actors", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -63,11 +82,13 @@ def list_placement_groups( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "placement_groups", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -75,11 +96,13 @@ def list_nodes( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "nodes", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -87,11 +110,13 @@ def list_jobs( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "jobs", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -99,11 +124,13 @@ def list_workers( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "workers", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -111,11 +138,13 @@ def list_tasks( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "tasks", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) @@ -123,17 +152,25 @@ def list_objects( api_server_url: str = None, limit: int = DEFAULT_LIMIT, timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, ): return _list( "objects", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) -def list_runtime_envs(api_server_url: str = None, limit: int = 1000, timeout: int = 30): +def list_runtime_envs( + api_server_url: str = None, + limit: int = DEFAULT_LIMIT, + timeout: int = DEFAULT_RPC_TIMEOUT, + _explain: bool = False, +): return _list( "runtime_envs", ListApiOptions(limit=limit, timeout=timeout), api_server_url=api_server_url, + _explain=_explain, ) diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index 789faae72..7b1190838 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass, fields +from typing import List, Dict, Union from ray.dashboard.modules.job.common import JobInfo @@ -23,11 +24,20 @@ def filter_fields(data: dict, state_dataclass) -> dict: class ListApiOptions: limit: int timeout: int + # When the request is processed on the server side, + # we should apply multiplier so that server side can finish + # processing a request within timeout. Otherwise, + # timeout will always lead Http timeout. + _server_timeout_multiplier: float = 0.8 # TODO(sang): Use Pydantic instead. def __post_init__(self): assert isinstance(self.limit, int) assert isinstance(self.timeout, int) + # To return the data to users, when there's a partial failure + # we need to have a timeout that's smaller than the users' timeout. + # 80% is configured arbitrarily. + self.timeout = int(self.timeout * self._server_timeout_multiplier) # TODO(sang): Replace it with Pydantic or gRPC schema (once interface is finalized). @@ -94,3 +104,30 @@ class RuntimeEnvState: error: str creation_time_ms: float node_id: str + + +@dataclass(init=True) +class ListApiResponse: + # Returned data. None if no data is returned. + result: Union[ + Dict[ + str, + Union[ + ActorState, + PlacementGroupState, + NodeState, + JobInfo, + WorkerState, + TaskState, + ObjectState, + ], + ], + List[RuntimeEnvState], + ] = None + # List API can have a partial failure if queries to + # all sources fail. For example, getting object states + # require to ping all raylets, and it is possible some of + # them fails. Note that it is impossible to guarantee high + # availability of data because ray's state information is + # not replicated. + partial_failure_warning: str = "" diff --git a/python/ray/experimental/state/exception.py b/python/ray/experimental/state/exception.py new file mode 100644 index 000000000..e91b8d931 --- /dev/null +++ b/python/ray/experimental/state/exception.py @@ -0,0 +1,12 @@ +"""Internal Error""" + + +class DataSourceUnavailable(Exception): + pass + + +"""User-facing Error""" + + +class RayStateApiException(Exception): + pass diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index d221ec57f..24773092a 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -59,6 +59,12 @@ def get_state_api_output_to_print( ) +def _should_explain(format: AvailableFormat): + # If the format is json or yaml, it should not print stats because + # users don't want additional strings. + return format == AvailableFormat.DEFAULT or format == AvailableFormat.TABLE + + @click.group("list") @click.pass_context def list_state_cli_group(ctx): @@ -71,6 +77,7 @@ def list_state_cli_group(ctx): namespace=ray_constants.KV_NAMESPACE_DASHBOARD, num_retries=20, ) + if api_server_url is None: raise ValueError( ( @@ -92,9 +99,11 @@ def list_state_cli_group(ctx): @click.pass_context def actors(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_actors(api_server_url=url), format=AvailableFormat(format) + list_actors(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -106,10 +115,11 @@ def actors(ctx, format: str): @click.pass_context def placement_groups(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_placement_groups(api_server_url=url), - format=AvailableFormat(format), + list_placement_groups(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -121,9 +131,11 @@ def placement_groups(ctx, format: str): @click.pass_context def nodes(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_nodes(api_server_url=url), format=AvailableFormat(format) + list_nodes(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -135,9 +147,11 @@ def nodes(ctx, format: str): @click.pass_context def jobs(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_jobs(api_server_url=url), format=AvailableFormat(format) + list_jobs(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -149,9 +163,11 @@ def jobs(ctx, format: str): @click.pass_context def workers(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_workers(api_server_url=url), format=AvailableFormat(format) + list_workers(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -163,9 +179,11 @@ def workers(ctx, format: str): @click.pass_context def tasks(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_tasks(api_server_url=url), format=AvailableFormat(format) + list_tasks(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -177,9 +195,11 @@ def tasks(ctx, format: str): @click.pass_context def objects(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_objects(api_server_url=url), format=AvailableFormat(format) + list_objects(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) @@ -191,9 +211,10 @@ def objects(ctx, format: str): @click.pass_context def runtime_envs(ctx, format: str): url = ctx.obj["api_server_url"] + format = AvailableFormat(format) print( get_state_api_output_to_print( - list_runtime_envs(api_server_url=url), - format=AvailableFormat(format), + list_runtime_envs(api_server_url=url, _explain=_should_explain(format)), + format=format, ) ) diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py index ce471ed3d..656f2ae24 100644 --- a/python/ray/experimental/state/state_manager.py +++ b/python/ray/experimental/state/state_manager.py @@ -6,7 +6,7 @@ from functools import wraps import grpc import ray -from typing import Dict, List +from typing import Dict, List, Optional from ray import ray_constants from ray.core.generated.gcs_service_pb2 import ( @@ -33,19 +33,13 @@ from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub from ray.dashboard.modules.job.common import JobInfoStorageClient, JobInfo +from ray.experimental.state.exception import DataSourceUnavailable logger = logging.getLogger(__name__) -class StateSourceNetworkException(Exception): - """Exceptions raised when there's a network error from data source query.""" - - pass - - -def handle_network_errors(func): - """Apply the network error handling logic to each APIs, - such as retry or exception policies. +def handle_grpc_network_errors(func): + """Decorator to add a network handling logic. It is a helper method for `StateDataSourceClient`. The method can only be used for async methods. @@ -54,20 +48,32 @@ def handle_network_errors(func): @wraps(func) async def api_with_network_error_handler(*args, **kwargs): + """Apply the network error handling logic to each APIs, + such as retry or exception policies. + + Returns: + If RPC succeeds, it returns what the original function returns. + If RPC fails, it raises exceptions. + Exceptions: + DataSourceUnavailable: if the source is unavailable because it is down + or there's a slow network issue causing timeout. + Otherwise, the raw network exceptions (e.g., gRPC) will be raised. + """ # TODO(sang): Add a retry policy. try: return await func(*args, **kwargs) - except ( - # https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc-exceptions - grpc.aio.AioRpcError, - grpc.aio.InternalError, - grpc.aio.AbortError, - grpc.aio.BaseError, - grpc.aio.UsageError, - ) as e: - raise StateSourceNetworkException( - f"Failed to query the data source, {func}" - ) from e + except grpc.aio.AioRpcError as e: + if ( + e.code() == grpc.StatusCode.DEADLINE_EXCEEDED + or e.code() == grpc.StatusCode.UNAVAILABLE + ): + raise DataSourceUnavailable( + "Failed to query the data source. " + "It is either there's a network issue, or the source is down." + ) + else: + logger.exception(e) + raise e return api_with_network_error_handler @@ -81,8 +87,9 @@ class StateDataSourceClient: finding services and register stubs through `register*` APIs. Non `register*` APIs + - Return the protobuf directly if it succeeds to query the source. + - Raises an exception if there's any network issue. - throw a ValueError if it cannot find the source. - - throw `StateSourceNetworkException` if there's any network errors. """ def __init__(self, gcs_channel: grpc.aio.Channel): @@ -132,50 +139,66 @@ class StateDataSourceClient: def get_all_registered_agent_ids(self) -> List[str]: return self._agent_stubs.keys() - @handle_network_errors - async def get_all_actor_info(self, timeout: int = None) -> GetAllActorInfoReply: + @handle_grpc_network_errors + async def get_all_actor_info( + self, timeout: int = None + ) -> Optional[GetAllActorInfoReply]: request = GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=timeout ) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_all_placement_group_info( self, timeout: int = None - ) -> GetAllPlacementGroupReply: + ) -> Optional[GetAllPlacementGroupReply]: request = GetAllPlacementGroupRequest() reply = await self._gcs_pg_info_stub.GetAllPlacementGroup( request, timeout=timeout ) return reply - @handle_network_errors - async def get_all_node_info(self, timeout: int = None) -> GetAllNodeInfoReply: + @handle_grpc_network_errors + async def get_all_node_info( + self, timeout: int = None + ) -> Optional[GetAllNodeInfoReply]: request = GetAllNodeInfoRequest() reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout) return reply - @handle_network_errors - async def get_all_worker_info(self, timeout: int = None) -> GetAllWorkerInfoReply: + @handle_grpc_network_errors + async def get_all_worker_info( + self, timeout: int = None + ) -> Optional[GetAllWorkerInfoReply]: request = GetAllWorkerInfoRequest() reply = await self._gcs_worker_info_stub.GetAllWorkerInfo( request, timeout=timeout ) return reply - def get_job_info(self) -> Dict[str, JobInfo]: - # Cannot use @handle_network_errors because async def is not supported yet. + def get_job_info(self) -> Optional[Dict[str, JobInfo]]: + # Cannot use @handle_grpc_network_errors because async def is not supported yet. # TODO(sang): Support timeout & make it async try: return self._job_client.get_all_jobs() - except Exception as e: - raise StateSourceNetworkException("Failed to query the job info.") from e + except grpc.aio.AioRpcError as e: + if ( + e.code == grpc.StatusCode.DEADLINE_EXCEEDED + or e.code == grpc.StatusCode.UNAVAILABLE + ): + raise DataSourceUnavailable( + "Failed to query the data source. " + "It is either there's a network issue, or the source is down." + ) + else: + logger.exception(e) + raise e - @handle_network_errors + @handle_grpc_network_errors async def get_task_info( self, node_id: str, timeout: int = None - ) -> GetTasksInfoReply: + ) -> Optional[GetTasksInfoReply]: stub = self._raylet_stubs.get(node_id) if not stub: raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.") @@ -183,10 +206,10 @@ class StateDataSourceClient: reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_object_info( self, node_id: str, timeout: int = None - ) -> GetNodeStatsReply: + ) -> Optional[GetNodeStatsReply]: stub = self._raylet_stubs.get(node_id) if not stub: raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.") @@ -197,10 +220,10 @@ class StateDataSourceClient: ) return reply - @handle_network_errors + @handle_grpc_network_errors async def get_runtime_envs_info( self, node_id: str, timeout: int = None - ) -> GetRuntimeEnvsInfoReply: + ) -> Optional[GetRuntimeEnvsInfoReply]: stub = self._agent_stubs.get(node_id) if not stub: raise ValueError(f"Agent for a node id, {node_id} doesn't exist.") diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 1a1a0a8ed..2c31f33de 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -8,7 +8,10 @@ from dataclasses import fields from unittest.mock import MagicMock -from asyncmock import AsyncMock +if sys.version_info > (3, 7, 0): + from unittest.mock import AsyncMock +else: + from asyncmock import AsyncMock import ray import ray.ray_constants as ray_constants @@ -41,7 +44,11 @@ from ray.core.generated.runtime_env_common_pb2 import ( ) from ray.core.generated.runtime_env_agent_pb2 import GetRuntimeEnvsInfoReply import ray.dashboard.consts as dashboard_consts -from ray.dashboard.state_aggregator import StateAPIManager +from ray.dashboard.state_aggregator import ( + StateAPIManager, + GCS_QUERY_FAILURE_WARNING, + NODE_QUERY_FAILURE_WARNING, +) from ray.experimental.state.api import ( list_actors, list_placement_groups, @@ -64,9 +71,9 @@ from ray.experimental.state.common import ( DEFAULT_RPC_TIMEOUT, DEFAULT_LIMIT, ) +from ray.experimental.state.exception import DataSourceUnavailable, RayStateApiException from ray.experimental.state.state_manager import ( StateDataSourceClient, - StateSourceNetworkException, ) from ray.experimental.state.state_cli import ( list_state_cli_group, @@ -188,7 +195,7 @@ def generate_runtime_env_info(runtime_env, creation_time=None): def list_api_options(timeout: int = DEFAULT_RPC_TIMEOUT, limit: int = DEFAULT_LIMIT): - return ListApiOptions(limit=limit, timeout=timeout) + return ListApiOptions(limit=limit, timeout=timeout, _server_timeout_multiplier=1.0) @pytest.mark.asyncio @@ -199,15 +206,25 @@ async def test_api_manager_list_actors(state_api_manager): actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")] ) result = await state_api_manager.list_actors(option=list_api_options()) - actor_data = list(result.values())[0] + data = result.result + actor_data = list(data.values())[0] verify_schema(ActorState, actor_data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_actors(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_actors(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -223,17 +240,31 @@ async def test_api_manager_list_pgs(state_api_manager): ) ) result = await state_api_manager.list_placement_groups(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(PlacementGroupState, data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_placement_groups( option=list_api_options(limit=1) ) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_placement_group_info.side_effect = ( + DataSourceUnavailable() + ) + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_placement_groups( + option=list_api_options(limit=1) + ) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -244,15 +275,25 @@ async def test_api_manager_list_nodes(state_api_manager): node_info_list=[generate_node_data(id), generate_node_data(b"12345")] ) result = await state_api_manager.list_nodes(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(NodeState, data) """ Test limit """ - assert len(result) == 2 + assert len(data) == 2 result = await state_api_manager.list_nodes(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_node_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_nodes(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING @pytest.mark.asyncio @@ -266,19 +307,30 @@ async def test_api_manager_list_workers(state_api_manager): ] ) result = await state_api_manager.list_workers(option=list_api_options()) - data = list(result.values())[0] + data = result.result + data = list(data.values())[0] verify_schema(WorkerState, data) """ Test limit """ - assert len(result) == 2 + assert len(result.result) == 2 result = await state_api_manager.list_workers(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable() + with pytest.raises(DataSourceUnavailable) as exc_info: + result = await state_api_manager.list_workers(option=list_api_options(limit=1)) + assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_tasks(state_api_manager): @@ -288,17 +340,19 @@ async def test_api_manager_list_tasks(state_api_manager): first_task_name = "1" second_task_name = "2" + data_source_client.get_task_info = AsyncMock() data_source_client.get_task_info.side_effect = [ generate_task_data(b"1234", first_task_name), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=list_api_options()) - data_source_client.get_task_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT) - data_source_client.get_task_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT) - result = list(result.values()) - assert len(result) == 2 - verify_schema(TaskState, result[0]) - verify_schema(TaskState, result[1]) + data_source_client.get_task_info.assert_any_await("1", timeout=DEFAULT_RPC_TIMEOUT) + data_source_client.get_task_info.assert_any_await("2", timeout=DEFAULT_RPC_TIMEOUT) + data = result.result + data = list(data.values()) + assert len(data) == 2 + verify_schema(TaskState, data[0]) + verify_schema(TaskState, data[1]) """ Test limit @@ -308,11 +362,38 @@ async def test_api_manager_list_tasks(state_api_manager): generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_task_info.side_effect = [ + DataSourceUnavailable(), + generate_task_data(b"2345", second_task_name), + ] + result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="raylet", total=2, network_failures=1, log_command="raylet.out" + ) + in warning + ) + + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_task_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_tasks(option=list_api_options(limit=1)) -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_objects(state_api_manager): @@ -322,17 +403,23 @@ async def test_api_manager_list_objects(state_api_manager): data_source_client.get_all_registered_raylet_ids = MagicMock() data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"] + data_source_client.get_object_info = AsyncMock() data_source_client.get_object_info.side_effect = [ GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]), GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options()) - data_source_client.get_object_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT) - data_source_client.get_object_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT) - result = list(result.values()) - assert len(result) == 2 - verify_schema(ObjectState, result[0]) - verify_schema(ObjectState, result[1]) + data = result.result + data_source_client.get_object_info.assert_any_await( + "1", timeout=DEFAULT_RPC_TIMEOUT + ) + data_source_client.get_object_info.assert_any_await( + "2", timeout=DEFAULT_RPC_TIMEOUT + ) + data = list(data.values()) + assert len(data) == 2 + verify_schema(ObjectState, data[0]) + verify_schema(ObjectState, data[1]) """ Test limit @@ -342,11 +429,38 @@ async def test_api_manager_list_objects(state_api_manager): GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_object_info.side_effect = [ + DataSourceUnavailable(), + GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), + ] + result = await state_api_manager.list_objects(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="raylet", total=2, network_failures=1, log_command="raylet.out" + ) + in warning + ) + + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_object_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_objects(option=list_api_options(limit=1)) -@pytest.mark.skip( - reason=("Not passing in CI although it works locally. Will handle it later.") +@pytest.mark.skipif( + sys.version_info <= (3, 7, 0), + reason=("Not passing in CI although it works locally. Will handle it later."), ) @pytest.mark.asyncio async def test_api_manager_list_runtime_envs(state_api_manager): @@ -354,6 +468,7 @@ async def test_api_manager_list_runtime_envs(state_api_manager): data_source_client.get_all_registered_agent_ids = MagicMock() data_source_client.get_all_registered_agent_ids.return_value = ["1", "2", "3"] + data_source_client.get_runtime_envs_info = AsyncMock() data_source_client.get_runtime_envs_info.side_effect = [ generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})), generate_runtime_env_info( @@ -362,23 +477,25 @@ async def test_api_manager_list_runtime_envs(state_api_manager): generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), creation_time=10), ] result = await state_api_manager.list_runtime_envs(option=list_api_options()) - data_source_client.get_runtime_envs_info.assert_any_call( + data = result.result + data_source_client.get_runtime_envs_info.assert_any_await( "1", timeout=DEFAULT_RPC_TIMEOUT ) - data_source_client.get_runtime_envs_info.assert_any_call( + data_source_client.get_runtime_envs_info.assert_any_await( "2", timeout=DEFAULT_RPC_TIMEOUT ) - data_source_client.get_runtime_envs_info.assert_any_call( + + data_source_client.get_runtime_envs_info.assert_any_await( "3", timeout=DEFAULT_RPC_TIMEOUT ) - assert len(result) == 3 - verify_schema(RuntimeEnvState, result[0]) - verify_schema(RuntimeEnvState, result[1]) - verify_schema(RuntimeEnvState, result[2]) + assert len(data) == 3 + verify_schema(RuntimeEnvState, data[0]) + verify_schema(RuntimeEnvState, data[1]) + verify_schema(RuntimeEnvState, data[2]) # Make sure the higher creation time is sorted first. - assert "creation_time_ms" not in result[0] - result[1]["creation_time_ms"] > result[2]["creation_time_ms"] + assert "creation_time_ms" not in data[0] + data[1]["creation_time_ms"] > data[2]["creation_time_ms"] """ Test limit @@ -391,7 +508,38 @@ async def test_api_manager_list_runtime_envs(state_api_manager): generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), ] result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1)) - assert len(result) == 1 + data = result.result + assert len(data) == 1 + + """ + Test error handling + """ + data_source_client.get_runtime_envs_info.side_effect = [ + DataSourceUnavailable(), + generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), + generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), + ] + result = await state_api_manager.list_runtime_envs(option=list_api_options(limit=1)) + # Make sure warnings are returned. + warning = result.partial_failure_warning + print(warning) + assert ( + NODE_QUERY_FAILURE_WARNING.format( + type="agent", total=3, network_failures=1, log_command="dashboard_agent.log" + ) + in warning + ) + + # Test if all RPCs fail, it will raise an exception. + data_source_client.get_runtime_envs_info.side_effect = [ + DataSourceUnavailable(), + DataSourceUnavailable(), + DataSourceUnavailable(), + ] + with pytest.raises(DataSourceUnavailable): + result = await state_api_manager.list_runtime_envs( + option=list_api_options(limit=1) + ) """ @@ -495,7 +643,7 @@ async def test_state_data_source_client(ray_start_cluster): """ with pytest.raises(ValueError): # Since we didn't register this node id, it should raise an exception. - result = await client.get_object_info("1234") + result = await client.get_runtime_envs_info("1234") wait_for_condition(lambda: len(ray.nodes()) == 2) for node in ray.nodes(): node_id = node["NodeID"] @@ -527,10 +675,9 @@ async def test_state_data_source_client(ray_start_cluster): if node["Alive"]: continue - # Querying to the dead node raises gRPC error, which should be - # translated into `StateSourceNetworkException` - with pytest.raises(StateSourceNetworkException): - result = await client.get_object_info(node_id) + # Querying to the dead node raises gRPC error, which should raise an exception. + with pytest.raises(DataSourceUnavailable): + await client.get_object_info(node_id) # Make sure unregister API works as expected. client.unregister_raylet_client(node_id) @@ -685,6 +832,7 @@ def test_list_jobs(shutdown_only): def verify(): job_data = list(list_jobs().values())[0] + print(job_data) job_id_from_api = list(list_jobs().keys())[0] correct_state = job_data["status"] == "SUCCEEDED" correct_id = job_id == job_id_from_api @@ -914,6 +1062,91 @@ def test_limit(shutdown_only): assert output == list_actors(limit=2) +def test_network_failure(shutdown_only): + """When the request fails due to network failure, + verifies it raises an exception.""" + ray.init() + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + wait_for_condition(lambda: len(list_tasks()) == 4) + + # Kill raylet so that list_tasks will have network error on querying raylets. + ray.worker._global_node.kill_raylet() + + with pytest.raises(RayStateApiException): + list_tasks(_explain=True) + + +def test_network_partial_failures(ray_start_cluster): + """When the request fails due to network failure, + verifies it prints proper warning.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + n = cluster.add_node(num_cpus=2) + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + wait_for_condition(lambda: len(list_tasks()) == 4) + + # Make sure when there's 0 node failure, it doesn't print the error. + with pytest.warns(None) as record: + list_tasks(_explain=True) + assert len(record) == 0 + + # Kill raylet so that list_tasks will have network error on querying raylets. + cluster.remove_node(n, allow_graceful=False) + + with pytest.warns(RuntimeWarning): + list_tasks(_explain=True) + + # Make sure when _explain == False, warning is not printed. + with pytest.warns(None) as record: + list_tasks(_explain=False) + assert len(record) == 0 + + +def test_network_partial_failures_timeout(monkeypatch, ray_start_cluster): + """When the request fails due to network timeout, + verifies it prints proper warning.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "NodeManagerService.grpc_server.GetTasksInfo=10000000:10000000", + ) + cluster.add_node(num_cpus=2) + + @ray.remote + def f(): + import time + + time.sleep(30) + + a = [f.remote() for _ in range(4)] # noqa + + def verify(): + with pytest.warns(None) as record: + list_tasks(_explain=True, timeout=5) + return len(record) == 1 + + wait_for_condition(verify) + + @pytest.mark.asyncio async def test_cli_format_print(state_api_manager): data_source_client = state_api_manager.data_source_client @@ -922,6 +1155,7 @@ async def test_cli_format_print(state_api_manager): actor_table_data=[generate_actor_data(actor_id), generate_actor_data(b"12345")] ) result = await state_api_manager.list_actors(option=list_api_options()) + result = result.result # If the format is not yaml, it will raise an exception. yaml.load( get_state_api_output_to_print(result, format=AvailableFormat.YAML),