ray/dashboard/state_aggregator.py
SangBin Cho 2bce07d4ce
[State API] List runtime env API (#24126)
This PR supports list runtime env API
2022-05-02 14:01:00 -07:00

256 lines
9.5 KiB
Python

import asyncio
import logging
from typing import List, Dict
from itertools import islice
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.memory_utils as memory_utils
from ray.dashboard.modules.job.common import JobInfo
from ray.experimental.state.common import (
filter_fields,
ActorState,
PlacementGroupState,
NodeState,
WorkerState,
TaskState,
ObjectState,
RuntimeEnvState,
ListApiOptions,
)
from ray.experimental.state.state_manager import StateDataSourceClient
from ray.runtime_env import RuntimeEnv
logger = logging.getLogger(__name__)
# TODO(sang): Move the class to state/state_manager.py.
# TODO(sang): Remove *State and replaces with Pydantic or protobuf
# (depending on API interface standardization).
class StateAPIManager:
"""A class to query states from data source, caches, and post-processes
the entries.
"""
def __init__(self, state_data_source_client: StateDataSourceClient):
self._client = state_data_source_client
@property
def data_source_client(self):
return self._client
async def list_actors(self, *, option: ListApiOptions) -> dict:
"""List all actor information from the cluster.
Returns:
{actor_id -> actor_data_in_dict}
actor_data_in_dict's schema is in ActorState
"""
reply = await self._client.get_all_actor_info(timeout=option.timeout)
result = []
for message in reply.actor_table_data:
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
data = filter_fields(data, ActorState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["actor_id"])
return {d["actor_id"]: d for d in islice(result, option.limit)}
async def list_placement_groups(self, *, option: ListApiOptions) -> dict:
"""List all placement group information from the cluster.
Returns:
{pg_id -> pg_data_in_dict}
pg_data_in_dict's schema is in PlacementGroupState
"""
reply = await self._client.get_all_placement_group_info(timeout=option.timeout)
result = []
for message in reply.placement_group_table_data:
data = self._message_to_dict(
message=message,
fields_to_decode=["placement_group_id"],
)
data = filter_fields(data, PlacementGroupState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["placement_group_id"])
return {d["placement_group_id"]: d for d in islice(result, option.limit)}
async def list_nodes(self, *, option: ListApiOptions) -> dict:
"""List all node information from the cluster.
Returns:
{node_id -> node_data_in_dict}
node_data_in_dict's schema is in NodeState
"""
reply = await self._client.get_all_node_info(timeout=option.timeout)
result = []
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
data = filter_fields(data, NodeState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["node_id"])
return {d["node_id"]: d for d in islice(result, option.limit)}
async def list_workers(self, *, option: ListApiOptions) -> dict:
"""List all worker information from the cluster.
Returns:
{worker_id -> worker_data_in_dict}
worker_data_in_dict's schema is in WorkerState
"""
reply = await self._client.get_all_worker_info(timeout=option.timeout)
result = []
for message in reply.worker_table_data:
data = self._message_to_dict(
message=message, fields_to_decode=["worker_id"]
)
data["worker_id"] = data["worker_address"]["worker_id"]
data = filter_fields(data, WorkerState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["worker_id"])
return {d["worker_id"]: d for d in islice(result, option.limit)}
def list_jobs(self, *, option: ListApiOptions) -> Dict[str, JobInfo]:
# TODO(sang): Support limit & timeout & async calls.
return self._client.get_job_info()
async def list_tasks(self, *, option: ListApiOptions) -> dict:
"""List all task information from the cluster.
Returns:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
replies = await asyncio.gather(
*[
self._client.get_task_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_raylet_ids()
]
)
result = []
for reply in replies:
tasks = reply.task_info_entries
for task in tasks:
data = self._message_to_dict(
message=task,
fields_to_decode=["task_id"],
)
data = filter_fields(data, TaskState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["task_id"])
return {d["task_id"]: d for d in islice(result, option.limit)}
async def list_objects(self, *, option: ListApiOptions) -> dict:
"""List all object information from the cluster.
Returns:
{object_id -> object_data_in_dict}
object_data_in_dict's schema is in ObjectState
"""
replies = await asyncio.gather(
*[
self._client.get_object_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_raylet_ids()
]
)
worker_stats = []
for reply in replies:
for core_worker_stat in reply.core_workers_stats:
# NOTE: Set preserving_proto_field_name=False here because
# `construct_memory_table` requires a dictionary that has
# modified protobuf name
# (e.g., workerId instead of worker_id) as a key.
worker_stats.append(
self._message_to_dict(
message=core_worker_stat,
fields_to_decode=["object_id"],
preserving_proto_field_name=False,
)
)
result = []
memory_table = memory_utils.construct_memory_table(worker_stats)
for entry in memory_table.table:
data = entry.as_dict()
# `construct_memory_table` returns object_ref field which is indeed
# object_id. We do transformation here.
# TODO(sang): Refactor `construct_memory_table`.
data["object_id"] = data["object_ref"]
del data["object_ref"]
data = filter_fields(data, ObjectState)
result.append(data)
# Sort to make the output deterministic.
result.sort(key=lambda entry: entry["object_id"])
return {d["object_id"]: d for d in islice(result, option.limit)}
async def list_runtime_envs(self, *, option: ListApiOptions) -> List[dict]:
"""List all runtime env information from the cluster.
Returns:
A list of runtime env information in the cluster.
The schema of returned "dict" is equivalent to the
`RuntimeEnvState` protobuf message.
We don't have id -> data mapping like other API because runtime env
doesn't have unique ids.
"""
replies = await asyncio.gather(
*[
self._client.get_runtime_envs_info(node_id, timeout=option.timeout)
for node_id in self._client.get_all_registered_agent_ids()
]
)
result = []
for node_id, reply in zip(self._client.get_all_registered_agent_ids(), replies):
states = reply.runtime_env_states
for state in states:
data = self._message_to_dict(message=state, fields_to_decode=[])
# Need to deseiralize this field.
data["runtime_env"] = RuntimeEnv.deserialize(
data["runtime_env"]
).to_dict()
data["node_id"] = node_id
data = filter_fields(data, RuntimeEnvState)
result.append(data)
# Sort to make the output deterministic.
def sort_func(entry):
# If creation time is not there yet (runtime env is failed
# to be created or not created yet, they are the highest priority.
# Otherwise, "bigger" creation time is coming first.
if "creation_time_ms" not in entry:
return float("inf")
elif entry["creation_time_ms"] is None:
return float("inf")
else:
return float(entry["creation_time_ms"])
result.sort(key=sort_func, reverse=True)
return list(islice(result, option.limit))
def _message_to_dict(
self,
*,
message,
fields_to_decode: List[str],
preserving_proto_field_name: bool = True,
) -> dict:
return dashboard_utils.message_to_dict(
message,
fields_to_decode,
including_default_value_fields=True,
preserving_proto_field_name=preserving_proto_field_name,
)