ray/dashboard/state_aggregator.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

563 lines
21 KiB
Python
Raw Normal View History

import asyncio
import logging
from dataclasses import asdict, fields
from itertools import islice
from typing import List, Tuple
import ray.dashboard.memory_utils as memory_utils
import ray.dashboard.utils as dashboard_utils
from ray._private.utils import binary_to_hex
from ray.core.generated.common_pb2 import TaskStatus
from ray.experimental.state.common import (
ActorState,
ListApiOptions,
ListApiResponse,
NodeState,
ObjectState,
PlacementGroupState,
RuntimeEnvState,
SummaryApiResponse,
DEFAULT_LIMIT,
SummaryApiOptions,
TaskSummaries,
StateSchema,
SupportedFilterType,
TaskState,
WorkerState,
filter_fields,
StateSummary,
ActorSummaries,
ObjectSummaries,
PredicateType,
)
from ray.experimental.state.state_manager import (
DataSourceUnavailable,
StateDataSourceClient,
)
from ray.runtime_env import RuntimeEnv
logger = logging.getLogger(__name__)
GCS_QUERY_FAILURE_WARNING = (
"Failed to query data from GCS. It is due to "
"(1) GCS is unexpectedly failed. "
"(2) GCS is overloaded. "
"(3) There's an unexpected network issue. "
"Please check the gcs_server.out log to find the root cause."
)
NODE_QUERY_FAILURE_WARNING = (
"Failed to query data from {type}. "
"Queryed {total} {type} "
"and {network_failures} {type} failed to reply. It is due to "
"(1) {type} is unexpectedly failed. "
"(2) {type} is overloaded. "
"(3) There's an unexpected network issue. Please check the "
"{log_command} to find the root cause."
)
def _convert_filters_type(
filter: List[Tuple[str, PredicateType, SupportedFilterType]],
schema: StateSchema,
) -> List[Tuple[str, SupportedFilterType]]:
"""Convert the given filter's type to SupportedFilterType.
This method is necessary because click can only accept a single type
for its tuple (which is string in this case).
Args:
filter: A list of filter which is a tuple of (key, val).
schema: The state schema. It is used to infer the type of the column for filter.
Returns:
A new list of filters with correctly types that match the schema.
"""
new_filter = []
schema = {field.name: field.type for field in fields(schema)}
for col, predicate, val in filter:
if col in schema:
column_type = schema[col]
if isinstance(val, column_type):
# Do nothing.
pass
elif column_type is int:
try:
val = int(val)
except ValueError:
raise ValueError(
f"Invalid filter `--filter {col} {val}` for a int type "
"column. Please provide an integer filter "
f"`--filter {col} [int]`"
)
elif column_type is float:
try:
val = float(val)
except ValueError:
raise ValueError(
f"Invalid filter `--filter {col} {val}` for a float "
"type column. Please provide an integer filter "
f"`--filter {col} [float]`"
)
elif column_type is bool:
# Without this, "False" will become True.
if val == "False" or val == "false" or val == "0":
val = False
elif val == "True" or val == "true" or val == "1":
val = True
else:
raise ValueError(
f"Invalid filter `--filter {col} {val}` for a boolean "
"type column. Please provide "
f"`--filter {col} [True|true|1]` for True or "
f"`--filter {col} [False|false|0]` for False."
)
new_filter.append((col, predicate, val))
return new_filter
# TODO(sang): Move the class to state/state_manager.py.
# TODO(sang): Remove *State and replaces with Pydantic or protobuf.
# (depending on API interface standardization).
class StateAPIManager:
"""A class to query states from data source, caches, and post-processes
the entries.
"""
def __init__(self, state_data_source_client: StateDataSourceClient):
self._client = state_data_source_client
@property
def data_source_client(self):
return self._client
def _filter(
self,
data: List[dict],
filters: List[Tuple[str, SupportedFilterType]],
state_dataclass: StateSchema,
) -> List[dict]:
"""Return the filtered data given filters.
Args:
data: A list of state data.
filters: A list of KV tuple to filter data (key, val). The data is filtered
if data[key] != val.
state_dataclass: The state schema.
Returns:
A list of filtered state data in dictionary. Each state data's
unncessary columns are filtered by the given state_dataclass schema.
"""
filters = _convert_filters_type(filters, state_dataclass)
result = []
for datum in data:
match = True
for filter_column, filter_predicate, filter_value in filters:
filterable_columns = state_dataclass.filterable_columns()
if filter_column not in filterable_columns:
raise ValueError(
f"The given filter column {filter_column} is not supported. "
f"Supported filter columns: {filterable_columns}"
)
if filter_predicate == "=":
match = datum[filter_column] == filter_value
elif filter_predicate == "!=":
match = datum[filter_column] != filter_value
else:
raise ValueError(
f"Unsupported filter predicate {filter_predicate} is given. "
"Available predicates: =, !=."
)
if not match:
break
if match:
result.append(filter_fields(datum, state_dataclass))
return result
async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all actor information from the cluster.
Returns:
{actor_id -> actor_data_in_dict}
actor_data_in_dict's schema is in ActorState
"""
try:
reply = await self._client.get_all_actor_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.actor_table_data:
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
result.append(data)
result = self._filter(result, option.filters, ActorState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["actor_id"])
return ListApiResponse(result=list(islice(result, option.limit)))
async def list_placement_groups(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all placement group information from the cluster.
Returns:
{pg_id -> pg_data_in_dict}
pg_data_in_dict's schema is in PlacementGroupState
"""
try:
reply = await self._client.get_all_placement_group_info(
timeout=option.timeout
)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.placement_group_table_data:
data = self._message_to_dict(
message=message,
fields_to_decode=["placement_group_id"],
)
result.append(data)
result = self._filter(result, option.filters, PlacementGroupState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["placement_group_id"])
return ListApiResponse(result=list(islice(result, option.limit)))
async def list_nodes(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all node information from the cluster.
Returns:
{node_id -> node_data_in_dict}
node_data_in_dict's schema is in NodeState
"""
try:
reply = await self._client.get_all_node_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
data["node_ip"] = data["node_manager_address"]
data = filter_fields(data, NodeState)
result.append(data)
result = self._filter(result, option.filters, NodeState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["node_id"])
return ListApiResponse(result=list(islice(result, option.limit)))
async def list_workers(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all worker information from the cluster.
Returns:
{worker_id -> worker_data_in_dict}
worker_data_in_dict's schema is in WorkerState
"""
try:
reply = await self._client.get_all_worker_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
result = []
for message in reply.worker_table_data:
data = self._message_to_dict(
message=message, fields_to_decode=["worker_id"]
)
data["worker_id"] = data["worker_address"]["worker_id"]
result.append(data)
result = self._filter(result, option.filters, WorkerState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["worker_id"])
return ListApiResponse(result=list(islice(result, option.limit)))
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
# TODO(sang): Support limit & timeout & async calls.
try:
result = []
job_info = self._client.get_job_info()
for job_id, data in job_info.items():
data = asdict(data)
data["job_id"] = job_id
result.append(data)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)
return ListApiResponse(result=result)
async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all task information from the cluster.
Returns:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
raylet_ids = self._client.get_all_registered_raylet_ids()
replies = await asyncio.gather(
*[
self._client.get_task_info(node_id, timeout=option.timeout)
for node_id in raylet_ids
],
return_exceptions=True,
)
unresponsive_nodes = 0
running_task_id = set()
successful_replies = []
for reply in replies:
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
successful_replies.append(reply)
for task_id in reply.running_task_ids:
running_task_id.add(binary_to_hex(task_id))
partial_failure_warning = None
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="raylet",
total=len(raylet_ids),
network_failures=unresponsive_nodes,
log_command="raylet.out",
)
if unresponsive_nodes == len(raylet_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
result = []
for reply in successful_replies:
assert not isinstance(reply, Exception)
tasks = reply.owned_task_info_entries
for task in tasks:
data = self._message_to_dict(
message=task,
fields_to_decode=["task_id"],
)
if data["task_id"] in running_task_id:
data["scheduling_state"] = TaskStatus.DESCRIPTOR.values_by_number[
TaskStatus.RUNNING
].name
result.append(data)
result = self._filter(result, option.filters, TaskState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["task_id"])
return ListApiResponse(
result=list(islice(result, option.limit)),
partial_failure_warning=partial_failure_warning,
)
async def list_objects(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all object information from the cluster.
Returns:
{object_id -> object_data_in_dict}
object_data_in_dict's schema is in ObjectState
"""
raylet_ids = self._client.get_all_registered_raylet_ids()
replies = await asyncio.gather(
*[
self._client.get_object_info(node_id, timeout=option.timeout)
for node_id in raylet_ids
],
return_exceptions=True,
)
unresponsive_nodes = 0
worker_stats = []
for reply, _ in zip(replies, raylet_ids):
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
for core_worker_stat in reply.core_workers_stats:
# NOTE: Set preserving_proto_field_name=False here because
# `construct_memory_table` requires a dictionary that has
# modified protobuf name
# (e.g., workerId instead of worker_id) as a key.
worker_stats.append(
self._message_to_dict(
message=core_worker_stat,
fields_to_decode=["object_id"],
preserving_proto_field_name=False,
)
)
partial_failure_warning = None
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="raylet",
total=len(raylet_ids),
network_failures=unresponsive_nodes,
log_command="raylet.out",
)
if unresponsive_nodes == len(raylet_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
result = []
memory_table = memory_utils.construct_memory_table(worker_stats)
for entry in memory_table.table:
data = entry.as_dict()
# `construct_memory_table` returns object_ref field which is indeed
# object_id. We do transformation here.
# TODO(sang): Refactor `construct_memory_table`.
data["object_id"] = data["object_ref"]
del data["object_ref"]
result.append(data)
result = self._filter(result, option.filters, ObjectState)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["object_id"])
return ListApiResponse(
result=list(islice(result, option.limit)),
partial_failure_warning=partial_failure_warning,
)
async def list_runtime_envs(self, *, option: ListApiOptions) -> ListApiResponse:
"""List all runtime env information from the cluster.
Returns:
A list of runtime env information in the cluster.
The schema of returned "dict" is equivalent to the
`RuntimeEnvState` protobuf message.
We don't have id -> data mapping like other API because runtime env
doesn't have unique ids.
"""
agent_ids = self._client.get_all_registered_agent_ids()
replies = await asyncio.gather(
*[
self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
for node_id in agent_ids
],
return_exceptions=True,
)
result = []
unresponsive_nodes = 0
for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply
states = reply.runtime_env_states
for state in states:
data = self._message_to_dict(message=state, fields_to_decode=[])
# Need to deseiralize this field.
data["runtime_env"] = RuntimeEnv.deserialize(
data["runtime_env"]
).to_dict()
data["node_id"] = node_id
result.append(data)
partial_failure_warning = None
if len(agent_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="agent",
total=len(agent_ids),
network_failures=unresponsive_nodes,
log_command="dashboard_agent.log",
)
if unresponsive_nodes == len(agent_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
result = self._filter(result, option.filters, RuntimeEnvState)
# Sort to make the output deterministic.
def sort_func(entry):
# If creation time is not there yet (runtime env is failed
# to be created or not created yet, they are the highest priority.
# Otherwise, "bigger" creation time is coming first.
if "creation_time_ms" not in entry:
return float("inf")
elif entry["creation_time_ms"] is None:
return float("inf")
else:
return float(entry["creation_time_ms"])
result.sort(key=sort_func, reverse=True)
return ListApiResponse(
result=list(islice(result, option.limit)),
partial_failure_warning=partial_failure_warning,
)
async def summarize_tasks(self, option: SummaryApiOptions) -> SummaryApiResponse:
result = await self.list_tasks(
option=ListApiOptions(
timeout=option.timeout, limit=DEFAULT_LIMIT, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
"cluster": TaskSummaries.to_summary(tasks=result.result)
}
)
return SummaryApiResponse(
result=summary, partial_failure_warning=result.partial_failure_warning
)
async def summarize_actors(self, option: SummaryApiOptions) -> SummaryApiResponse:
result = await self.list_actors(
option=ListApiOptions(
timeout=option.timeout, limit=DEFAULT_LIMIT, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
"cluster": ActorSummaries.to_summary(actors=result.result)
}
)
return SummaryApiResponse(
result=summary, partial_failure_warning=result.partial_failure_warning
)
async def summarize_objects(self, option: SummaryApiOptions) -> SummaryApiResponse:
result = await self.list_objects(
option=ListApiOptions(
timeout=option.timeout, limit=DEFAULT_LIMIT, filters=[]
)
)
summary = StateSummary(
node_id_to_summary={
"cluster": ObjectSummaries.to_summary(objects=result.result)
}
)
return SummaryApiResponse(
result=summary, partial_failure_warning=result.partial_failure_warning
)
def _message_to_dict(
self,
*,
message,
fields_to_decode: List[str],
preserving_proto_field_name: bool = True,
) -> dict:
return dashboard_utils.message_to_dict(
message,
fields_to_decode,
including_default_value_fields=True,
preserving_proto_field_name=preserving_proto_field_name,
)