diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py index 1f42ad614..c46be0ba3 100644 --- a/dashboard/dashboard.py +++ b/dashboard/dashboard.py @@ -16,6 +16,7 @@ import ray.dashboard.head as dashboard_head import ray.dashboard.utils as dashboard_utils from ray._private.gcs_pubsub import GcsPublisher from ray._private.ray_logging import setup_component_logger +from typing import Optional, Set # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray provides a default configuration at @@ -40,14 +41,15 @@ class Dashboard: def __init__( self, - host, - port, - port_retries, - gcs_address, - log_dir=None, - temp_dir=None, - session_dir=None, - minimal=False, + host: str, + port: int, + port_retries: int, + gcs_address: str, + log_dir: str = None, + temp_dir: str = None, + session_dir: str = None, + minimal: bool = False, + modules_to_load: Optional[Set[str]] = None, ): self.dashboard_head = dashboard_head.DashboardHead( http_host=host, @@ -58,6 +60,7 @@ class Dashboard: temp_dir=temp_dir, session_dir=session_dir, minimal=minimal, + modules_to_load=modules_to_load, ) async def run(self): @@ -154,6 +157,16 @@ if __name__ == "__main__": "by `pip install ray[default]`." ), ) + parser.add_argument( + "--modules-to-load", + required=False, + default=None, + help=( + "Specify the list of module names in [module_1],[module_2] format." + "E.g., JobHead,StateHead... " + "If nothing is specified, all modules are loaded." + ), + ) args = parser.parse_args() @@ -167,6 +180,12 @@ if __name__ == "__main__": backup_count=args.logging_rotate_backup_count, ) + if args.modules_to_load: + modules_to_load = set(args.modules_to_load.strip(" ,").split(",")) + else: + # None == default. + modules_to_load = None + dashboard = Dashboard( args.host, args.port, @@ -176,6 +195,7 @@ if __name__ == "__main__": temp_dir=args.temp_dir, session_dir=args.session_dir, minimal=args.minimal, + modules_to_load=modules_to_load, ) loop = asyncio.get_event_loop() diff --git a/dashboard/head.py b/dashboard/head.py index 67cc13886..f22373723 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -11,11 +11,14 @@ import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils import ray.experimental.internal_kv as internal_kv from ray._private import ray_constants +from ray.dashboard.utils import DashboardHeadModule from ray._private.gcs_pubsub import GcsAioErrorSubscriber, GcsAioLogSubscriber from ray._private.gcs_utils import GcsClient, GcsAioClient, check_health from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.utils import async_loop_forever +from typing import Optional, Set + try: from grpc import aio as aiogrpc except ImportError: @@ -59,15 +62,31 @@ class GCSHealthCheckThread(threading.Thread): class DashboardHead: def __init__( self, - http_host, - http_port, - http_port_retries, - gcs_address, - log_dir, - temp_dir, - session_dir, - minimal, + http_host: str, + http_port: int, + http_port_retries: int, + gcs_address: str, + log_dir: str, + temp_dir: str, + session_dir: str, + minimal: bool, + modules_to_load: Optional[Set[str]] = None, ): + """ + Args: + http_host: The host address for the Http server. + http_port: The port for the Http server. + http_port_retries: The maximum retry to bind ports for the Http server. + gcs_address: The GCS address in the {address}:{port} format. + log_dir: The log directory. E.g., /tmp/session_latest/logs. + temp_dir: The temp directory. E.g., /tmp. + session_dir: The session directory. E.g., tmp/session_latest. + minimal: Whether or not it will load the minimal modules. + modules_to_load: A set of module name in string to load. + By default (None), it loads all available modules. + Note that available modules could be changed depending on + minimal flags. + """ self.minimal = minimal self.health_check_thread: GCSHealthCheckThread = None self._gcs_rpc_error_counter = 0 @@ -76,6 +95,7 @@ class DashboardHead: self.http_host = "127.0.0.1" if http_host == "localhost" else http_host self.http_port = http_port self.http_port_retries = http_port_retries + self._modules_to_load = modules_to_load self.gcs_address = None assert gcs_address is not None @@ -84,6 +104,7 @@ class DashboardHead: self.temp_dir = temp_dir self.session_dir = session_dir self.aiogrpc_gcs_channel = None + self.gcs_aio_client = None self.gcs_error_subscriber = None self.gcs_log_subscriber = None self.ip = ray.util.get_node_ip_address() @@ -148,19 +169,35 @@ class DashboardHead: # https://github.com/ray-project/ray/issues/16328 os._exit(-1) - def _load_modules(self): - """Load dashboard head modules.""" + def _load_modules(self, modules_to_load: Optional[Set[str]] = None): + """Load dashboard head modules. + + Args: + modules: A list of module names to load. By default (None), + it loads all modules. + """ modules = [] - head_cls_list = dashboard_utils.get_all_modules( - dashboard_utils.DashboardHeadModule - ) + head_cls_list = dashboard_utils.get_all_modules(DashboardHeadModule) + + # Select modules to load. + modules_to_load = modules_to_load or {m.__name__ for m in head_cls_list} + logger.info("Modules to load: %s", modules_to_load) + for cls in head_cls_list: - logger.info( - "Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls + logger.info("Loading %s: %s", DashboardHeadModule.__name__, cls) + if cls.__name__ in modules_to_load: + c = cls(self) + modules.append(c) + + # Verify modules are loaded as expected. + loaded_modules = {type(m).__name__ for m in modules} + if loaded_modules != modules_to_load: + assert False, ( + "Actual loaded modules, {}, doesn't match the requested modules " + "to load, {}".format(loaded_modules, modules_to_load) ) - c = cls(self) - modules.append(c) - logger.info("Loaded %d modules.", len(modules)) + + logger.info("Loaded %d modules. %s", len(modules), modules) return modules async def run(self): @@ -192,7 +229,7 @@ class DashboardHead: except Exception: logger.exception(f"Error notifying coroutine {co}") - modules = self._load_modules() + modules = self._load_modules(self._modules_to_load) http_host, http_port = self.http_host, self.http_port if not self.minimal: diff --git a/dashboard/modules/usage_stats/usage_stats_head.py b/dashboard/modules/usage_stats/usage_stats_head.py index fe7359b2a..932c0ce0f 100644 --- a/dashboard/modules/usage_stats/usage_stats_head.py +++ b/dashboard/modules/usage_stats/usage_stats_head.py @@ -17,9 +17,7 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule): super().__init__(dashboard_head) self.usage_stats_enabled = ray_usage_lib.usage_stats_enabled() self.usage_stats_prompt_enabled = ray_usage_lib.usage_stats_prompt_enabled() - self.cluster_config_to_report = ray_usage_lib.get_cluster_config_to_report( - os.path.expanduser("~/ray_bootstrap_config.yaml") - ) + self.cluster_config_to_report = None self.session_dir = dashboard_head.session_dir self.client = ray_usage_lib.UsageReportClient() # The total number of report succeeded. @@ -95,6 +93,9 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule): await self._report_usage_async() async def run(self, server): + self.cluster_config_to_report = ray_usage_lib.get_cluster_config_to_report( + os.path.expanduser("~/ray_bootstrap_config.yaml") + ) if not self.usage_stats_enabled: logger.info("Usage reporting is disabled.") return diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 7c0123f95..3dc47448b 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -33,10 +33,14 @@ from ray._private.test_utils import ( wait_until_succeeded_without_exception, ) from ray.dashboard import dashboard +from ray.dashboard.head import DashboardHead from ray.dashboard.modules.dashboard_sdk import DEFAULT_DASHBOARD_ADDRESS from ray.experimental.state.api import StateApiClient from ray.experimental.state.common import ListApiOptions, StateResource from ray.experimental.state.exception import ServerUnavailable +from ray.experimental.internal_kv import _initialize_internal_kv +from unittest.mock import MagicMock +from ray.dashboard.utils import DashboardHeadModule import psutil @@ -962,5 +966,45 @@ def test_dashboard_requests_fail_on_missing_deps(ray_start_with_dashboard): assert response is None +@pytest.mark.skipif( + os.environ.get("RAY_DEFAULT") != "1", + reason="This test only works for default installation.", +) +def test_dashboard_module_load(tmpdir): + """Verify if the head module can load only selected modules.""" + head = DashboardHead( + "127.0.0.1", + 8265, + 1, + "127.0.0.1:6379", + str(tmpdir), + str(tmpdir), + str(tmpdir), + False, + ) + + # Test basic. + loaded_modules_expected = {"UsageStatsHead", "JobHead"} + loaded_modules = head._load_modules(modules_to_load=loaded_modules_expected) + loaded_modules_actual = {type(m).__name__ for m in loaded_modules} + assert loaded_modules_actual == loaded_modules_expected + + # Test modules that don't exist. + loaded_modules_expected = {"StateHea"} + with pytest.raises(AssertionError): + loaded_modules = head._load_modules(modules_to_load=loaded_modules_expected) + + # Test the base case. + # It is needed to pass assertion check from one of modules. + gcs_client = MagicMock() + _initialize_internal_kv(gcs_client) + loaded_modules_expected = { + m.__name__ for m in dashboard_utils.get_all_modules(DashboardHeadModule) + } + loaded_modules = head._load_modules() + loaded_modules_actual = {type(m).__name__ for m in loaded_modules} + assert loaded_modules_actual == loaded_modules_expected + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index 23fc4af0d..2d4239d83 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -855,16 +855,20 @@ class Node: process_info, ] - def start_dashboard(self, require_dashboard: bool): + def start_api_server(self, *, include_dashboard: bool, raise_on_failure: bool): """Start the dashboard. Args: - require_dashboard: If true, this will raise an exception - if we fail to start the dashboard. Otherwise it will print - a warning if we fail to start the dashboard. + include_dashboard: If true, this will load all dashboard-related modules + when starting the API server. Otherwise, it will only + start the modules that are not relevant to the dashboard. + raise_on_failure: If true, this will raise an exception + if we fail to start the API server. Otherwise it will print + a warning if we fail to start the API server. """ - self._webui_url, process_info = ray._private.services.start_dashboard( - require_dashboard, + self._webui_url, process_info = ray._private.services.start_api_server( + include_dashboard, + raise_on_failure, self._ray_params.dashboard_host, self.gcs_address, self._temp_dir, @@ -1060,10 +1064,21 @@ class Node: if self._ray_params.ray_client_server_port: self.start_ray_client_server() - if self._ray_params.include_dashboard: - self.start_dashboard(require_dashboard=True) - elif self._ray_params.include_dashboard is None: - self.start_dashboard(require_dashboard=False) + if self._ray_params.include_dashboard is None: + # Default + include_dashboard = True + raise_on_api_server_failure = False + elif self._ray_params.include_dashboard is False: + include_dashboard = False + raise_on_api_server_failure = False + else: + include_dashboard = True + raise_on_api_server_failure = True + + self.start_api_server( + include_dashboard=include_dashboard, + raise_on_failure=raise_on_api_server_failure, + ) def start_ray_processes(self): """Start all of the processes on the node.""" diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 9696e41e3..3485dc338 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1185,8 +1185,9 @@ def start_log_monitor( return process_info -def start_dashboard( - require_dashboard: bool, +def start_api_server( + include_dashboard: bool, + raise_on_failure: bool, host: str, gcs_address: str, temp_dir: str, @@ -1198,12 +1199,15 @@ def start_dashboard( backup_count: int = 0, redirect_logging: bool = True, ): - """Start a dashboard process. + """Start a API server process. Args: - require_dashboard: If true, this will raise an exception if we - fail to start the dashboard. Otherwise it will print a warning if - we fail to start the dashboard. + include_dashboard: If true, this will load all dashboard-related modules + when starting the API server. Otherwise, it will only + start the modules that are not relevant to the dashboard. + raise_on_failure: If true, this will raise an exception + if we fail to start the API server. Otherwise it will print + a warning if we fail to start the API server. host: The host to bind the dashboard web server to. gcs_address: The gcs address the dashboard should connect to temp_dir: The temporary directory used for log files and @@ -1293,6 +1297,13 @@ def start_dashboard( if minimal: command.append("--minimal") + if not include_dashboard: + # If dashboard is not included, load modules + # that are irrelevant to the dashboard. + # TODO(sang): Modules like job or state APIs should be + # loaded although dashboard is disabled. Fix it. + command.append("--modules-to-load=UsageStatsHead") + process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_DASHBOARD, @@ -1326,6 +1337,7 @@ def start_dashboard( if dashboard_returncode is not None else "" ) + # TODO(sang): Change it to the API server. err_msg = "Failed to start the dashboard" + returncode_str if logdir: dashboard_log = os.path.join(logdir, "dashboard.log") @@ -1357,9 +1369,10 @@ def start_dashboard( dashboard_url = "" return dashboard_url, process_info except Exception as e: - if require_dashboard: + if raise_on_failure: raise e from e else: + # TODO(sang): Change it to the API server. logger.error(f"Failed to start the dashboard: {e}") logger.exception(e) return None, None diff --git a/python/ray/tests/test_usage_stats.py b/python/ray/tests/test_usage_stats.py index c5c9c0f6b..11e62f226 100644 --- a/python/ray/tests/test_usage_stats.py +++ b/python/ray/tests/test_usage_stats.py @@ -1300,6 +1300,33 @@ def test_usage_stats_gcs_query_failure( ) +def test_usages_stats_available_when_dashboard_not_included( + monkeypatch, ray_start_cluster, reset_usage_stats +): + """ + Test library usage is correctly reported when they are imported from + workers. + """ + with monkeypatch.context() as m: + m.setenv("RAY_USAGE_STATS_ENABLED", "1") + m.setenv("RAY_USAGE_STATS_REPORT_URL", "http://127.0.0.1:8000/usage") + m.setenv("RAY_USAGE_STATS_REPORT_INTERVAL_S", "1") + cluster = ray_start_cluster + cluster.add_node(num_cpus=1, include_dashboard=False) + ray.init(address=cluster.address) + + """ + Verify the usage_stats.json contains the lib usage. + """ + temp_dir = pathlib.Path(cluster.head_node.get_session_dir_path()) + wait_for_condition(lambda: file_exists(temp_dir), timeout=30) + + def verify(): + return read_file(temp_dir, "usage_stats")["seq_number"] > 2 + + wait_for_condition(verify) + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))