diff --git a/dashboard/modules/node/node_consts.py b/dashboard/modules/node/node_consts.py index 23ee41d49..023d45e8f 100644 --- a/dashboard/modules/node/node_consts.py +++ b/dashboard/modules/node/node_consts.py @@ -1,5 +1,12 @@ NODE_STATS_UPDATE_INTERVAL_SECONDS = 1 UPDATE_NODES_INTERVAL_SECONDS = 5 +# Until the head node is registered, +# the API server is doing more frequent update +# with this interval. +FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS = 0.1 +# If the head node is not updated within +# this timeout, it will stop frequent update. +FREQUENT_UPDATE_TIMEOUT_SECONDS = 10 MAX_COUNT_OF_GCS_RPC_ERROR = 10 MAX_LOGS_TO_CACHE = 10000 LOG_PRUNE_THREASHOLD = 1.25 diff --git a/dashboard/modules/node/node_head.py b/dashboard/modules/node/node_head.py index 7804be9d1..b0562992c 100644 --- a/dashboard/modules/node/node_head.py +++ b/dashboard/modules/node/node_head.py @@ -2,6 +2,7 @@ import asyncio import json import logging import re +import time import aiohttp.web @@ -22,6 +23,8 @@ from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node.node_consts import ( LOG_PRUNE_THREASHOLD, MAX_LOGS_TO_CACHE, + FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS, + FREQUENT_UPDATE_TIMEOUT_SECONDS, ) from ray.dashboard.utils import async_loop_forever @@ -70,6 +73,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule): self._gcs_node_info_stub = None self._collect_memory_info = False DataSource.nodes.signal.append(self._update_stubs) + # Total number of node updates happened. + self._node_update_cnt = 0 + # The time where the module is started. + self._module_start_time = time.time() + # The time it takes until the head node is registered. None means + # head node hasn't been registered. + self._head_node_registration_time_s = None async def _update_stubs(self, change): if change.old: @@ -88,6 +98,15 @@ class NodeHead(dashboard_utils.DashboardHeadModule): stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) self._stubs[node_id] = stub + def get_internal_states(self): + return { + "head_node_registration_time_s": self._head_node_registration_time_s, + "registered_nodes": len(DataSource.nodes), + "registered_agents": len(DataSource.agents), + "node_update_count": self._node_update_cnt, + "module_lifetime_s": time.time() - self._module_start_time, + } + async def _get_nodes(self): """Read the client table. @@ -120,6 +139,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule): node_id = node["nodeId"] ip = node["nodeManagerAddress"] hostname = node["nodeManagerHostname"] + if ( + ip == self._dashboard_head.ip + and not self._head_node_registration_time_s + ): + self._head_node_registration_time_s = ( + time.time() - self._module_start_time + ) node_id_to_ip[node_id] = ip node_id_to_hostname[node_id] = hostname assert node["state"] in ["ALIVE", "DEAD"] @@ -146,7 +172,40 @@ class NodeHead(dashboard_utils.DashboardHeadModule): except Exception: logger.exception("Error updating nodes.") finally: - await asyncio.sleep(node_consts.UPDATE_NODES_INTERVAL_SECONDS) + self._node_update_cnt += 1 + # _head_node_registration_time_s == None if head node is not + # registered. + head_node_not_registered = not self._head_node_registration_time_s + # Until the head node is registered, we update the + # node status more frequently. + # If the head node is not updated after 10 seconds, it just stops + # doing frequent update to avoid unexpected edge case. + if ( + head_node_not_registered + and self._node_update_cnt * FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS + < FREQUENT_UPDATE_TIMEOUT_SECONDS + ): + logger.info("SANG-TODO a") + await asyncio.sleep(FREQUENTY_UPDATE_NODES_INTERVAL_SECONDS) + else: + logger.info("SANG-TODO b") + if head_node_not_registered: + logger.warning( + "Head node is not registered even after " + f"{FREQUENT_UPDATE_TIMEOUT_SECONDS} seconds. " + "The API server might not work correctly. Please " + "report a Github issue. Internal states :" + f"{self.get_internal_states()}" + ) + await asyncio.sleep(node_consts.UPDATE_NODES_INTERVAL_SECONDS) + + @routes.get("/internal/node_module") + async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: + return dashboard_optional_utils.rest_response( + success=True, + message="", + **self.get_internal_states(), + ) @routes.get("/nodes") @dashboard_optional_utils.aiohttp_cache diff --git a/dashboard/modules/node/tests/test_node.py b/dashboard/modules/node/tests/test_node.py index 80c9ad270..ff300025d 100644 --- a/dashboard/modules/node/tests/test_node.py +++ b/dashboard/modules/node/tests/test_node.py @@ -10,6 +10,7 @@ import ray import threading from datetime import datetime, timedelta from ray.cluster_utils import Cluster +from ray.dashboard.modules.node.node_consts import UPDATE_NODES_INTERVAL_SECONDS from ray.dashboard.tests.conftest import * # noqa from ray._private.test_utils import ( format_web_url, @@ -321,5 +322,31 @@ def test_multi_node_churn( time.sleep(2) +@pytest.mark.parametrize( + "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True +) +def test_frequent_node_update( + enable_test_module, disable_aiohttp_cache, ray_start_cluster_head +): + cluster: Cluster = ray_start_cluster_head + assert wait_until_server_available(cluster.webui_url) + webui_url = cluster.webui_url + webui_url = format_web_url(webui_url) + + def verify(): + response = requests.get(webui_url + "/internal/node_module") + response.raise_for_status() + result = response.json() + data = result["data"] + head_node_registration_time = data["headNodeRegistrationTimeS"] + # If the head node is not registered, it is None. + assert head_node_registration_time is not None + # Head node should be registered before the node update interval + # because we do frequent until the head node is registered. + return head_node_registration_time < UPDATE_NODES_INTERVAL_SECONDS + + wait_for_condition(verify, timeout=15) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index 44adb8ebb..9fe1cc8d2 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -537,6 +537,10 @@ def list( _explain=_should_explain(format), ) + # If --detail is given, the default formatting is yaml. + if detail and format == AvailableFormat.DEFAULT: + format = AvailableFormat.YAML + # Print data to console. print( format_list_api_output( diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index a231991cf..dcb31dc80 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -1920,38 +1920,45 @@ def test_network_failure(shutdown_only): list_tasks(_explain=True) -def test_network_partial_failures(ray_start_cluster): +def test_network_partial_failures(monkeypatch, ray_start_cluster): """When the request fails due to network failure, verifies it prints proper warning.""" - cluster = ray_start_cluster - cluster.add_node(num_cpus=2) - ray.init(address=cluster.address) - n = cluster.add_node(num_cpus=2) + with monkeypatch.context() as m: + # defer for 5s for the second node. + # This will help the API not return until the node is killed. + m.setenv( + "RAY_testing_asio_delay_us", + "NodeManagerService.grpc_server.GetTasksInfo=5000000:5000000", + ) + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + n = cluster.add_node(num_cpus=2) - @ray.remote - def f(): - import time + @ray.remote + def f(): + import time - time.sleep(30) + time.sleep(30) - a = [f.remote() for _ in range(4)] # noqa - wait_for_condition(lambda: len(list_tasks()) == 4) + a = [f.remote() for _ in range(4)] # noqa + wait_for_condition(lambda: len(list_tasks()) == 4) - # Make sure when there's 0 node failure, it doesn't print the error. - with pytest.warns(None) as record: - list_tasks(_explain=True) - assert len(record) == 0 + # Make sure when there's 0 node failure, it doesn't print the error. + with pytest.warns(None) as record: + list_tasks(_explain=True) + assert len(record) == 0 - # Kill raylet so that list_tasks will have network error on querying raylets. - cluster.remove_node(n, allow_graceful=False) + # Kill raylet so that list_tasks will have network error on querying raylets. + cluster.remove_node(n, allow_graceful=False) - with pytest.warns(UserWarning): - list_tasks(raise_on_missing_output=False, _explain=True) + with pytest.warns(UserWarning): + list_tasks(raise_on_missing_output=False, _explain=True) - # Make sure when _explain == False, warning is not printed. - with pytest.warns(None) as record: - list_tasks(raise_on_missing_output=False, _explain=False) - assert len(record) == 0 + # Make sure when _explain == False, warning is not printed. + with pytest.warns(None) as record: + list_tasks(raise_on_missing_output=False, _explain=False) + assert len(record) == 0 def test_network_partial_failures_timeout(monkeypatch, ray_start_cluster): @@ -2193,8 +2200,23 @@ def test_detail(shutdown_only): assert "test_detail" in result.output # Columns are upper case in the default formatting (table). - assert "serialized_runtime_env".upper() in result.output - assert "actor_id".upper() in result.output + assert "serialized_runtime_env" in result.output + assert "actor_id" in result.output + + # Make sure when the --detail option is specified, the default formatting + # is yaml. If the format is not yaml, the below line will raise an yaml exception. + print( + yaml.load( + result.output, + Loader=yaml.FullLoader, + ) + ) + + # When the format is given, it should respect that formatting. + result = runner.invoke(cli_list, ["actors", "--detail", "--format=table"]) + assert result.exit_code == 0 + with pytest.raises(yaml.YAMLError): + yaml.load(result.output, Loader=yaml.FullLoader) def _try_state_query_expect_rate_limit(api_func, res_q, start_q=None):