[Usage stats] Add tags & number of nodes to the report. (#25852)

This PR adds the RAY_EXTRA_USAGE_TAGS to add additional tag metadata + number of nodes to the report.
This commit is contained in:
SangBin Cho 2022-07-08 00:31:04 +09:00 committed by GitHub
parent 9b49417a72
commit 2dd5fdfdf1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 157 additions and 11 deletions

View file

@ -1,19 +1,24 @@
import os
import asyncio
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import ray
import ray.dashboard.utils as dashboard_utils
import ray._private.usage.usage_lib as ray_usage_lib
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.utils import async_loop_forever
from ray.experimental.state.state_manager import StateDataSourceClient
from ray.dashboard.consts import env_integer
logger = logging.getLogger(__name__)
def gcs_query_timeout():
return env_integer("GCS_QUERY_TIMEOUT_DEFAULT", 10)
class UsageStatsHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
@ -34,6 +39,7 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule):
self.total_failed = 0
# The seq number of report. It increments whenever a new report is sent.
self.seq_no = 0
self._state_api_data_source_client = None
if ray._private.utils.check_dashboard_dependencies_installed():
import aiohttp
@ -49,7 +55,7 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule):
usage_stats_prompt_enabled=self.usage_stats_prompt_enabled,
)
def _report_usage_sync(self):
def _report_usage_sync(self, total_num_nodes: Optional[int]):
"""
- Always write usage_stats.json regardless of report success/failure.
- If report fails, the error message should be written to usage_stats.json
@ -66,6 +72,7 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule):
self.total_success,
self.total_failed,
self.seq_no,
total_num_nodes,
)
error = None
@ -94,7 +101,19 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule):
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
await loop.run_in_executor(executor, self._report_usage_sync)
# Find the number of nodes.
total_num_nodes = None
try:
result = await self._state_api_data_source_client.get_all_node_info(
timeout=gcs_query_timeout()
)
total_num_nodes = len(result.node_info_list)
except Exception as e:
logger.info(f"Faile to query number of nodes in the cluster: {e}")
await loop.run_in_executor(
executor, lambda: self._report_usage_sync(total_num_nodes)
)
@async_loop_forever(ray_usage_lib._usage_stats_report_interval_s())
async def periodically_report_usage(self):
@ -106,6 +125,8 @@ class UsageStatsHead(dashboard_utils.DashboardHeadModule):
return
else:
logger.info("Usage reporting is enabled.")
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._state_api_data_source_client = StateDataSourceClient(gcs_channel)
# Wait for 1 minutes to send the first report
# so autoscaler has the chance to set DEBUG_AUTOSCALING_STATUS.
await asyncio.sleep(min(60, ray_usage_lib._usage_stats_report_interval_s()))

View file

@ -34,3 +34,5 @@ USAGE_STATS_CONFIRMATION_MESSAGE = (
LIBRARY_USAGE_PREFIX = "library_usage_"
USAGE_STATS_NAMESPACE = "usage_stats"
EXTRA_USAGE_TAGS = "RAY_EXTRA_USAGE_TAGS"

View file

@ -52,7 +52,7 @@ import uuid
from dataclasses import asdict, dataclass
from enum import Enum, auto
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional
import requests
import yaml
@ -134,6 +134,11 @@ class UsageStatsToReport:
total_failed: int
#: The sequence number of the report.
seq_number: int
#: The extra tags to report when specified by an
# environment variable EXTRA_USAGE_TAGS
extra_usage_tags: Optional[Dict[str, str]]
#: The number of alive nodes when the report is generated.
total_num_nodes: Optional[int]
@dataclass(init=True)
@ -477,6 +482,31 @@ def get_library_usages_to_report(gcs_client, num_retries: int) -> List[str]:
return []
def _parse_extra_usage_tags() -> Dict[str, str]:
"""Parse the extra usage tags given by the environment variable.
The env var should be given this way; key=value;key=value.
If parsing is failed, it will return the empty data.
Returns:
Dictionary of key value pair parsed.
"""
extra_tags = os.getenv("RAY_EXTRA_USAGE_TAGS", None)
if not extra_tags:
return None
try:
result = {}
kvs = extra_tags.strip(";").split(";")
for kv in kvs:
k, v = kv.split("=")
result[k] = v
return result
except Exception as e:
logger.debug(f"Failed to parse extra usage tags. Error: {e}")
return None
def get_cluster_status_to_report(gcs_client, num_retries: int) -> ClusterStatusToReport:
"""Get the current status of this cluster.
@ -637,6 +667,7 @@ def generate_report_data(
total_success: int,
total_failed: int,
seq_number: int,
total_num_nodes: Optional[int],
) -> UsageStatsToReport:
"""Generate the report data.
@ -645,12 +676,13 @@ def generate_report_data(
`_generate_cluster_metadata`.
cluster_config_to_report: The cluster (autoscaler)
config generated by `get_cluster_config_to_report`.
total_success(int): The total number of successful report
total_success: The total number of successful report
for the lifetime of the cluster.
total_failed(int): The total number of failed report
total_failed: The total number of failed report
for the lifetime of the cluster.
seq_number(int): The sequence number that's incremented whenever
seq_number: The sequence number that's incremented whenever
a new report is sent.
total_num_nodes: The number of current alive nodes in the cluster.
Returns:
UsageStats
@ -663,6 +695,7 @@ def generate_report_data(
ray.experimental.internal_kv.internal_kv_get_gcs_client(),
num_retries=20,
)
data = UsageStatsToReport(
ray_version=cluster_metadata["ray_version"],
python_version=cluster_metadata["python_version"],
@ -686,6 +719,8 @@ def generate_report_data(
total_success=total_success,
total_failed=total_failed,
seq_number=seq_number,
extra_usage_tags=_parse_extra_usage_tags(),
total_num_nodes=total_num_nodes,
)
return data

View file

@ -99,6 +99,35 @@ def reset_lib_usage():
ray_usage_lib._recorded_library_usages.clear()
def test_parse_extra_usage_tags(monkeypatch):
with monkeypatch.context() as m:
# Test a normal case.
m.setenv("RAY_EXTRA_USAGE_TAGS", "key=val;key2=val2")
result = ray_usage_lib._parse_extra_usage_tags()
assert result["key"] == "val"
assert result["key2"] == "val2"
m.setenv("RAY_EXTRA_USAGE_TAGS", "key=val;key2=val2;")
result = ray_usage_lib._parse_extra_usage_tags()
assert result["key"] == "val"
assert result["key2"] == "val2"
# Test that the env var is not given.
m.delenv("RAY_EXTRA_USAGE_TAGS")
result = ray_usage_lib._parse_extra_usage_tags()
assert result is None
# Test the parsing failure.
m.setenv("RAY_EXTRA_USAGE_TAGS", "key=val,key2=val2")
result = ray_usage_lib._parse_extra_usage_tags()
assert result is None
# Test differnt types of parsing failures.
m.setenv("RAY_EXTRA_USAGE_TAGS", "key=v=al,key2=val2")
result = ray_usage_lib._parse_extra_usage_tags()
assert result is None
def test_usage_stats_enabledness(monkeypatch, tmp_path, reset_lib_usage):
with monkeypatch.context() as m:
m.setenv("RAY_USAGE_STATS_ENABLED", "1")
@ -586,7 +615,7 @@ provider:
cluster_config_file_path
)
d = ray_usage_lib.generate_report_data(
cluster_metadata, cluster_config_to_report, 2, 2, 2
cluster_metadata, cluster_config_to_report, 2, 2, 2, 2
)
validate(instance=asdict(d), schema=schema)
@ -743,6 +772,8 @@ provider:
assert payload["total_num_gpus"] is None
assert payload["total_memory_gb"] > 0
assert payload["total_object_store_memory_gb"] > 0
assert payload["extra_usage_tags"] is None
assert payload["total_num_nodes"] == 1
if os.environ.get("RAY_MINIMAL") == "1":
# Since we start a serve actor for mocking a server using runtime env.
assert set(payload["library_usages"]) == {"serve"}
@ -1031,6 +1062,63 @@ ray.init()
wait_for_condition(verify)
def test_usage_stats_tags(monkeypatch, ray_start_cluster, reset_lib_usage):
"""
Test usage tags are correctly reported.
"""
with monkeypatch.context() as m:
m.setenv("RAY_USAGE_STATS_ENABLED", "1")
m.setenv("RAY_USAGE_STATS_REPORT_URL", "http://127.0.0.1:8000/usage")
m.setenv("RAY_USAGE_STATS_REPORT_INTERVAL_S", "1")
m.setenv("RAY_EXTRA_USAGE_TAGS", "key=val;key2=val2")
cluster = ray_start_cluster
cluster.add_node(num_cpus=3)
cluster.add_node(num_cpus=3)
context = ray.init(address=cluster.address)
"""
Verify the usage_stats.json contains the lib usage.
"""
temp_dir = pathlib.Path(context.address_info["session_dir"])
wait_for_condition(lambda: file_exists(temp_dir), timeout=30)
def verify():
tags = read_file(temp_dir, "usage_stats")["extra_usage_tags"]
num_nodes = read_file(temp_dir, "usage_stats")["total_num_nodes"]
assert tags == {"key": "val", "key2": "val2"}
assert num_nodes == 2
return True
wait_for_condition(verify)
def test_usage_stats_gcs_query_failure(monkeypatch, ray_start_cluster, reset_lib_usage):
"""Test None data is reported when the GCS query is failed."""
with monkeypatch.context() as m:
m.setenv("RAY_USAGE_STATS_ENABLED", "1")
m.setenv("RAY_USAGE_STATS_REPORT_URL", "http://127.0.0.1:8000/usage")
m.setenv("RAY_USAGE_STATS_REPORT_INTERVAL_S", "1")
m.setenv("GCS_QUERY_TIMEOUT_DEFAULT", "1")
m.setenv(
"RAY_testing_asio_delay_us",
"NodeInfoGcsService.grpc_server.GetAllNodeInfo=2000000:2000000",
)
cluster = ray_start_cluster
cluster.add_node(num_cpus=3)
context = ray.init(address=cluster.address)
temp_dir = pathlib.Path(context.address_info["session_dir"])
wait_for_condition(lambda: file_exists(temp_dir), timeout=30)
def verify():
num_nodes = read_file(temp_dir, "usage_stats")["total_num_nodes"]
return num_nodes is None
wait_for_condition(verify)
if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))