ray/dashboard/state_aggregator.py
SangBin Cho 51a4a1a802
[State Observability] Basic functionality for centralized data (#23744)
Support listing actor/pg/job/node/workers

Design doc: https://docs.google.com/document/d/1IeEsJOiurg-zctOcBjY-tQVbsCmURFSnUCTkx_4a7Cw/edit#heading=h.9ub9e6yvu9p2

Note that this PR doesn't contain any output except ids. I will update them in the follow-up PRs.
2022-04-14 07:33:18 -07:00

96 lines
3.3 KiB
Python

import logging
import ray.dashboard.utils as dashboard_utils
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.experimental.state.common import (
filter_fields,
ActorState,
PlacementGroupState,
NodeState,
WorkerState,
)
logger = logging.getLogger(__name__)
DEFAULT_RPC_TIMEOUT = 30
# TODO(sang): Add error handling.
class GcsStateAggregator:
def __init__(self, gcs_channel):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)
self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
gcs_channel
)
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
gcs_channel
)
async def get_actors(self) -> dict:
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
result = {}
for message in reply.actor_table_data:
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
data = filter_fields(data, ActorState)
result[data["actor_id"]] = data
return result
async def get_placement_groups(self) -> dict:
request = gcs_service_pb2.GetAllPlacementGroupRequest()
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
request, timeout=DEFAULT_RPC_TIMEOUT
)
result = {}
logger.error(reply)
for message in reply.placement_group_table_data:
data = self._message_to_dict(
message=message,
fields_to_decode=["placement_group_id"],
)
data = filter_fields(data, PlacementGroupState)
result[data["placement_group_id"]] = data
return result
async def get_nodes(self) -> dict:
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
result = {}
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
data = filter_fields(data, NodeState)
result[data["node_id"]] = data
return result
async def get_workers(self) -> dict:
request = gcs_service_pb2.GetAllWorkerInfoRequest()
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
result = {}
for message in reply.worker_table_data:
data = self._message_to_dict(
message=message, fields_to_decode=["worker_id"]
)
data["worker_id"] = data["worker_address"]["worker_id"]
data = filter_fields(data, WorkerState)
result[data["worker_id"]] = data
return result
def _message_to_dict(self, *, message, fields_to_decode) -> dict:
return dashboard_utils.message_to_dict(
message,
fields_to_decode,
including_default_value_fields=True,
preserving_proto_field_name=True,
)