[Core][State Observability] Truncate warning message is incorrect when filter is used (#26801)

Signed-off-by: rickyyx rickyx@anyscale.com

# Why are these changes needed?
When we returned less/incomplete results to users, there could be 3 reasons:

Data being truncated at the data source (raylets -> API server)
Data being filtered at the API server
Data being limited at the API server
We are not distinguishing the those 3 scenarios, but we should. This is why we thought data being truncated when it's actually filtered/limited.

This PR distinguishes these scenarios and prompt warnings accordingly.

# Related issue number
Closes #26570
Closes #26923
This commit is contained in:
Ricky Xu 2022-07-25 23:31:49 -07:00 committed by GitHub
parent 65563e994b
commit 259473c221
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 216 additions and 94 deletions

View file

@ -18,6 +18,7 @@ from ray.dashboard.optional_utils import rest_response
from ray.dashboard.state_aggregator import StateAPIManager from ray.dashboard.state_aggregator import StateAPIManager
from ray.dashboard.utils import Change from ray.dashboard.utils import Change
from ray.experimental.state.common import ( from ray.experimental.state.common import (
RAY_MAX_LIMIT_FROM_API_SERVER,
ListApiOptions, ListApiOptions,
GetLogOptions, GetLogOptions,
SummaryApiOptions, SummaryApiOptions,
@ -166,6 +167,13 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
if req.query.get("limit") is not None if req.query.get("limit") is not None
else DEFAULT_LIMIT else DEFAULT_LIMIT
) )
if limit > RAY_MAX_LIMIT_FROM_API_SERVER:
raise ValueError(
f"Given limit {limit} exceeds the supported "
f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit."
)
timeout = int(req.query.get("timeout")) timeout = int(req.query.get("timeout"))
filter_keys = req.query.getall("filter_keys", []) filter_keys = req.query.getall("filter_keys", [])
filter_predicates = req.query.getall("filter_predicates", []) filter_predicates = req.query.getall("filter_predicates", [])

View file

@ -20,7 +20,7 @@ from ray.experimental.state.common import (
PlacementGroupState, PlacementGroupState,
RuntimeEnvState, RuntimeEnvState,
SummaryApiResponse, SummaryApiResponse,
MAX_LIMIT, RAY_MAX_LIMIT_FROM_API_SERVER,
SummaryApiOptions, SummaryApiOptions,
TaskSummaries, TaskSummaries,
StateSchema, StateSchema,
@ -51,7 +51,7 @@ GCS_QUERY_FAILURE_WARNING = (
) )
NODE_QUERY_FAILURE_WARNING = ( NODE_QUERY_FAILURE_WARNING = (
"Failed to query data from {type}. " "Failed to query data from {type}. "
"Queryed {total} {type} " "Queried {total} {type} "
"and {network_failures} {type} failed to reply. It is due to " "and {network_failures} {type} failed to reply. It is due to "
"(1) {type} is unexpectedly failed. " "(1) {type} is unexpectedly failed. "
"(2) {type} is overloaded. " "(2) {type} is overloaded. "
@ -202,14 +202,18 @@ class StateAPIManager:
message=message, fields_to_decode=["actor_id", "owner_id"] message=message, fields_to_decode=["actor_id", "owner_id"]
) )
result.append(data) result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, ActorState, option.detail) result = self._filter(result, option.filters, ActorState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["actor_id"]) result.sort(key=lambda entry: entry["actor_id"])
result = list(islice(result, option.limit)) result = list(islice(result, option.limit))
return ListApiResponse( return ListApiResponse(
result=result, result=result,
total=reply.total, total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse: async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
@ -234,15 +238,19 @@ class StateAPIManager:
fields_to_decode=["placement_group_id", "node_id"], fields_to_decode=["placement_group_id", "node_id"],
) )
result.append(data) result.append(data)
num_after_truncation = len(result)
result = self._filter( result = self._filter(
result, option.filters, PlacementGroupState, option.detail result, option.filters, PlacementGroupState, option.detail
) )
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["placement_group_id"]) result.sort(key=lambda entry: entry["placement_group_id"])
return ListApiResponse( return ListApiResponse(
result=list(islice(result, option.limit)), result=list(islice(result, option.limit)),
total=reply.total, total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse: async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
@ -263,15 +271,21 @@ class StateAPIManager:
data["node_ip"] = data["node_manager_address"] data["node_ip"] = data["node_manager_address"]
result.append(data) result.append(data)
total_nodes = len(result)
# No reason to truncate node because they are usually small.
num_after_truncation = len(result)
result = self._filter(result, option.filters, NodeState, option.detail) result = self._filter(result, option.filters, NodeState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["node_id"]) result.sort(key=lambda entry: entry["node_id"])
total_nodes = len(result)
result = list(islice(result, option.limit)) result = list(islice(result, option.limit))
return ListApiResponse( return ListApiResponse(
result=result, result=result,
# No reason to truncate node because they are usually small.
total=total_nodes, total=total_nodes,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse: async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
@ -296,13 +310,17 @@ class StateAPIManager:
data["ip"] = data["worker_address"]["ip_address"] data["ip"] = data["worker_address"]["ip_address"]
result.append(data) result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, WorkerState, option.detail) result = self._filter(result, option.filters, WorkerState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["worker_id"]) result.sort(key=lambda entry: entry["worker_id"])
result = list(islice(result, option.limit)) result = list(islice(result, option.limit))
return ListApiResponse( return ListApiResponse(
result=result, result=result,
total=reply.total, total=reply.total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
@ -320,6 +338,8 @@ class StateAPIManager:
result=result, result=result,
# TODO(sang): Support this. # TODO(sang): Support this.
total=len(result), total=len(result),
num_after_truncation=len(result),
num_filtered=len(result),
) )
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse: async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
@ -382,8 +402,9 @@ class StateAPIManager:
TaskStatus.RUNNING TaskStatus.RUNNING
].name ].name
result.append(data) result.append(data)
num_after_truncation = len(result)
result = self._filter(result, option.filters, TaskState, option.detail) result = self._filter(result, option.filters, TaskState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["task_id"]) result.sort(key=lambda entry: entry["task_id"])
result = list(islice(result, option.limit)) result = list(islice(result, option.limit))
@ -391,6 +412,8 @@ class StateAPIManager:
result=result, result=result,
partial_failure_warning=partial_failure_warning, partial_failure_warning=partial_failure_warning,
total=total_tasks, total=total_tasks,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse: async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
@ -471,7 +494,9 @@ class StateAPIManager:
"and `ray.init`." "and `ray.init`."
) )
num_after_truncation = len(result)
result = self._filter(result, option.filters, ObjectState, option.detail) result = self._filter(result, option.filters, ObjectState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
result.sort(key=lambda entry: entry["object_id"]) result.sort(key=lambda entry: entry["object_id"])
result = list(islice(result, option.limit)) result = list(islice(result, option.limit))
@ -479,6 +504,8 @@ class StateAPIManager:
result=result, result=result,
partial_failure_warning=partial_failure_warning, partial_failure_warning=partial_failure_warning,
total=total_objects, total=total_objects,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
warnings=callsite_warning, warnings=callsite_warning,
) )
@ -515,7 +542,7 @@ class StateAPIManager:
states = reply.runtime_env_states states = reply.runtime_env_states
for state in states: for state in states:
data = self._message_to_dict(message=state, fields_to_decode=[]) data = self._message_to_dict(message=state, fields_to_decode=[])
# Need to deseiralize this field. # Need to deserialize this field.
data["runtime_env"] = RuntimeEnv.deserialize( data["runtime_env"] = RuntimeEnv.deserialize(
data["runtime_env"] data["runtime_env"]
).to_dict() ).to_dict()
@ -535,8 +562,9 @@ class StateAPIManager:
partial_failure_warning = ( partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}" f"The returned data may contain incomplete result. {warning_msg}"
) )
num_after_truncation = len(result)
result = self._filter(result, option.filters, RuntimeEnvState, option.detail) result = self._filter(result, option.filters, RuntimeEnvState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic. # Sort to make the output deterministic.
def sort_func(entry): def sort_func(entry):
@ -556,12 +584,16 @@ class StateAPIManager:
result=result, result=result,
partial_failure_warning=partial_failure_warning, partial_failure_warning=partial_failure_warning,
total=total_runtime_envs, total=total_runtime_envs,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
) )
async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse: async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss. # For summary, try getting as many entries as possible to minimze data loss.
result = await self.list_tasks( result = await self.list_tasks(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[]) option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
) )
summary = StateSummary( summary = StateSummary(
node_id_to_summary={ node_id_to_summary={
@ -573,12 +605,15 @@ class StateAPIManager:
result=summary, result=summary,
partial_failure_warning=result.partial_failure_warning, partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings, warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
) )
async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse: async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss. # For summary, try getting as many entries as possible to minimze data loss.
result = await self.list_actors( result = await self.list_actors(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[]) option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
) )
summary = StateSummary( summary = StateSummary(
node_id_to_summary={ node_id_to_summary={
@ -590,12 +625,15 @@ class StateAPIManager:
result=summary, result=summary,
partial_failure_warning=result.partial_failure_warning, partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings, warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
) )
async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse: async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse:
# For summary, try getting as many entries as possible to minimze data loss. # For summary, try getting as many entries as possible to minimize data loss.
result = await self.list_objects( result = await self.list_objects(
option=ListApiOptions(timeout=option.timeout, limit=MAX_LIMIT, filters=[]) option=ListApiOptions(
timeout=option.timeout, limit=RAY_MAX_LIMIT_FROM_API_SERVER, filters=[]
)
) )
summary = StateSummary( summary = StateSummary(
node_id_to_summary={ node_id_to_summary={
@ -607,6 +645,7 @@ class StateAPIManager:
result=summary, result=summary,
partial_failure_warning=result.partial_failure_warning, partial_failure_warning=result.partial_failure_warning,
warnings=result.warnings, warnings=result.warnings,
num_after_truncation=result.num_after_truncation,
) )
def _message_to_dict( def _message_to_dict(

View file

@ -314,6 +314,12 @@ class StateApiClient(SubmissionClient):
def _print_api_warning(self, resource: StateResource, api_response: dict): def _print_api_warning(self, resource: StateResource, api_response: dict):
"""Print the API warnings. """Print the API warnings.
We print warnings for users:
1. when some data sources are not available
2. when results were truncated at the data source
3. when results were limited
4. when callsites not enabled for listing objects
Args: Args:
resource: Resource names, i.e. 'jobs', 'actors', 'nodes', resource: Resource names, i.e. 'jobs', 'actors', 'nodes',
see `StateResource` for details. see `StateResource` for details.
@ -324,16 +330,34 @@ class StateApiClient(SubmissionClient):
if warning_msgs: if warning_msgs:
warnings.warn(warning_msgs) warnings.warn(warning_msgs)
# Print warnings if data is truncated. # Print warnings if data is truncated at the data source.
data = api_response["result"] num_after_truncation = api_response["num_after_truncation"]
total = api_response["total"] total = api_response["total"]
if total > len(data): if total > num_after_truncation:
# NOTE(rickyyx): For now, there's not much users could do (neither can we),
# with hard truncation. Unless we allow users to set a higher
# `RAY_MAX_LIMIT_FROM_DATA_SOURCE`, the data will always be truncated at the
# data source.
warnings.warn( warnings.warn(
( (
f"{len(data)} ({total} total) {resource.value} " f"{num_after_truncation} ({total} total) {resource.value} "
f"are returned. {total - len(data)} entries have been truncated. " "are retrieved from the data source. "
"Use `--filter` to reduce the amount of data to return " f"{total - num_after_truncation} entries have been truncated. "
"or increase the limit by specifying`--limit`." f"Max of {num_after_truncation} entries are retrieved from data "
"source to prevent over-sized payloads."
),
)
# Print warnings if return data is limited at the API server due to
# limit enforced at the server side
num_filtered = api_response["num_filtered"]
data = api_response["result"]
if num_filtered > len(data):
warnings.warn(
(
f"{len(data)}/{num_filtered} {resource.value} returned. "
"Use `--filter` to reduce the amount of data to return or "
"setting a higher limit with `--limit` to see all data. "
), ),
) )

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass, field, fields
from enum import Enum, unique from enum import Enum, unique
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
from ray._private.ray_constants import env_integer
from ray.core.generated.common_pb2 import TaskType from ray.core.generated.common_pb2 import TaskType
from ray.dashboard.modules.job.common import JobInfo from ray.dashboard.modules.job.common import JobInfo
@ -12,7 +13,17 @@ logger = logging.getLogger(__name__)
DEFAULT_RPC_TIMEOUT = 30 DEFAULT_RPC_TIMEOUT = 30
DEFAULT_LIMIT = 100 DEFAULT_LIMIT = 100
DEFAULT_LOG_LIMIT = 1000 DEFAULT_LOG_LIMIT = 1000
MAX_LIMIT = 10000
# Max number of entries from API server to the client
RAY_MAX_LIMIT_FROM_API_SERVER = env_integer(
"RAY_MAX_LIMIT_FROM_API_SERVER", 10 * 1000
) # 10k
# Max number of entries from data sources (rest will be truncated at the
# data source, e.g. raylet)
RAY_MAX_LIMIT_FROM_DATA_SOURCE = env_integer(
"RAY_MAX_LIMIT_FROM_DATA_SOURCE", 10 * 1000
) # 10k
STATE_OBS_ALPHA_FEEDBACK_MSG = [ STATE_OBS_ALPHA_FEEDBACK_MSG = [
"\n==========ALPHA PREVIEW, FEEDBACK NEEDED ===============", "\n==========ALPHA PREVIEW, FEEDBACK NEEDED ===============",
@ -85,12 +96,6 @@ class ListApiOptions:
if self.filters is None: if self.filters is None:
self.filters = [] self.filters = []
if self.limit > MAX_LIMIT:
raise ValueError(
f"Given limit {self.limit} exceeds the supported "
f"limit {MAX_LIMIT}. Use a lower limit."
)
for filter in self.filters: for filter in self.filters:
_, filter_predicate, _ = filter _, filter_predicate, _ = filter
if filter_predicate != "=" and filter_predicate != "!=": if filter_predicate != "=" and filter_predicate != "!=":
@ -355,10 +360,30 @@ class RuntimeEnvState(StateSchema):
@dataclass(init=True) @dataclass(init=True)
class ListApiResponse: class ListApiResponse:
# Total number of the resource from the cluster. # NOTE(rickyyx): We currently perform hard truncation when querying
# Note that this value can be larger than `result` # resources which could have a large number (e.g. asking raylets for
# because `result` can be truncated. # the number of all objects).
# The returned of resources seen by the user will go through from the
# below funnel:
# - total
# | With truncation at the data source if the number of returned
# | resource exceeds `RAY_MAX_LIMIT_FROM_DATA_SOURCE`
# v
# - num_after_truncation
# | With filtering at the state API server
# v
# - num_filtered
# | With limiting,
# | set by min(`RAY_MAX_LIMIT_FROM_API_SERER`, <user-supplied limit>)
# v
# - len(result)
# Total number of the available resource from the cluster.
total: int total: int
# Number of resources returned by data sources after truncation
num_after_truncation: int
# Number of resources after filtering
num_filtered: int
# Returned data. None if no data is returned. # Returned data. None if no data is returned.
result: List[ result: List[
Union[ Union[
@ -602,17 +627,19 @@ class ObjectSummaries:
@dataclass(init=True) @dataclass(init=True)
class StateSummary: class StateSummary:
# Node ID -> summary per node # Node ID -> summary per node
# If the data is not required to be orgnized per node, it will contain # If the data is not required to be organized per node, it will contain
# a single key, "cluster". # a single key, "cluster".
node_id_to_summary: Dict[str, Union[TaskSummaries, ActorSummaries, ObjectSummaries]] node_id_to_summary: Dict[str, Union[TaskSummaries, ActorSummaries, ObjectSummaries]]
@dataclass(init=True) @dataclass(init=True)
class SummaryApiResponse: class SummaryApiResponse:
# Total number of the resource from the cluster. # Carried over from ListApiResponse
# Note that this value can be larger than `result` # We currently use list API for listing the resources
# because `result` can be truncated.
total: int total: int
# Carried over from ListApiResponse
# Number of resources returned by data sources after truncation
num_after_truncation: int
result: StateSummary = None result: StateSummary = None
partial_failure_warning: str = "" partial_failure_warning: str = ""
# A list of warnings to print. # A list of warnings to print.

View file

@ -40,7 +40,7 @@ from ray.core.generated.runtime_env_agent_pb2 import (
) )
from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub from ray.core.generated.runtime_env_agent_pb2_grpc import RuntimeEnvServiceStub
from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient
from ray.experimental.state.common import MAX_LIMIT from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE
from ray.experimental.state.exception import DataSourceUnavailable from ray.experimental.state.exception import DataSourceUnavailable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -209,7 +209,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None self, timeout: int = None, limit: int = None
) -> Optional[GetAllActorInfoReply]: ) -> Optional[GetAllActorInfoReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllActorInfoRequest(limit=limit) request = GetAllActorInfoRequest(limit=limit)
reply = await self._gcs_actor_info_stub.GetAllActorInfo( reply = await self._gcs_actor_info_stub.GetAllActorInfo(
@ -222,7 +222,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None self, timeout: int = None, limit: int = None
) -> Optional[GetAllPlacementGroupReply]: ) -> Optional[GetAllPlacementGroupReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllPlacementGroupRequest(limit=limit) request = GetAllPlacementGroupRequest(limit=limit)
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup( reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
@ -243,7 +243,7 @@ class StateDataSourceClient:
self, timeout: int = None, limit: int = None self, timeout: int = None, limit: int = None
) -> Optional[GetAllWorkerInfoReply]: ) -> Optional[GetAllWorkerInfoReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetAllWorkerInfoRequest(limit=limit) request = GetAllWorkerInfoRequest(limit=limit)
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo( reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
@ -274,7 +274,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetTasksInfoReply]: ) -> Optional[GetTasksInfoReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._raylet_stubs.get(node_id) stub = self._raylet_stubs.get(node_id)
if not stub: if not stub:
@ -290,7 +290,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetObjectsInfoReply]: ) -> Optional[GetObjectsInfoReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._raylet_stubs.get(node_id) stub = self._raylet_stubs.get(node_id)
if not stub: if not stub:
@ -307,7 +307,7 @@ class StateDataSourceClient:
self, node_id: str, timeout: int = None, limit: int = None self, node_id: str, timeout: int = None, limit: int = None
) -> Optional[GetRuntimeEnvsInfoReply]: ) -> Optional[GetRuntimeEnvsInfoReply]:
if not limit: if not limit:
limit = MAX_LIMIT limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
stub = self._runtime_env_agent_stub.get(node_id) stub = self._runtime_env_agent_stub.get(node_id)
if not stub: if not stub:

View file

@ -81,7 +81,6 @@ from ray.experimental.state.common import (
SupportedFilterType, SupportedFilterType,
TaskState, TaskState,
WorkerState, WorkerState,
MAX_LIMIT,
StateSchema, StateSchema,
state_column, state_column,
) )
@ -1764,42 +1763,52 @@ def test_list_actor_tasks(shutdown_only):
tasks = list_tasks() tasks = list_tasks()
# Actor.__init__: 1 finished # Actor.__init__: 1 finished
# Actor.call: 1 running, 9 waiting for execution (queued). # Actor.call: 1 running, 9 waiting for execution (queued).
correct_num_tasks = len(tasks) == 11 assert len(tasks) == 11
waiting_for_execution = len( assert (
list( len(
filter( list(
lambda task: task["scheduling_state"] == "WAITING_FOR_EXECUTION", filter(
tasks, lambda task: task["scheduling_state"]
== "WAITING_FOR_EXECUTION",
tasks,
)
) )
) )
== 9
) )
scheduled = len( assert (
list(filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks)) len(
) list(
waiting_for_dep = len( filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks)
list(
filter(
lambda task: task["scheduling_state"] == "WAITING_FOR_DEPENDENCIES",
tasks,
) )
) )
== 0
) )
running = len( assert (
list( len(
filter( list(
lambda task: task["scheduling_state"] == "RUNNING", filter(
tasks, lambda task: task["scheduling_state"]
== "WAITING_FOR_DEPENDENCIES",
tasks,
)
) )
) )
== 0
)
assert (
len(
list(
filter(
lambda task: task["scheduling_state"] == "RUNNING",
tasks,
)
)
)
== 1
) )
return ( return True
correct_num_tasks
and running == 1
and waiting_for_dep == 0
and waiting_for_execution == 9
and scheduled == 0
)
wait_for_condition(verify) wait_for_condition(verify)
print(list_tasks()) print(list_tasks())
@ -2104,41 +2113,56 @@ def test_filter(shutdown_only):
assert alive_actor_id in result.output assert alive_actor_id in result.output
def test_data_truncate(shutdown_only): def test_data_truncate(shutdown_only, monkeypatch):
""" """
Verify the data is properly truncated when there are too many entries to return. Verify the data is properly truncated when there are too many entries to return.
""" """
ray.init(num_cpus=16) with monkeypatch.context() as m:
max_limit_data_source = 10
max_limit_api_server = 1000
m.setenv("RAY_MAX_LIMIT_FROM_API_SERVER", f"{max_limit_api_server}")
m.setenv("RAY_MAX_LIMIT_FROM_DATA_SOURCE", f"{max_limit_data_source}")
pgs = [ # noqa ray.init(num_cpus=16)
ray.util.placement_group(bundles=[{"CPU": 0.001}]) for _ in range(MAX_LIMIT + 1)
]
runner = CliRunner()
with pytest.warns(UserWarning) as record:
result = runner.invoke(cli_list, ["placement-groups"])
assert (
f"{DEFAULT_LIMIT} ({MAX_LIMIT + 1} total) placement_groups are returned. "
f"{MAX_LIMIT + 1 - DEFAULT_LIMIT} entries have been truncated."
in record[0].message.args[0]
)
assert result.exit_code == 0
# Make sure users cannot specify higher limit than 10000. pgs = [ # noqa
with pytest.raises(ValueError): ray.util.placement_group(bundles=[{"CPU": 0.001}])
list_placement_groups(limit=MAX_LIMIT + 1) for _ in range(max_limit_data_source + 1)
]
runner = CliRunner()
with pytest.warns(UserWarning) as record:
result = runner.invoke(cli_list, ["placement-groups"])
# result = list_placement_groups()
assert (
f"{max_limit_data_source} ({max_limit_data_source + 1} total) "
"placement_groups are retrieved from the data source. "
"1 entries have been truncated." in record[0].message.args[0]
)
assert result.exit_code == 0
# Make sure warning is not printed when truncation doesn't happen. # Make sure users cannot specify higher limit than MAX_LIMIT_FROM_API_SERVER
@ray.remote with pytest.raises(RayStateApiException):
class A: list_placement_groups(limit=max_limit_api_server + 1)
def ready(self):
pass
a = A.remote() # TODO(rickyyx): We should support error code or more granular errors from
ray.get(a.ready.remote()) # the server to the client so we could assert the specific type of error.
# assert (
# f"Given limit {max_limit_api_server+1} exceeds the supported "
# f"limit {max_limit_api_server}." in str(e)
# )
with pytest.warns(None) as record: # Make sure warning is not printed when truncation doesn't happen.
result = runner.invoke(cli_list, ["actors"]) @ray.remote
assert len(record) == 0 class A:
def ready(self):
pass
a = A.remote()
ray.get(a.ready.remote())
with pytest.warns(None) as record:
result = runner.invoke(cli_list, ["actors"])
assert len(record) == 0
def test_detail(shutdown_only): def test_detail(shutdown_only):