[Core][State Observability] Truncate warning message is incorrect when filter is used (#26801)

Signed-off-by: rickyyx rickyx@anyscale.com

# Why are these changes needed?
When we returned less/incomplete results to users, there could be 3 reasons:

Data being truncated at the data source (raylets -> API server)
Data being filtered at the API server
Data being limited at the API server
We are not distinguishing the those 3 scenarios, but we should. This is why we thought data being truncated when it's actually filtered/limited.

This PR distinguishes these scenarios and prompt warnings accordingly.

# Related issue number
Closes #26570
Closes #26923
This commit is contained in:
Ricky Xu 2022-07-25 23:31:49 -07:00 committed by GitHub
parent 65563e994b
commit 259473c221
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 216 additions and 94 deletions

View file

@ -18,6 +18,7 @@ from ray.dashboard.optional_utils import rest_response
from ray.dashboard.state_aggregator import StateAPIManager
from ray.dashboard.utils import Change
from ray.experimental.state.common import (
RAY_MAX_LIMIT_FROM_API_SERVER,
ListApiOptions,
GetLogOptions,
SummaryApiOptions,
@ -166,6 +167,13 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
if req.query.get("limit") is not None
else DEFAULT_LIMIT
)
if limit > RAY_MAX_LIMIT_FROM_API_SERVER:
raise ValueError(
f"Given limit {limit} exceeds the supported "
f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit."
)
timeout = int(req.query.get("timeout"))
filter_keys = req.query.getall("filter_keys", [])
filter_predicates = req.query.getall("filter_predicates", [])

View file

@ -20,7 +20,7 @@ from ray.experimental.state.common import (
PlacementGroupState,
RuntimeEnvState,
SummaryApiResponse,
MAX_LIMIT,
RAY_MAX_LIMIT_FROM_API_SERVER,
SummaryApiOptions,
TaskSummaries,
StateSchema,
@ -51,7 +51,7 @@ GCS_QUERY_FAILURE_WARNING = (
)
NODE_QUERY_FAILURE_WARNING = (
"Failed to query data from {type}. "
"Queryed {total} {type} "
"Queried {total} {type} "
"and {network_failures} {type} failed to reply. It is due to "
"(1) {type} is unexpectedly failed. "
"(2) {type} is overloaded. "
@ -202,14 +202,18 @@ class StateAPIManager:
message=message, fields_to_decode=["actor_id", "owner_id"]
)
result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, ActorState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["actor_id"])
result = list(islice(result, option.limit))
return ListApiResponse(
result=result,
total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
@ -234,15 +238,19 @@ class StateAPIManager:
fields_to_decode=["placement_group_id", "node_id"],
)
result.append(data)
num_after_truncation = len(result)
result = self._filter(
result, option.filters, PlacementGroupState, option.detail
)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["placement_group_id"])
return ListApiResponse(
result=list(islice(result, option.limit)),
total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
@ -263,15 +271,21 @@ class StateAPIManager:
data["node_ip"] = data["node_manager_address"]
result.append(data)
total_nodes = len(result)
# No reason to truncate node because they are usually small.
num_after_truncation = len(result)
result = self._filter(result, option.filters, NodeState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["node_id"])
total_nodes = len(result)
result = list(islice(result, option.limit))
return ListApiResponse(
result=result,
# No reason to truncate node because they are usually small.
total=total_nodes,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
@ -296,13 +310,17 @@ class StateAPIManager:
data["ip"] = data["worker_address"]["ip_address"]
result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, WorkerState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["worker_id"])
result = list(islice(result, option.limit))
return ListApiResponse(
result=result,
total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
@ -320,6 +338,8 @@ class StateAPIManager:
result=result,
# TODO(sang): Support this.
total=len(result),
num_after_truncation=len(result),
num_filtered=len(result),
)
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
@ -382,8 +402,9 @@ class StateAPIManager:
TaskStatus.RUNNING
].name
result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, TaskState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["task_id"])
result = list(islice(result, option.limit))
@ -391,6 +412,8 @@ class StateAPIManager:
result=result,
partial_failure_warning=partial_failure_warning,
total=total_tasks,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
@ -471,7 +494,9 @@ class StateAPIManager:
"and `ray.init`."
)
num_after_truncation = len(result)
result = self._filter(result, option.filters, ObjectState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["object_id"])
result = list(islice(result, option.limit))
@ -479,6 +504,8 @@ class StateAPIManager:
result=result,
partial_failure_warning=partial_failure_warning,
total=total_objects,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
warnings=callsite_warning,
)
@ -515,7 +542,7 @@ class StateAPIManager:
states = reply.runtime_env_states
for state in states:
data = self._message_to_dict(message=state, fields_to_decode=[])
# Need to deseiralize this field.
# Need to deserialize this field.
data["runtime_env"] = RuntimeEnv.deserialize(
data["runtime_env"]
).to_dict()
@ -535,8 +562,9 @@ class StateAPIManager:
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
num_after_truncation = len(result)
result = self._filter(result, option.filters, RuntimeEnvState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.
def sort_func(entry):
@ -556,12 +584,16 @@ class StateAPIManager:
result=result,
partial_failure_warning=partial_failure_warning,
total=total_runtime_envs,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss.
result = await self.list_tasks(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
@ -573,12 +605,15 @@ class StateAPIManager:
result=summary,
partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
)
async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss.
result = await self.list_actors(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
@ -590,12 +625,15 @@ class StateAPIManager:
result=summary,
partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
)
async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss.
# For summary, try getting as many entries as possible to minimize data loss.
result = await self.list_objects(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
@ -607,6 +645,7 @@ class StateAPIManager:
result=summary,
partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
)
def _message_to_dict(

View file

@ -314,6 +314,12 @@ class StateApiClient(SubmissionClient):
def _print_api_warning(self, resource: StateResource, api_response: dict):
"""Print the API warnings.
We print warnings for users:
1. when some data sources are not available
2. when results were truncated at the data source
3. when results were limited
4. when callsites not enabled for listing objects
Args:
resource: Resource names, i.e. 'jobs', 'actors', 'nodes',
see `StateResource` for details.
@ -324,16 +330,34 @@ class StateApiClient(SubmissionClient):
if warning_msgs:
warnings.warn(warning_msgs)
# Print warnings if data is truncated.
data = api_response["result"]
# Print warnings if data is truncated at the data source.
num_after_truncation = api_response["num_after_truncation"]
total = api_response["total"]
if total > len(data):
if total > num_after_truncation:
# NOTE(rickyyx): For now, there's not much users could do (neither can we),
# with hard truncation. Unless we allow users to set a higher
# `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will always be truncated at the
# data source.
warnings.warn(
(
f"{len(data)} ({total} total) {resource.value} "
f"are returned. {total - len(data)} entries have been truncated. "
"Use `--filter` to reduce the amount of data to return "
"or increase the limit by specifying`--limit`."
f"{num_after_truncation} ({total} total) {resource.value} "
"are retrieved from the data source. "
f"{total - num_after_truncation} entries have been truncated. "
f"Max of {num_after_truncation} entries are retrieved from data "
"source to prevent over-sized payloads."
),
)
# Print warnings if return data is limited at the API server due to
# limit enforced at the server side
num_filtered = api_response["num_filtered"]
data = api_response["result"]
if num_filtered > len(data):
warnings.warn(
(
f"{len(data)}/{num_filtered} {resource.value} returned. "
"Use `--filter` to reduce the amount of data to return or "
"setting a higher limit with `--limit` to see all data. "
),
)

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass, field, fields
from enum import Enum, unique
from typing import Dict, List, Optional, Set, Tuple, Union
from ray._private.ray_constants import env_integer
from ray.core.generated.common_pb2 import TaskType
from ray.dashboard.modules.job.common import JobInfo
@ -12,7 +13,17 @@ logger = logging.getLogger(__name__)
DEFAULT_RPC_TIMEOUT = 30
DEFAULT_LIMIT = 100
DEFAULT_LOG_LIMIT = 1000
MAX_LIMIT = 10000
# Max number of entries from API server to the client
RAY_MAX_LIMIT_FROM_API_SERVER = env_integer(
"RAY_MAX_LIMIT_FROM_API_SERVER", 10 * 1000
) # 10k
# Max number of entries from data sources (rest will be truncated at the
# data source, e.g. raylet)
RAY_MAX_LIMIT_FROM_DATA_SOURCE = env_integer(
"RAY_MAX_LIMIT_FROM_DATA_SOURCE", 10 * 1000
) # 10k
STATE_OBS_ALPHA_FEEDBACK_MSG = [
"\n==========ALPHA PREVIEW, FEEDBACK NEEDED ===============",
@ -85,12 +96,6 @@ class ListApiOptions:
if self.filters is None:
self.filters = []
if self.limit > MAX_LIMIT:
raise ValueError(
f"Given limit {self.limit} exceeds the supported "
f"limit {MAX_LIMIT}. Use a lower limit."
)
for filter in self.filters:
_, filter_predicate, _ = filter
if filter_predicate != "=" and filter_predicate != "!=":
@ -355,10 +360,30 @@ class RuntimeEnvState(StateSchema):
@dataclass(init=True)
class ListApiResponse:
# Total number of the resource from the cluster.
# Note that this value can be larger than `result`
# because `result` can be truncated.
# NOTE(rickyyx): We currently perform hard truncation when querying
# resources which could have a large number (e.g. asking raylets for
# the number of all objects).
# The returned of resources seen by the user will go through from the
# below funnel:
# - total
# | With truncation at the data source if the number of returned
# | resource exceeds `RAY_MAX_LIMIT_FROM_DATA_SOURCE`
# v
# - num_after_truncation
# | With filtering at the state API server
# v
# - num_filtered
# | With limiting,
# | set by min(`RAY_MAX_LIMIT_FROM_API_SERER`, <user-supplied limit>)
# v
# - len(result)
# Total number of the available resource from the cluster.
total: int
# Number of resources returned by data sources after truncation
num_after_truncation: int
# Number of resources after filtering
num_filtered: int
# Returned data. None if no data is returned.
result: List[
Union[
@ -602,17 +627,19 @@ class ObjectSummaries:
@dataclass(init=True)
class StateSummary:
# Node ID -> summary per node
# If the data is not required to be orgnized per node, it will contain
# If the data is not required to be organized per node, it will contain
# a single key, "cluster".
node_id_to_summary: Dict[str, Union[TaskSummaries, ActorSummaries, ObjectSummaries]]
@dataclass(init=True)
class SummaryApiResponse:
# Total number of the resource from the cluster.
# Note that this value can be larger than `result`
# because `result` can be truncated.
# Carried over from ListApiResponse
# We currently use list API for listing the resources
total: int
# Carried over from ListApiResponse
# Number of resources returned by data sources after truncation
num_after_truncation: int
result: StateSummary = None
partial_failure_warning: str = ""
# A list of warnings to print.

View file

@ -40,7 +40,7 @@ from ray.core.generated.runtime_env_agent_pb2 import (
)
from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub
from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient
from ray.experimental.state.common import MAX_LIMIT
from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE
from ray.experimental.state.exception import DataSourceUnavailable
logger = logging.getLogger(__name__)
@ -209,7 +209,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None
) -> Optional[GetAllActorInfoReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllActorInfoRequest(limit=limit)
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
@ -222,7 +222,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None
) -> Optional[GetAllPlacementGroupReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllPlacementGroupRequest(limit=limit)
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
@ -243,7 +243,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None
) -> Optional[GetAllWorkerInfoReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllWorkerInfoRequest(limit=limit)
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
@ -274,7 +274,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetTasksInfoReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._raylet_stubs.get(node_id)
if not stub:
@ -290,7 +290,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetObjectsInfoReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._raylet_stubs.get(node_id)
if not stub:
@ -307,7 +307,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetRuntimeEnvsInfoReply]:
if not limit:
limit = MAX_LIMIT
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._runtime_env_agent_stub.get(node_id)
if not stub:

View file

@ -81,7 +81,6 @@ from ray.experimental.state.common import (
SupportedFilterType,
TaskState,
WorkerState,
MAX_LIMIT,
StateSchema,
state_column,
)
@ -1764,42 +1763,52 @@ def test_list_actor_tasks(shutdown_only):
tasks = list_tasks()
# Actor.__init__: 1 finished
# Actor.call: 1 running, 9 waiting for execution (queued).
correct_num_tasks = len(tasks) == 11
waiting_for_execution = len(
list(
filter(
lambda task: task["scheduling_state"] == "WAITING_FOR_EXECUTION",
tasks,
assert len(tasks) == 11
assert (
len(
list(
filter(
lambda task: task["scheduling_state"]
== "WAITING_FOR_EXECUTION",
tasks,
)
)
)
== 9
)
scheduled = len(
list(filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks))
)
waiting_for_dep = len(
list(
filter(
lambda task: task["scheduling_state"] == "WAITING_FOR_DEPENDENCIES",
tasks,
assert (
len(
list(
filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks)
)
)
== 0
)
running = len(
list(
filter(
lambda task: task["scheduling_state"] == "RUNNING",
tasks,
assert (
len(
list(
filter(
lambda task: task["scheduling_state"]
== "WAITING_FOR_DEPENDENCIES",
tasks,
)
)
)
== 0
)
assert (
len(
list(
filter(
lambda task: task["scheduling_state"] == "RUNNING",
tasks,
)
)
)
== 1
)
return (
correct_num_tasks
and running == 1
and waiting_for_dep == 0
and waiting_for_execution == 9
and scheduled == 0
)
return True
wait_for_condition(verify)
print(list_tasks())
@ -2104,41 +2113,56 @@ def test_filter(shutdown_only):
assert alive_actor_id in result.output
def test_data_truncate(shutdown_only):
def test_data_truncate(shutdown_only, monkeypatch):
"""
Verify the data is properly truncated when there are too many entries to return.
"""
ray.init(num_cpus=16)
with monkeypatch.context() as m:
max_limit_data_source = 10
max_limit_api_server = 1000
m.setenv("RAY_MAX_LIMIT_FROM_API_SERVER", f"{max_limit_api_server}")
m.setenv("RAY_MAX_LIMIT_FROM_DATA_SOURCE", f"{max_limit_data_source}")
pgs = [ # noqa
ray.util.placement_group(bundles=[{"CPU": 0.001}]) for _ in range(MAX_LIMIT + 1)
]
runner = CliRunner()
with pytest.warns(UserWarning) as record:
result = runner.invoke(cli_list, ["placement-groups"])
assert (
f"{DEFAULT_LIMIT} ({MAX_LIMIT + 1} total) placement_groups are returned. "
f"{MAX_LIMIT + 1 - DEFAULT_LIMIT} entries have been truncated."
in record[0].message.args[0]
)
assert result.exit_code == 0
ray.init(num_cpus=16)
# Make sure users cannot specify higher limit than 10000.
with pytest.raises(ValueError):
list_placement_groups(limit=MAX_LIMIT + 1)
pgs = [ # noqa
ray.util.placement_group(bundles=[{"CPU": 0.001}])
for _ in range(max_limit_data_source + 1)
]
runner = CliRunner()
with pytest.warns(UserWarning) as record:
result = runner.invoke(cli_list, ["placement-groups"])
# result = list_placement_groups()
assert (
f"{max_limit_data_source} ({max_limit_data_source + 1} total) "
"placement_groups are retrieved from the data source. "
"1 entries have been truncated." in record[0].message.args[0]
)
assert result.exit_code == 0
# Make sure warning is not printed when truncation doesn't happen.
@ray.remote
class A:
def ready(self):
pass
# Make sure users cannot specify higher limit than MAX_LIMIT_FROM_API_SERVER
with pytest.raises(RayStateApiException):
list_placement_groups(limit=max_limit_api_server + 1)
a = A.remote()
ray.get(a.ready.remote())
# TODO(rickyyx): We should support error code or more granular errors from
# the server to the client so we could assert the specific type of error.
# assert (
# f"Given limit {max_limit_api_server+1} exceeds the supported "
# f"limit {max_limit_api_server}." in str(e)
# )
with pytest.warns(None) as record:
result = runner.invoke(cli_list, ["actors"])
assert len(record) == 0
# Make sure warning is not printed when truncation doesn't happen.
@ray.remote
class A:
def ready(self):
pass
a = A.remote()
ray.get(a.ready.remote())
with pytest.warns(None) as record:
result = runner.invoke(cli_list, ["actors"])
assert len(record) == 0
def test_detail(shutdown_only):