mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Core][State Observability] Truncate warning message is incorrect when filter is used (#26801)
Signed-off-by: rickyyx rickyx@anyscale.com # Why are these changes needed? When we returned less/incomplete results to users, there could be 3 reasons: Data being truncated at the data source (raylets -> API server) Data being filtered at the API server Data being limited at the API server We are not distinguishing the those 3 scenarios, but we should. This is why we thought data being truncated when it's actually filtered/limited. This PR distinguishes these scenarios and prompt warnings accordingly. # Related issue number Closes #26570 Closes #26923
This commit is contained in:
parent
65563e994b
commit
259473c221
6 changed files with 216 additions and 94 deletions
|
@ -18,6 +18,7 @@ from ray.dashboard.optional_utils import rest_response
|
|||
from ray.dashboard.state_aggregator import StateAPIManager
|
||||
from ray.dashboard.utils import Change
|
||||
from ray.experimental.state.common import (
|
||||
RAY_MAX_LIMIT_FROM_API_SERVER,
|
||||
ListApiOptions,
|
||||
GetLogOptions,
|
||||
SummaryApiOptions,
|
||||
|
@ -166,6 +167,13 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
|
|||
if req.query.get("limit") is not None
|
||||
else DEFAULT_LIMIT
|
||||
)
|
||||
|
||||
if limit > RAY_MAX_LIMIT_FROM_API_SERVER:
|
||||
raise ValueError(
|
||||
f"Given limit {limit} exceeds the supported "
|
||||
f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit."
|
||||
)
|
||||
|
||||
timeout = int(req.query.get("timeout"))
|
||||
filter_keys = req.query.getall("filter_keys", [])
|
||||
filter_predicates = req.query.getall("filter_predicates", [])
|
||||
|
|
|
@ -20,7 +20,7 @@ from ray.experimental.state.common import (
|
|||
PlacementGroupState,
|
||||
RuntimeEnvState,
|
||||
SummaryApiResponse,
|
||||
MAX_LIMIT,
|
||||
RAY_MAX_LIMIT_FROM_API_SERVER,
|
||||
SummaryApiOptions,
|
||||
TaskSummaries,
|
||||
StateSchema,
|
||||
|
@ -51,7 +51,7 @@ GCS_QUERY_FAILURE_WARNING = (
|
|||
)
|
||||
NODE_QUERY_FAILURE_WARNING = (
|
||||
"Failed to query data from {type}. "
|
||||
"Queryed {total} {type} "
|
||||
"Queried {total} {type} "
|
||||
"and {network_failures} {type} failed to reply. It is due to "
|
||||
"(1) {type} is unexpectedly failed. "
|
||||
"(2) {type} is overloaded. "
|
||||
|
@ -202,14 +202,18 @@ class StateAPIManager:
|
|||
message=message, fields_to_decode=["actor_id", "owner_id"]
|
||||
)
|
||||
result.append(data)
|
||||
|
||||
num_after_truncation = len(result)
|
||||
result = self._filter(result, option.filters, ActorState, option.detail)
|
||||
num_filtered = len(result)
|
||||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["actor_id"])
|
||||
result = list(islice(result, option.limit))
|
||||
return ListApiResponse(
|
||||
result=result,
|
||||
total=reply.total,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -234,15 +238,19 @@ class StateAPIManager:
|
|||
fields_to_decode=["placement_group_id", "node_id"],
|
||||
)
|
||||
result.append(data)
|
||||
num_after_truncation = len(result)
|
||||
|
||||
result = self._filter(
|
||||
result, option.filters, PlacementGroupState, option.detail
|
||||
)
|
||||
num_filtered = len(result)
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["placement_group_id"])
|
||||
return ListApiResponse(
|
||||
result=list(islice(result, option.limit)),
|
||||
total=reply.total,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -263,15 +271,21 @@ class StateAPIManager:
|
|||
data["node_ip"] = data["node_manager_address"]
|
||||
result.append(data)
|
||||
|
||||
total_nodes = len(result)
|
||||
# No reason to truncate node because they are usually small.
|
||||
num_after_truncation = len(result)
|
||||
|
||||
result = self._filter(result, option.filters, NodeState, option.detail)
|
||||
num_filtered = len(result)
|
||||
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["node_id"])
|
||||
total_nodes = len(result)
|
||||
result = list(islice(result, option.limit))
|
||||
return ListApiResponse(
|
||||
result=result,
|
||||
# No reason to truncate node because they are usually small.
|
||||
total=total_nodes,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -296,13 +310,17 @@ class StateAPIManager:
|
|||
data["ip"] = data["worker_address"]["ip_address"]
|
||||
result.append(data)
|
||||
|
||||
num_after_truncation = len(result)
|
||||
result = self._filter(result, option.filters, WorkerState, option.detail)
|
||||
num_filtered = len(result)
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["worker_id"])
|
||||
result = list(islice(result, option.limit))
|
||||
return ListApiResponse(
|
||||
result=result,
|
||||
total=reply.total,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -320,6 +338,8 @@ class StateAPIManager:
|
|||
result=result,
|
||||
# TODO(sang): Support this.
|
||||
total=len(result),
|
||||
num_after_truncation=len(result),
|
||||
num_filtered=len(result),
|
||||
)
|
||||
|
||||
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -382,8 +402,9 @@ class StateAPIManager:
|
|||
TaskStatus.RUNNING
|
||||
].name
|
||||
result.append(data)
|
||||
|
||||
num_after_truncation = len(result)
|
||||
result = self._filter(result, option.filters, TaskState, option.detail)
|
||||
num_filtered = len(result)
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["task_id"])
|
||||
result = list(islice(result, option.limit))
|
||||
|
@ -391,6 +412,8 @@ class StateAPIManager:
|
|||
result=result,
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
total=total_tasks,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
|
||||
|
@ -471,7 +494,9 @@ class StateAPIManager:
|
|||
"and `ray.init`."
|
||||
)
|
||||
|
||||
num_after_truncation = len(result)
|
||||
result = self._filter(result, option.filters, ObjectState, option.detail)
|
||||
num_filtered = len(result)
|
||||
# Sort to make the output deterministic.
|
||||
result.sort(key=lambda entry: entry["object_id"])
|
||||
result = list(islice(result, option.limit))
|
||||
|
@ -479,6 +504,8 @@ class StateAPIManager:
|
|||
result=result,
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
total=total_objects,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
warnings=callsite_warning,
|
||||
)
|
||||
|
||||
|
@ -515,7 +542,7 @@ class StateAPIManager:
|
|||
states = reply.runtime_env_states
|
||||
for state in states:
|
||||
data = self._message_to_dict(message=state, fields_to_decode=[])
|
||||
# Need to deseiralize this field.
|
||||
# Need to deserialize this field.
|
||||
data["runtime_env"] = RuntimeEnv.deserialize(
|
||||
data["runtime_env"]
|
||||
).to_dict()
|
||||
|
@ -535,8 +562,9 @@ class StateAPIManager:
|
|||
partial_failure_warning = (
|
||||
f"The returned data may contain incomplete result. {warning_msg}"
|
||||
)
|
||||
|
||||
num_after_truncation = len(result)
|
||||
result = self._filter(result, option.filters, RuntimeEnvState, option.detail)
|
||||
num_filtered = len(result)
|
||||
|
||||
# Sort to make the output deterministic.
|
||||
def sort_func(entry):
|
||||
|
@ -556,12 +584,16 @@ class StateAPIManager:
|
|||
result=result,
|
||||
partial_failure_warning=partial_failure_warning,
|
||||
total=total_runtime_envs,
|
||||
num_after_truncation=num_after_truncation,
|
||||
num_filtered=num_filtered,
|
||||
)
|
||||
|
||||
async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse:
|
||||
# For summary, try getting as many entries as possible to minimze data loss.
|
||||
result = await self.list_tasks(
|
||||
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
|
||||
option=ListApiOptions(
|
||||
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
|
||||
)
|
||||
)
|
||||
summary = StateSummary(
|
||||
node_id_to_summary={
|
||||
|
@ -573,12 +605,15 @@ class StateAPIManager:
|
|||
result=summary,
|
||||
partial_failure_warning=result.partial_failure_warning,
|
||||
warnings=result.warnings,
|
||||
num_after_truncation=result.num_after_truncation,
|
||||
)
|
||||
|
||||
async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse:
|
||||
# For summary, try getting as many entries as possible to minimze data loss.
|
||||
result = await self.list_actors(
|
||||
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
|
||||
option=ListApiOptions(
|
||||
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
|
||||
)
|
||||
)
|
||||
summary = StateSummary(
|
||||
node_id_to_summary={
|
||||
|
@ -590,12 +625,15 @@ class StateAPIManager:
|
|||
result=summary,
|
||||
partial_failure_warning=result.partial_failure_warning,
|
||||
warnings=result.warnings,
|
||||
num_after_truncation=result.num_after_truncation,
|
||||
)
|
||||
|
||||
async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse:
|
||||
# For summary, try getting as many entries as possible to minimze data loss.
|
||||
# For summary, try getting as many entries as possible to minimize data loss.
|
||||
result = await self.list_objects(
|
||||
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[])
|
||||
option=ListApiOptions(
|
||||
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
|
||||
)
|
||||
)
|
||||
summary = StateSummary(
|
||||
node_id_to_summary={
|
||||
|
@ -607,6 +645,7 @@ class StateAPIManager:
|
|||
result=summary,
|
||||
partial_failure_warning=result.partial_failure_warning,
|
||||
warnings=result.warnings,
|
||||
num_after_truncation=result.num_after_truncation,
|
||||
)
|
||||
|
||||
def _message_to_dict(
|
||||
|
|
|
@ -314,6 +314,12 @@ class StateApiClient(SubmissionClient):
|
|||
def _print_api_warning(self, resource: StateResource, api_response: dict):
|
||||
"""Print the API warnings.
|
||||
|
||||
We print warnings for users:
|
||||
1. when some data sources are not available
|
||||
2. when results were truncated at the data source
|
||||
3. when results were limited
|
||||
4. when callsites not enabled for listing objects
|
||||
|
||||
Args:
|
||||
resource: Resource names, i.e. 'jobs', 'actors', 'nodes',
|
||||
see `StateResource` for details.
|
||||
|
@ -324,16 +330,34 @@ class StateApiClient(SubmissionClient):
|
|||
if warning_msgs:
|
||||
warnings.warn(warning_msgs)
|
||||
|
||||
# Print warnings if data is truncated.
|
||||
data = api_response["result"]
|
||||
# Print warnings if data is truncated at the data source.
|
||||
num_after_truncation = api_response["num_after_truncation"]
|
||||
total = api_response["total"]
|
||||
if total > len(data):
|
||||
if total > num_after_truncation:
|
||||
# NOTE(rickyyx): For now, there's not much users could do (neither can we),
|
||||
# with hard truncation. Unless we allow users to set a higher
|
||||
# `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will always be truncated at the
|
||||
# data source.
|
||||
warnings.warn(
|
||||
(
|
||||
f"{len(data)} ({total} total) {resource.value} "
|
||||
f"are returned. {total - len(data)} entries have been truncated. "
|
||||
"Use `--filter` to reduce the amount of data to return "
|
||||
"or increase the limit by specifying`--limit`."
|
||||
f"{num_after_truncation} ({total} total) {resource.value} "
|
||||
"are retrieved from the data source. "
|
||||
f"{total - num_after_truncation} entries have been truncated. "
|
||||
f"Max of {num_after_truncation} entries are retrieved from data "
|
||||
"source to prevent over-sized payloads."
|
||||
),
|
||||
)
|
||||
|
||||
# Print warnings if return data is limited at the API server due to
|
||||
# limit enforced at the server side
|
||||
num_filtered = api_response["num_filtered"]
|
||||
data = api_response["result"]
|
||||
if num_filtered > len(data):
|
||||
warnings.warn(
|
||||
(
|
||||
f"{len(data)}/{num_filtered} {resource.value} returned. "
|
||||
"Use `--filter` to reduce the amount of data to return or "
|
||||
"setting a higher limit with `--limit` to see all data. "
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass, field, fields
|
|||
from enum import Enum, unique
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from ray._private.ray_constants import env_integer
|
||||
from ray.core.generated.common_pb2 import TaskType
|
||||
from ray.dashboard.modules.job.common import JobInfo
|
||||
|
||||
|
@ -12,7 +13,17 @@ logger = logging.getLogger(__name__)
|
|||
DEFAULT_RPC_TIMEOUT = 30
|
||||
DEFAULT_LIMIT = 100
|
||||
DEFAULT_LOG_LIMIT = 1000
|
||||
MAX_LIMIT = 10000
|
||||
|
||||
# Max number of entries from API server to the client
|
||||
RAY_MAX_LIMIT_FROM_API_SERVER = env_integer(
|
||||
"RAY_MAX_LIMIT_FROM_API_SERVER", 10 * 1000
|
||||
) # 10k
|
||||
|
||||
# Max number of entries from data sources (rest will be truncated at the
|
||||
# data source, e.g. raylet)
|
||||
RAY_MAX_LIMIT_FROM_DATA_SOURCE = env_integer(
|
||||
"RAY_MAX_LIMIT_FROM_DATA_SOURCE", 10 * 1000
|
||||
) # 10k
|
||||
|
||||
STATE_OBS_ALPHA_FEEDBACK_MSG = [
|
||||
"\n==========ALPHA PREVIEW, FEEDBACK NEEDED ===============",
|
||||
|
@ -85,12 +96,6 @@ class ListApiOptions:
|
|||
if self.filters is None:
|
||||
self.filters = []
|
||||
|
||||
if self.limit > MAX_LIMIT:
|
||||
raise ValueError(
|
||||
f"Given limit {self.limit} exceeds the supported "
|
||||
f"limit {MAX_LIMIT}. Use a lower limit."
|
||||
)
|
||||
|
||||
for filter in self.filters:
|
||||
_, filter_predicate, _ = filter
|
||||
if filter_predicate != "=" and filter_predicate != "!=":
|
||||
|
@ -355,10 +360,30 @@ class RuntimeEnvState(StateSchema):
|
|||
|
||||
@dataclass(init=True)
|
||||
class ListApiResponse:
|
||||
# Total number of the resource from the cluster.
|
||||
# Note that this value can be larger than `result`
|
||||
# because `result` can be truncated.
|
||||
# NOTE(rickyyx): We currently perform hard truncation when querying
|
||||
# resources which could have a large number (e.g. asking raylets for
|
||||
# the number of all objects).
|
||||
# The returned of resources seen by the user will go through from the
|
||||
# below funnel:
|
||||
# - total
|
||||
# | With truncation at the data source if the number of returned
|
||||
# | resource exceeds `RAY_MAX_LIMIT_FROM_DATA_SOURCE`
|
||||
# v
|
||||
# - num_after_truncation
|
||||
# | With filtering at the state API server
|
||||
# v
|
||||
# - num_filtered
|
||||
# | With limiting,
|
||||
# | set by min(`RAY_MAX_LIMIT_FROM_API_SERER`, <user-supplied limit>)
|
||||
# v
|
||||
# - len(result)
|
||||
|
||||
# Total number of the available resource from the cluster.
|
||||
total: int
|
||||
# Number of resources returned by data sources after truncation
|
||||
num_after_truncation: int
|
||||
# Number of resources after filtering
|
||||
num_filtered: int
|
||||
# Returned data. None if no data is returned.
|
||||
result: List[
|
||||
Union[
|
||||
|
@ -602,17 +627,19 @@ class ObjectSummaries:
|
|||
@dataclass(init=True)
|
||||
class StateSummary:
|
||||
# Node ID -> summary per node
|
||||
# If the data is not required to be orgnized per node, it will contain
|
||||
# If the data is not required to be organized per node, it will contain
|
||||
# a single key, "cluster".
|
||||
node_id_to_summary: Dict[str, Union[TaskSummaries, ActorSummaries, ObjectSummaries]]
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class SummaryApiResponse:
|
||||
# Total number of the resource from the cluster.
|
||||
# Note that this value can be larger than `result`
|
||||
# because `result` can be truncated.
|
||||
# Carried over from ListApiResponse
|
||||
# We currently use list API for listing the resources
|
||||
total: int
|
||||
# Carried over from ListApiResponse
|
||||
# Number of resources returned by data sources after truncation
|
||||
num_after_truncation: int
|
||||
result: StateSummary = None
|
||||
partial_failure_warning: str = ""
|
||||
# A list of warnings to print.
|
||||
|
|
|
@ -40,7 +40,7 @@ from ray.core.generated.runtime_env_agent_pb2 import (
|
|||
)
|
||||
from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub
|
||||
from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient
|
||||
from ray.experimental.state.common import MAX_LIMIT
|
||||
from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
from ray.experimental.state.exception import DataSourceUnavailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -209,7 +209,7 @@ class StateDataSourceClient:
|
|||
self, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetAllActorInfoReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
request = GetAllActorInfoRequest(limit=limit)
|
||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
||||
|
@ -222,7 +222,7 @@ class StateDataSourceClient:
|
|||
self, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetAllPlacementGroupReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
request = GetAllPlacementGroupRequest(limit=limit)
|
||||
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
|
||||
|
@ -243,7 +243,7 @@ class StateDataSourceClient:
|
|||
self, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetAllWorkerInfoReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
request = GetAllWorkerInfoRequest(limit=limit)
|
||||
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
|
||||
|
@ -274,7 +274,7 @@ class StateDataSourceClient:
|
|||
self, node_id: str, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetTasksInfoReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
stub = self._raylet_stubs.get(node_id)
|
||||
if not stub:
|
||||
|
@ -290,7 +290,7 @@ class StateDataSourceClient:
|
|||
self, node_id: str, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetObjectsInfoReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
stub = self._raylet_stubs.get(node_id)
|
||||
if not stub:
|
||||
|
@ -307,7 +307,7 @@ class StateDataSourceClient:
|
|||
self, node_id: str, timeout: int = None, limit: int = None
|
||||
) -> Optional[GetRuntimeEnvsInfoReply]:
|
||||
if not limit:
|
||||
limit = MAX_LIMIT
|
||||
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
|
||||
|
||||
stub = self._runtime_env_agent_stub.get(node_id)
|
||||
if not stub:
|
||||
|
|
|
@ -81,7 +81,6 @@ from ray.experimental.state.common import (
|
|||
SupportedFilterType,
|
||||
TaskState,
|
||||
WorkerState,
|
||||
MAX_LIMIT,
|
||||
StateSchema,
|
||||
state_column,
|
||||
)
|
||||
|
@ -1764,42 +1763,52 @@ def test_list_actor_tasks(shutdown_only):
|
|||
tasks = list_tasks()
|
||||
# Actor.__init__: 1 finished
|
||||
# Actor.call: 1 running, 9 waiting for execution (queued).
|
||||
correct_num_tasks = len(tasks) == 11
|
||||
waiting_for_execution = len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"] == "WAITING_FOR_EXECUTION",
|
||||
tasks,
|
||||
assert len(tasks) == 11
|
||||
assert (
|
||||
len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"]
|
||||
== "WAITING_FOR_EXECUTION",
|
||||
tasks,
|
||||
)
|
||||
)
|
||||
)
|
||||
== 9
|
||||
)
|
||||
scheduled = len(
|
||||
list(filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks))
|
||||
)
|
||||
waiting_for_dep = len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"] == "WAITING_FOR_DEPENDENCIES",
|
||||
tasks,
|
||||
assert (
|
||||
len(
|
||||
list(
|
||||
filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks)
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
running = len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"] == "RUNNING",
|
||||
tasks,
|
||||
assert (
|
||||
len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"]
|
||||
== "WAITING_FOR_DEPENDENCIES",
|
||||
tasks,
|
||||
)
|
||||
)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
list(
|
||||
filter(
|
||||
lambda task: task["scheduling_state"] == "RUNNING",
|
||||
tasks,
|
||||
)
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
|
||||
return (
|
||||
correct_num_tasks
|
||||
and running == 1
|
||||
and waiting_for_dep == 0
|
||||
and waiting_for_execution == 9
|
||||
and scheduled == 0
|
||||
)
|
||||
return True
|
||||
|
||||
wait_for_condition(verify)
|
||||
print(list_tasks())
|
||||
|
@ -2104,41 +2113,56 @@ def test_filter(shutdown_only):
|
|||
assert alive_actor_id in result.output
|
||||
|
||||
|
||||
def test_data_truncate(shutdown_only):
|
||||
def test_data_truncate(shutdown_only, monkeypatch):
|
||||
"""
|
||||
Verify the data is properly truncated when there are too many entries to return.
|
||||
"""
|
||||
ray.init(num_cpus=16)
|
||||
with monkeypatch.context() as m:
|
||||
max_limit_data_source = 10
|
||||
max_limit_api_server = 1000
|
||||
m.setenv("RAY_MAX_LIMIT_FROM_API_SERVER", f"{max_limit_api_server}")
|
||||
m.setenv("RAY_MAX_LIMIT_FROM_DATA_SOURCE", f"{max_limit_data_source}")
|
||||
|
||||
pgs = [ # noqa
|
||||
ray.util.placement_group(bundles=[{"CPU": 0.001}]) for _ in range(MAX_LIMIT + 1)
|
||||
]
|
||||
runner = CliRunner()
|
||||
with pytest.warns(UserWarning) as record:
|
||||
result = runner.invoke(cli_list, ["placement-groups"])
|
||||
assert (
|
||||
f"{DEFAULT_LIMIT} ({MAX_LIMIT + 1} total) placement_groups are returned. "
|
||||
f"{MAX_LIMIT + 1 - DEFAULT_LIMIT} entries have been truncated."
|
||||
in record[0].message.args[0]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
ray.init(num_cpus=16)
|
||||
|
||||
# Make sure users cannot specify higher limit than 10000.
|
||||
with pytest.raises(ValueError):
|
||||
list_placement_groups(limit=MAX_LIMIT + 1)
|
||||
pgs = [ # noqa
|
||||
ray.util.placement_group(bundles=[{"CPU": 0.001}])
|
||||
for _ in range(max_limit_data_source + 1)
|
||||
]
|
||||
runner = CliRunner()
|
||||
with pytest.warns(UserWarning) as record:
|
||||
result = runner.invoke(cli_list, ["placement-groups"])
|
||||
# result = list_placement_groups()
|
||||
assert (
|
||||
f"{max_limit_data_source} ({max_limit_data_source + 1} total) "
|
||||
"placement_groups are retrieved from the data source. "
|
||||
"1 entries have been truncated." in record[0].message.args[0]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Make sure warning is not printed when truncation doesn't happen.
|
||||
@ray.remote
|
||||
class A:
|
||||
def ready(self):
|
||||
pass
|
||||
# Make sure users cannot specify higher limit than MAX_LIMIT_FROM_API_SERVER
|
||||
with pytest.raises(RayStateApiException):
|
||||
list_placement_groups(limit=max_limit_api_server + 1)
|
||||
|
||||
a = A.remote()
|
||||
ray.get(a.ready.remote())
|
||||
# TODO(rickyyx): We should support error code or more granular errors from
|
||||
# the server to the client so we could assert the specific type of error.
|
||||
# assert (
|
||||
# f"Given limit {max_limit_api_server+1} exceeds the supported "
|
||||
# f"limit {max_limit_api_server}." in str(e)
|
||||
# )
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
result = runner.invoke(cli_list, ["actors"])
|
||||
assert len(record) == 0
|
||||
# Make sure warning is not printed when truncation doesn't happen.
|
||||
@ray.remote
|
||||
class A:
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
a = A.remote()
|
||||
ray.get(a.ready.remote())
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
result = runner.invoke(cli_list, ["actors"])
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
def test_detail(shutdown_only):
|
||||
|
|
Loading…
Add table
Reference in a new issue