[State Observability] Tasks and Objects API (#23912)

This PR implements ray list tasks and ray list objects APIs.

NOTE: You can ignore the merge conflict for now. It is because the first PR was reverted. There's a fix PR open now.
This commit is contained in:
SangBin Cho 2022-04-22 10:45:03 +09:00 committed by GitHub
parent f500997a65
commit 30ab5458a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 1116 additions and 51 deletions

View file

@ -23,7 +23,8 @@ from ray._private.gcs_pubsub import (
)
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.utils import async_loop_forever
from ray.dashboard.state_aggregator import GcsStateAggregator
from ray.dashboard.state_aggregator import StateAPIManager
from ray.experimental.state.state_manager import StateDataSourceClient
logger = logging.getLogger(__name__)
@ -87,7 +88,7 @@ class DashboardHead:
self.temp_dir = temp_dir
self.session_dir = session_dir
self.aiogrpc_gcs_channel = None
self.gcs_state_aggregator = None
self.state_aggregator = None
self.gcs_error_subscriber = None
self.gcs_log_subscriber = None
self.ip = ray.util.get_node_ip_address()
@ -176,7 +177,9 @@ class DashboardHead:
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
self.gcs_state_aggregator = GcsStateAggregator(self.aiogrpc_gcs_channel)
self.state_aggregator = StateAPIManager(
StateDataSourceClient(self.aiogrpc_gcs_channel)
)
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)

View file

@ -230,7 +230,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/v0/actors")
async def get_actors(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.gcs_state_aggregator.get_actors()
data = await self._dashboard_head.state_aggregator.get_actors()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)

View file

@ -322,7 +322,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/v0/nodes")
async def get_nodes(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.gcs_state_aggregator.get_nodes()
data = await self._dashboard_head.state_aggregator.get_nodes()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)

View file

View file

@ -0,0 +1,35 @@
import logging
import aiohttp.web
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response
logger = logging.getLogger(__name__)
routes = dashboard_optional_utils.ClassMethodRouteTable
class ObjectHead(dashboard_utils.DashboardHeadModule):
"""Module to obtain object information of the ray cluster."""
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
@routes.get("/api/v0/objects")
async def get_objects(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.state_aggregator.get_objects()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)
async def run(self, server):
# Run method is required to implement for subclass of DashboardHead.
# Since object module only includes the state api, we don't need to
# do anything.
pass
@staticmethod
def is_minimal_module():
return False

View file

@ -17,7 +17,7 @@ class PlacementGroupHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/v0/placement_groups")
async def get_placement_groups(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.gcs_state_aggregator.get_placement_groups()
data = await self._dashboard_head.state_aggregator.get_placement_groups()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)

View file

View file

@ -0,0 +1,33 @@
import logging
import aiohttp.web
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response
logger = logging.getLogger(__name__)
routes = dashboard_optional_utils.ClassMethodRouteTable
class TaskHead(dashboard_utils.DashboardHeadModule):
"""Module to obtain task information of the ray cluster."""
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
@routes.get("/api/v0/tasks")
async def get_tasks(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.state_aggregator.get_tasks()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)
async def run(self, server):
# Run method is required to implement for subclass of DashboardHead.
# Since object module only includes the state api, we don't need to
# do anything.
pass
@staticmethod
def is_minimal_module():
return False

View file

@ -17,7 +17,7 @@ class WorkerHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/v0/workers")
async def get_workers(self, req) -> aiohttp.web.Response:
data = await self._dashboard_head.gcs_state_aggregator.get_workers()
data = await self._dashboard_head.state_aggregator.get_workers()
return rest_response(
success=True, message="", result=data, convert_google_style=False
)

View file

@ -1,42 +1,83 @@
import asyncio
import logging
from collections import defaultdict
from typing import List
import ray.dashboard.utils as dashboard_utils
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
import ray.dashboard.memory_utils as memory_utils
from ray.dashboard.datacenter import DataSource
from ray.dashboard.utils import Change
from ray.experimental.state.common import (
filter_fields,
ActorState,
PlacementGroupState,
NodeState,
WorkerState,
TaskState,
ObjectState,
)
from ray.experimental.state.state_manager import StateDataSourceClient
logger = logging.getLogger(__name__)
DEFAULT_RPC_TIMEOUT = 30
# TODO(sang): Add error handling.
class GcsStateAggregator:
def __init__(self, gcs_channel):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)
self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
gcs_channel
)
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
gcs_channel
)
# TODO(sang): Move the class to state/state_manager.py.
# TODO(sang): Remove *State and replaces with Pydantic or protobuf
# (depending on API interface standardization).
class StateAPIManager:
"""A class to query states from data source, caches, and post-processes
the entries.
"""
def __init__(self, state_data_source_client: StateDataSourceClient):
self._client = state_data_source_client
DataSource.nodes.signal.append(self._update_raylet_stubs)
async def _update_raylet_stubs(self, change: Change):
"""Callback that's called when a new raylet is added to Datasource.
Datasource is a api-server-specific module that's updated whenever
api server adds/removes a new node.
Args:
change: The change object. Whenever a new node is added
or removed, this callback is invoked.
When new node is added: information is in `change.new`.
When a node is removed: information is in `change.old`.
When a node id is overwritten by a new node with the same node id:
`change.old` contains the old node info, and
`change.new` contains the new node info.
"""
# TODO(sang): Move this function out of this class.
if change.old:
# When a node is deleted from the DataSource or it is overwritten.
node_id, node_info = change.old
self._client.unregister_raylet_client(node_id)
if change.new:
# When a new node information is written to DataSource.
node_id, node_info = change.new
self._client.register_raylet_client(
node_id,
node_info["nodeManagerAddress"],
int(node_info["nodeManagerPort"]),
)
@property
def data_source_client(self):
return self._client
async def get_actors(self) -> dict:
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
"""List all actor information from the cluster.
Returns:
{actor_id -> actor_data_in_dict}
actor_data_in_dict's schema is in ActorState
"""
reply = await self._client.get_all_actor_info(timeout=DEFAULT_RPC_TIMEOUT)
result = {}
for message in reply.actor_table_data:
data = self._message_to_dict(message=message, fields_to_decode=["actor_id"])
@ -45,12 +86,16 @@ class GcsStateAggregator:
return result
async def get_placement_groups(self) -> dict:
request = gcs_service_pb2.GetAllPlacementGroupRequest()
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
request, timeout=DEFAULT_RPC_TIMEOUT
"""List all placement group information from the cluster.
Returns:
{pg_id -> pg_data_in_dict}
pg_data_in_dict's schema is in PlacementGroupState
"""
reply = await self._client.get_all_placement_group_info(
timeout=DEFAULT_RPC_TIMEOUT
)
result = {}
logger.error(reply)
for message in reply.placement_group_table_data:
data = self._message_to_dict(
message=message,
@ -61,10 +106,13 @@ class GcsStateAggregator:
return result
async def get_nodes(self) -> dict:
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
"""List all node information from the cluster.
Returns:
{node_id -> node_data_in_dict}
node_data_in_dict's schema is in NodeState
"""
reply = await self._client.get_all_node_info(timeout=DEFAULT_RPC_TIMEOUT)
result = {}
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
@ -73,10 +121,13 @@ class GcsStateAggregator:
return result
async def get_workers(self) -> dict:
request = gcs_service_pb2.GetAllWorkerInfoRequest()
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
request, timeout=DEFAULT_RPC_TIMEOUT
)
"""List all worker information from the cluster.
Returns:
{worker_id -> worker_data_in_dict}
worker_data_in_dict's schema is in WorkerState
"""
reply = await self._client.get_all_worker_info(timeout=DEFAULT_RPC_TIMEOUT)
result = {}
for message in reply.worker_table_data:
data = self._message_to_dict(
@ -87,10 +138,83 @@ class GcsStateAggregator:
result[data["worker_id"]] = data
return result
def _message_to_dict(self, *, message, fields_to_decode) -> dict:
async def get_tasks(self) -> dict:
"""List all task information from the cluster.
Returns:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
replies = await asyncio.gather(
*[
self._client.get_task_info(node_id, timeout=DEFAULT_RPC_TIMEOUT)
for node_id in self._client.get_all_registered_raylet_ids()
]
)
result = defaultdict(dict)
for reply in replies:
tasks = reply.task_info_entries
for task in tasks:
data = self._message_to_dict(
message=task,
fields_to_decode=["task_id"],
)
data = filter_fields(data, TaskState)
result[data["task_id"]] = data
return result
async def get_objects(self) -> dict:
"""List all object information from the cluster.
Returns:
{object_id -> object_data_in_dict}
object_data_in_dict's schema is in ObjectState
"""
replies = await asyncio.gather(
*[
self._client.get_object_info(node_id, timeout=DEFAULT_RPC_TIMEOUT)
for node_id in self._client.get_all_registered_raylet_ids()
]
)
worker_stats = []
for reply in replies:
for core_worker_stat in reply.core_workers_stats:
# NOTE: Set preserving_proto_field_name=False here because
# `construct_memory_table` requires a dictionary that has
# modified protobuf name
# (e.g., workerId instead of worker_id) as a key.
worker_stats.append(
self._message_to_dict(
message=core_worker_stat,
fields_to_decode=["object_id"],
preserving_proto_field_name=False,
)
)
result = {}
memory_table = memory_utils.construct_memory_table(worker_stats)
for entry in memory_table.table:
data = entry.as_dict()
# `construct_memory_table` returns object_ref field which is indeed
# object_id. We do transformation here.
# TODO(sang): Refactor `construct_memory_table`.
data["object_id"] = data["object_ref"]
del data["object_ref"]
data = filter_fields(data, ObjectState)
result[data["object_id"]] = data
return result
def _message_to_dict(
self,
*,
message,
fields_to_decode: List[str],
preserving_proto_field_name: bool = True,
) -> dict:
return dashboard_utils.message_to_dict(
message,
fields_to_decode,
including_default_value_fields=True,
preserving_proto_field_name=True,
preserving_proto_field_name=preserving_proto_field_name,
)

View file

@ -83,3 +83,19 @@ def list_workers(api_server_url: str = None, limit: int = 1000, timeout: int = 3
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
)
def list_tasks(api_server_url: str = None, limit: int = 1000, timeout: int = 30):
return _list(
"tasks",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
)
def list_objects(api_server_url: str = None, limit: int = 1000, timeout: int = 30):
return _list(
"objects",
ListApiOptions(limit=limit, timeout=timeout),
api_server_url=api_server_url,
)

View file

@ -49,3 +49,26 @@ class WorkerState:
worker_id: str
is_alive: str
worker_type: str
@dataclass(init=True)
class TaskState:
task_id: str
name: str
scheduling_state: str
@dataclass(init=True)
class ObjectState:
object_id: str
pid: int
node_ip_address: str
object_size: int
reference_type: str
call_site: str
task_status: str
local_ref_count: int
pinned_in_memory: int
submitted_task_ref_count: int
contained_in_owned: int
type: str

View file

@ -14,6 +14,8 @@ from ray.experimental.state.api import (
list_jobs,
list_placement_groups,
list_workers,
list_tasks,
list_objects,
)
@ -76,3 +78,17 @@ def jobs(ctx):
def workers(ctx):
url = ctx.obj["api_server_url"]
pprint(list_workers(url))
@list_state_cli_group.command()
@click.pass_context
def tasks(ctx):
url = ctx.obj["api_server_url"]
pprint(list_tasks(url))
@list_state_cli_group.command()
@click.pass_context
def objects(ctx):
url = ctx.obj["api_server_url"]
pprint(list_objects(url))

View file

@ -0,0 +1,190 @@
import inspect
from functools import wraps
import grpc
import ray
from typing import Dict, List
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoRequest,
GetAllActorInfoReply,
GetAllPlacementGroupRequest,
GetAllPlacementGroupReply,
GetAllNodeInfoRequest,
GetAllNodeInfoReply,
GetAllWorkerInfoRequest,
GetAllWorkerInfoReply,
)
from ray.core.generated.node_manager_pb2 import (
GetTasksInfoRequest,
GetTasksInfoReply,
GetNodeStatsRequest,
GetNodeStatsReply,
)
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.node_manager_pb2_grpc import NodeManagerServiceStub
from ray.core.generated.reporter_pb2_grpc import ReporterServiceStub
from ray.dashboard.modules.job.common import JobInfoStorageClient, JobInfo
class StateSourceNetworkException(Exception):
"""Exceptions raised when there's a network error from data source query."""
pass
def handle_network_errors(func):
"""Apply the network error handling logic to each APIs,
such as retry or exception policies.
It is a helper method for `StateDataSourceClient`.
The method can only be used for async methods.
"""
assert inspect.iscoroutinefunction(func)
@wraps(func)
async def api_with_network_error_handler(*args, **kwargs):
# TODO(sang): Add a retry policy.
try:
return await func(*args, **kwargs)
except (
# https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc-exceptions
grpc.aio.AioRpcError,
grpc.aio.InternalError,
grpc.aio.AbortError,
grpc.aio.BaseError,
grpc.aio.UsageError,
) as e:
raise StateSourceNetworkException(
f"Failed to query the data source, {func}"
) from e
return api_with_network_error_handler
class StateDataSourceClient:
"""The client to query states from various data sources such as Raylet, GCS, Agents.
Note that it doesn't directly query core workers. They are proxied through raylets.
The module is not in charge of service discovery. The caller is responsible for
finding services and register stubs through `register*` APIs.
Non `register*` APIs
- throw a ValueError if it cannot find the source.
- throw `StateSourceNetworkException` if there's any network errors.
"""
def __init__(self, gcs_channel: grpc.aio.Channel):
self.register_gcs_client(gcs_channel)
self._raylet_stubs = {}
self._agent_stubs = {}
self._job_client = JobInfoStorageClient()
def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)
self._gcs_pg_info_stub = gcs_service_pb2_grpc.PlacementGroupInfoGcsServiceStub(
gcs_channel
)
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
gcs_channel
)
def register_raylet_client(self, node_id: str, address: str, port: int):
full_addr = f"{address}:{port}"
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
full_addr, options, asynchronous=True
)
self._raylet_stubs[node_id] = NodeManagerServiceStub(channel)
def unregister_raylet_client(self, node_id: str):
self._raylet_stubs.pop(node_id)
def register_agent_client(self, node_id, address: str, port: int):
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
f"{address}:{port}", options=options, asynchronous=True
)
self._agent_stubs[node_id] = ReporterServiceStub(channel)
def unregister_agent_client(self, node_id: str):
self._agent_stubs.pop(node_id)
def get_all_registered_raylet_ids(self) -> List[str]:
return self._raylet_stubs.keys()
def get_all_registered_agent_ids(self) -> List[str]:
return self._agent_stubs.keys()
@handle_network_errors
async def get_all_actor_info(self, timeout: int = None) -> GetAllActorInfoReply:
request = GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=timeout
)
return reply
@handle_network_errors
async def get_all_placement_group_info(
self, timeout: int = None
) -> GetAllPlacementGroupReply:
request = GetAllPlacementGroupRequest()
reply = await self._gcs_pg_info_stub.GetAllPlacementGroup(
request, timeout=timeout
)
return reply
@handle_network_errors
async def get_all_node_info(self, timeout: int = None) -> GetAllNodeInfoReply:
request = GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=timeout)
return reply
@handle_network_errors
async def get_all_worker_info(self, timeout: int = None) -> GetAllWorkerInfoReply:
request = GetAllWorkerInfoRequest()
reply = await self._gcs_worker_info_stub.GetAllWorkerInfo(
request, timeout=timeout
)
return reply
def get_job_info(self) -> Dict[str, JobInfo]:
# Cannot use @handle_network_errors because async def is not supported yet.
# TODO(sang): Support timeout & make it async
try:
return self._job_client.get_all_jobs()
except Exception as e:
raise StateSourceNetworkException("Failed to query the job info.") from e
@handle_network_errors
async def get_task_info(
self, node_id: str, timeout: int = None
) -> GetTasksInfoReply:
stub = self._raylet_stubs.get(node_id)
if not stub:
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
reply = await stub.GetTasksInfo(GetTasksInfoRequest(), timeout=timeout)
return reply
@handle_network_errors
async def get_object_info(
self, node_id: str, timeout: int = None
) -> GetNodeStatsReply:
stub = self._raylet_stubs.get(node_id)
if not stub:
raise ValueError(f"Raylet for a node id, {node_id} doesn't exist.")
reply = await stub.GetNodeStats(
GetNodeStatsRequest(include_memory_info=True),
timeout=timeout,
)
return reply

View file

@ -2,21 +2,372 @@ import sys
import pytest
from typing import List
from dataclasses import fields
from unittest.mock import MagicMock
from asyncmock import AsyncMock
import ray
import ray.ray_constants as ray_constants
from click.testing import CliRunner
from ray.cluster_utils import cluster_not_supported
from ray.core.generated.common_pb2 import (
Address,
WorkerType,
TaskStatus,
TaskInfoEntry,
CoreWorkerStats,
ObjectRefInfo,
)
from ray.core.generated.node_manager_pb2 import GetTasksInfoReply, GetNodeStatsReply
from ray.core.generated.gcs_pb2 import (
ActorTableData,
PlacementGroupTableData,
GcsNodeInfo,
WorkerTableData,
)
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply,
GetAllPlacementGroupReply,
GetAllNodeInfoReply,
GetAllWorkerInfoReply,
)
from ray.dashboard.state_aggregator import StateAPIManager, DEFAULT_RPC_TIMEOUT
from ray.experimental.state.api import (
list_actors,
list_placement_groups,
list_nodes,
list_jobs,
list_workers,
list_tasks,
list_objects,
)
from ray.experimental.state.common import (
ActorState,
PlacementGroupState,
NodeState,
WorkerState,
TaskState,
ObjectState,
)
from ray.experimental.state.state_manager import (
StateDataSourceClient,
StateSourceNetworkException,
)
from ray.experimental.state.state_cli import list_state_cli_group
from ray._private.test_utils import wait_for_condition
from ray.job_submission import JobSubmissionClient
from ray.experimental.state.state_cli import list_state_cli_group
"""
Unit tests
"""
@pytest.fixture
def state_api_manager():
data_source_client = AsyncMock(StateDataSourceClient)
manager = StateAPIManager(data_source_client)
yield manager
def verify_schema(state, result_dict: dict):
state_fields_columns = set()
for field in fields(state):
state_fields_columns.add(field.name)
for k in result_dict.keys():
assert k in state_fields_columns
@pytest.mark.asyncio
async def test_api_manager_list_actors(state_api_manager):
data_source_client = state_api_manager.data_source_client
actor_id = b"1234"
data_source_client.get_all_actor_info.return_value = GetAllActorInfoReply(
actor_table_data=[
ActorTableData(
actor_id=actor_id,
state=ActorTableData.ActorState.ALIVE,
name="abc",
pid=1234,
class_name="class",
)
]
)
result = await state_api_manager.get_actors()
actor_data = list(result.values())[0]
verify_schema(ActorState, actor_data)
@pytest.mark.asyncio
async def test_api_manager_list_pgs(state_api_manager):
data_source_client = state_api_manager.data_source_client
id = b"1234"
data_source_client.get_all_placement_group_info.return_value = (
GetAllPlacementGroupReply(
placement_group_table_data=[
PlacementGroupTableData(
placement_group_id=id,
state=PlacementGroupTableData.PlacementGroupState.CREATED,
name="abc",
creator_job_dead=True,
creator_actor_dead=False,
)
]
)
)
result = await state_api_manager.get_placement_groups()
data = list(result.values())[0]
verify_schema(PlacementGroupState, data)
@pytest.mark.asyncio
async def test_api_manager_list_nodes(state_api_manager):
data_source_client = state_api_manager.data_source_client
id = b"1234"
data_source_client.get_all_node_info.return_value = GetAllNodeInfoReply(
node_info_list=[
GcsNodeInfo(
node_id=id,
state=GcsNodeInfo.GcsNodeState.ALIVE,
node_manager_address="127.0.0.1",
raylet_socket_name="abcd",
object_store_socket_name="False",
)
]
)
result = await state_api_manager.get_nodes()
data = list(result.values())[0]
verify_schema(NodeState, data)
@pytest.mark.asyncio
async def test_api_manager_list_workers(state_api_manager):
data_source_client = state_api_manager.data_source_client
id = b"1234"
data_source_client.get_all_worker_info.return_value = GetAllWorkerInfoReply(
worker_table_data=[
WorkerTableData(
worker_address=Address(
raylet_id=id, ip_address="127.0.0.1", port=124, worker_id=id
),
is_alive=True,
timestamp=1234,
worker_type=WorkerType.WORKER,
)
]
)
result = await state_api_manager.get_workers()
data = list(result.values())[0]
verify_schema(WorkerState, data)
@pytest.mark.skip(
reason=("Not passing in CI although it works locally. Will handle it later.")
)
@pytest.mark.asyncio
async def test_api_manager_list_tasks(state_api_manager):
data_source_client = state_api_manager.data_source_client
data_source_client.get_all_registered_raylet_ids = MagicMock()
data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]
def generate_task_info(id, name):
return GetTasksInfoReply(
task_info_entries=[
TaskInfoEntry(
task_id=id,
name=name,
func_or_class_name="class",
scheduling_state=TaskStatus.SCHEDULED,
)
]
)
first_task_name = "1"
second_task_name = "2"
data_source_client.get_task_info.side_effect = [
generate_task_info(b"1234", first_task_name),
generate_task_info(b"2345", second_task_name),
]
result = await state_api_manager.get_tasks()
data_source_client.get_task_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
data_source_client.get_task_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
result = list(result.values())
assert len(result) == 2
verify_schema(TaskState, result[0])
verify_schema(TaskState, result[1])
@pytest.mark.skip(
reason=("Not passing in CI although it works locally. Will handle it later.")
)
@pytest.mark.asyncio
async def test_api_manager_list_objects(state_api_manager):
data_source_client = state_api_manager.data_source_client
obj_1_id = b"1" * 28
obj_2_id = b"2" * 28
data_source_client.get_all_registered_raylet_ids = MagicMock()
data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]
def generate_node_stats_reply(obj_id):
return GetNodeStatsReply(
core_workers_stats=[
CoreWorkerStats(
pid=1234,
worker_type=WorkerType.DRIVER,
ip_address="1234",
object_refs=[
ObjectRefInfo(
object_id=obj_id,
call_site="",
object_size=1,
local_ref_count=1,
submitted_task_ref_count=1,
contained_in_owned=[],
pinned_in_memory=True,
task_status=TaskStatus.SCHEDULED,
attempt_number=1,
)
],
)
]
)
data_source_client.get_object_info.side_effect = [
generate_node_stats_reply(obj_1_id),
generate_node_stats_reply(obj_2_id),
]
result = await state_api_manager.get_objects()
data_source_client.get_object_info.assert_any_call("1", timeout=DEFAULT_RPC_TIMEOUT)
data_source_client.get_object_info.assert_any_call("2", timeout=DEFAULT_RPC_TIMEOUT)
result = list(result.values())
assert len(result) == 2
verify_schema(ObjectState, result[0])
verify_schema(ObjectState, result[1])
"""
Integration tests
"""
@pytest.mark.asyncio
async def test_state_data_source_client(ray_start_cluster):
cluster = ray_start_cluster
# head
cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
# worker
worker = cluster.add_node(num_cpus=2)
GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0),
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)
gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
client = StateDataSourceClient(gcs_channel)
"""
Test actor
"""
result = await client.get_all_actor_info()
assert isinstance(result, GetAllActorInfoReply)
"""
Test placement group
"""
result = await client.get_all_placement_group_info()
assert isinstance(result, GetAllPlacementGroupReply)
"""
Test node
"""
result = await client.get_all_node_info()
assert isinstance(result, GetAllNodeInfoReply)
"""
Test worker info
"""
result = await client.get_all_worker_info()
assert isinstance(result, GetAllWorkerInfoReply)
"""
Test job
"""
job_client = JobSubmissionClient(
f"http://{ray.worker.global_worker.node.address_info['webui_url']}"
)
job_id = job_client.submit_job( # noqa
# Entrypoint shell command to execute
entrypoint="ls",
)
result = client.get_job_info()
assert list(result.keys())[0] == job_id
assert isinstance(result, dict)
"""
Test tasks
"""
with pytest.raises(ValueError):
# Since we didn't register this node id, it should raise an exception.
result = await client.get_task_info("1234")
wait_for_condition(lambda: len(ray.nodes()) == 2)
for node in ray.nodes():
node_id = node["NodeID"]
ip = node["NodeManagerAddress"]
port = int(node["NodeManagerPort"])
client.register_raylet_client(node_id, ip, port)
result = await client.get_task_info(node_id)
assert isinstance(result, GetTasksInfoReply)
assert len(client.get_all_registered_raylet_ids()) == 2
"""
Test objects
"""
with pytest.raises(ValueError):
# Since we didn't register this node id, it should raise an exception.
result = await client.get_object_info("1234")
wait_for_condition(lambda: len(ray.nodes()) == 2)
for node in ray.nodes():
node_id = node["NodeID"]
ip = node["NodeManagerAddress"]
port = int(node["NodeManagerPort"])
client.register_raylet_client(node_id, ip, port)
result = await client.get_object_info(node_id)
assert isinstance(result, GetNodeStatsReply)
"""
Test the exception is raised when the RPC error occurs.
"""
cluster.remove_node(worker)
# Wait until the dead node information is propagated.
wait_for_condition(
lambda: len(list(filter(lambda node: node["Alive"], ray.nodes()))) == 1
)
for node in ray.nodes():
node_id = node["NodeID"]
if node["Alive"]:
continue
# Querying to the dead node raises gRPC error, which should be
# translated into `StateSourceNetworkException`
with pytest.raises(StateSourceNetworkException):
result = await client.get_object_info(node_id)
# Make sure unregister API works as expected.
client.unregister_raylet_client(node_id)
assert len(client.get_all_registered_raylet_ids()) == 1
# Since the node_id is unregistered, the API should raise ValueError.
with pytest.raises(ValueError):
result = await client.get_object_info(node_id)
def is_hex(val):
@ -31,9 +382,10 @@ def is_hex(val):
def test_cli_apis_sanity_check(ray_start_cluster):
"""Test all of CLI APIs work as expected."""
cluster = ray_start_cluster
for _ in range(4):
cluster.add_node(num_cpus=2)
cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
for _ in range(3):
cluster.add_node(num_cpus=2)
runner = CliRunner()
client = JobSubmissionClient(
@ -68,11 +420,15 @@ def test_cli_apis_sanity_check(ray_start_cluster):
print(result.output)
return exit_code_correct and substring_matched
assert verify_output("actors", ["actor_id"])
assert verify_output("workers", ["worker_id"])
assert verify_output("nodes", ["node_id"])
assert verify_output("placement-groups", ["placement_group_id"])
assert verify_output("jobs", ["raysubmit"])
wait_for_condition(lambda: verify_output("actors", ["actor_id"]))
wait_for_condition(lambda: verify_output("workers", ["worker_id"]))
wait_for_condition(lambda: verify_output("nodes", ["node_id"]))
wait_for_condition(
lambda: verify_output("placement-groups", ["placement_group_id"])
)
wait_for_condition(lambda: verify_output("jobs", ["raysubmit"]))
wait_for_condition(lambda: verify_output("tasks", ["task_id"]))
wait_for_condition(lambda: verify_output("objects", ["object_id"]))
@pytest.mark.skipif(
@ -180,6 +536,67 @@ def test_list_workers(shutdown_only):
print(list_workers())
def test_list_tasks(shutdown_only):
ray.init(num_cpus=2)
@ray.remote
def f():
import time
time.sleep(30)
@ray.remote
def g(dep):
import time
time.sleep(30)
out = [f.remote() for _ in range(2)] # noqa
g_out = g.remote(f.remote()) # noqa
def verify():
tasks = list(list_tasks().values())
correct_num_tasks = len(tasks) == 4
scheduled = len(
list(filter(lambda task: task["scheduling_state"] == "SCHEDULED", tasks))
)
waiting_for_dep = len(
list(
filter(
lambda task: task["scheduling_state"] == "WAITING_FOR_DEPENDENCIES",
tasks,
)
)
)
return correct_num_tasks and scheduled == 3 and waiting_for_dep == 1
wait_for_condition(verify)
print(list_tasks())
def test_list_objects(shutdown_only):
ray.init()
import numpy as np
data = np.ones(50 * 1024 * 1024, dtype=np.uint8)
plasma_obj = ray.put(data)
@ray.remote
def f(obj):
print(obj)
ray.get(f.remote(plasma_obj))
def verify():
obj = list(list_objects().values())[0]
# For detailed output, the test is covered from `test_memstat.py`
return obj["object_id"] == plasma_obj.hex()
wait_for_condition(verify)
print(list_objects())
if __name__ == "__main__":
import sys

View file

@ -64,6 +64,7 @@ moto
mypy
networkx
numba
asyncmock
# higher version of llvmlite breaks windows
llvmlite==0.34.0
openpyxl

View file

@ -3152,6 +3152,10 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest &
task_manager_->AddTaskStatusInfo(stats);
}
if (request.include_task_info()) {
task_manager_->FillTaskInfo(reply);
}
send_reply_callback(Status::OK(), nullptr, nullptr);
}

View file

@ -640,5 +640,36 @@ void TaskManager::MarkDependenciesResolved(const TaskID &task_id) {
}
}
void TaskManager::FillTaskInfo(rpc::GetCoreWorkerStatsReply *reply) const {
absl::MutexLock lock(&mu_);
for (const auto &task_it : submissible_tasks_) {
const auto &task_entry = task_it.second;
auto entry = reply->add_task_info_entries();
const auto &task_spec = task_entry.spec;
const auto &task_state = task_entry.status;
rpc::TaskType type;
if (task_spec.IsNormalTask()) {
type = rpc::TaskType::NORMAL_TASK;
} else if (task_spec.IsActorCreationTask()) {
type = rpc::TaskType::ACTOR_CREATION_TASK;
} else {
RAY_CHECK(task_spec.IsActorTask());
type = rpc::TaskType::ACTOR_TASK;
}
entry->set_type(type);
entry->set_name(task_spec.GetName());
entry->set_language(task_spec.GetLanguage());
entry->set_func_or_class_name(task_spec.FunctionDescriptor()->CallString());
entry->set_scheduling_state(task_state);
entry->set_job_id(task_spec.JobId().Binary());
entry->set_task_id(task_spec.TaskId().Binary());
entry->set_parent_task_id(task_spec.ParentTaskId().Binary());
const auto &resources_map = task_spec.GetRequiredResources().GetResourceMap();
entry->mutable_required_resources()->insert(resources_map.begin(),
resources_map.end());
entry->mutable_runtime_env_info()->CopyFrom(task_spec.RuntimeEnvInfo());
}
}
} // namespace core
} // namespace ray

View file

@ -256,6 +256,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
/// any.
void AddTaskStatusInfo(rpc::CoreWorkerStats *stats) const;
/// Fill every task information of the current worker to GetCoreWorkerStatsReply.
void FillTaskInfo(rpc::GetCoreWorkerStatsReply *reply) const;
private:
struct TaskEntry {
TaskEntry(const TaskSpecification &spec_arg,

View file

@ -320,6 +320,28 @@ message TaskSpec {
uint64 attempt_number = 28;
}
message TaskInfoEntry {
// Type of this task.
TaskType type = 1;
// Name of this task.
string name = 2;
// Language of this task.
Language language = 3;
// Function descriptor of this task uniquely describe the function to execute.
string func_or_class_name = 4;
TaskStatus scheduling_state = 5;
// ID of the job that this task belongs to.
bytes job_id = 6;
// Task ID of the task.
bytes task_id = 7;
// Task ID of the parent task.
bytes parent_task_id = 8;
// Quantities of the different resources required by this task.
map<string, double> required_resources = 13;
// Runtime environment for this task.
RuntimeEnvInfo runtime_env_info = 23;
}
message Bundle {
message BundleIdentifier {
bytes placement_group_id = 1;

View file

@ -267,11 +267,16 @@ message GetCoreWorkerStatsRequest {
// Whether to include memory stats. This could be large since it includes
// metadata for all live object references.
bool include_memory_info = 2;
// Whether to include task information. This could be large since it
// includes metadata for all live tasks.
bool include_task_info = 3;
}
message GetCoreWorkerStatsReply {
// Debug information returned from the core worker.
CoreWorkerStats core_worker_stats = 1;
// A list of task information of the current worker.
repeated TaskInfoEntry task_info_entries = 2;
}
message LocalGCRequest {

View file

@ -298,6 +298,18 @@ message GetGcsServerAddressReply {
int32 port = 2;
}
message GetTasksInfoRequest {}
message GetTasksInfoReply {
repeated TaskInfoEntry task_info_entries = 1;
}
message GetObjectsInfoRequest {}
message GetObjectsInfoReply {
repeated CoreWorkerStats core_workers_stats = 1;
}
// Service for inter-node-manager communication.
service NodeManagerService {
// Update the node's view of the cluster resource usage
@ -355,4 +367,8 @@ service NodeManagerService {
rpc GetSystemConfig(GetSystemConfigRequest) returns (GetSystemConfigReply);
// Get gcs server address.
rpc GetGcsServerAddress(GetGcsServerAddressRequest) returns (GetGcsServerAddressReply);
// [State API] Get the all task information of the node.
rpc GetTasksInfo(GetTasksInfoRequest) returns (GetTasksInfoReply);
// [State API] Get the all object information of the node.
rpc GetObjectsInfo(GetObjectsInfoRequest) returns (GetObjectsInfoReply);
}

View file

@ -732,6 +732,86 @@ void NodeManager::HandleReleaseUnusedBundles(
send_reply_callback(Status::OK(), nullptr, nullptr);
}
void NodeManager::HandleGetTasksInfo(const rpc::GetTasksInfoRequest &request,
rpc::GetTasksInfoReply *reply,
rpc::SendReplyCallback send_reply_callback) {
QueryAllWorkerStates(
/*on_replied*/
[reply](const ray::Status &status, const rpc::GetCoreWorkerStatsReply &r) {
if (status.ok()) {
for (const auto &task_info : r.task_info_entries()) {
reply->add_task_info_entries()->CopyFrom(task_info);
}
} else {
RAY_LOG(INFO) << "Failed to query task information from a worker.";
}
},
send_reply_callback,
/*include_memory_info*/ false,
/*include_task_info*/ true);
}
void NodeManager::HandleGetObjectsInfo(const rpc::GetObjectsInfoRequest &request,
rpc::GetObjectsInfoReply *reply,
rpc::SendReplyCallback send_reply_callback) {
QueryAllWorkerStates(
/*on_replied*/
[reply](const ray::Status &status, const rpc::GetCoreWorkerStatsReply &r) {
if (status.ok()) {
reply->add_core_workers_stats()->MergeFrom(r.core_worker_stats());
} else {
RAY_LOG(INFO) << "Failed to query object information from a worker.";
}
},
send_reply_callback,
/*include_memory_info*/ true,
/*include_task_info*/ false);
}
void NodeManager::QueryAllWorkerStates(
const std::function<void(const ray::Status &, const rpc::GetCoreWorkerStatsReply &)>
&on_replied,
rpc::SendReplyCallback &send_reply_callback,
bool include_memory_info,
bool include_task_info) {
auto all_workers = worker_pool_.GetAllRegisteredWorkers(/* filter_dead_worker */ true);
for (auto driver :
worker_pool_.GetAllRegisteredDrivers(/* filter_dead_driver */ true)) {
all_workers.push_back(driver);
}
if (all_workers.empty()) {
send_reply_callback(Status::OK(), nullptr, nullptr);
return;
}
auto rpc_replied = std::make_shared<size_t>(0);
auto num_workers = all_workers.size();
for (const auto &worker : all_workers) {
if (worker->IsDead()) {
continue;
}
rpc::GetCoreWorkerStatsRequest request;
request.set_intended_worker_id(worker->WorkerId().Binary());
request.set_include_memory_info(include_memory_info);
request.set_include_task_info(include_task_info);
// TODO(sang): Add timeout to the RPC call.
worker->rpc_client()->GetCoreWorkerStats(
request,
[num_workers,
rpc_replied,
send_reply_callback,
on_replied = std::move(on_replied)](const ray::Status &status,
const rpc::GetCoreWorkerStatsReply &r) {
*rpc_replied += 1;
on_replied(status, r);
if (*rpc_replied == num_workers) {
send_reply_callback(Status::OK(), nullptr, nullptr);
}
});
}
}
// This warns users that there could be the resource deadlock. It works this way;
// - If there's no available workers for scheduling
// - But if there are still pending tasks waiting for resource acquisition

View file

@ -208,6 +208,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// Stop this node manager.
void Stop();
/// Query all of local core worker states.
///
/// \param on_replied A callback that's called when each of query RPC is replied.
/// \param send_reply_callback A reply callback that will be called when all
/// RPCs are replied.
/// \param include_memory_info If true, it requires every object ref information
/// from all workers.
/// \param include_task_info If true, it requires every task metadata information
/// from all workers.
void QueryAllWorkerStates(
const std::function<void(const ray::Status &status,
const rpc::GetCoreWorkerStatsReply &r)> &on_replied,
rpc::SendReplyCallback &send_reply_callback,
bool include_memory_info,
bool include_task_info);
private:
/// Methods for handling nodes.
@ -562,6 +578,16 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
rpc::GetGcsServerAddressReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Handle a `HandleGetTasksInfo` request.
void HandleGetTasksInfo(const rpc::GetTasksInfoRequest &request,
rpc::GetTasksInfoReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Handle a `HandleGetObjectsInfo` request.
void HandleGetObjectsInfo(const rpc::GetObjectsInfoRequest &request,
rpc::GetObjectsInfoReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Trigger local GC on each worker of this raylet.
void DoLocalGC(bool triggered_by_global_gc = false);

View file

@ -177,6 +177,18 @@ class NodeManagerWorkerClient
grpc_client_,
/*method_timeout_ms*/ -1, )
/// Get all the task information from the node.
VOID_RPC_CLIENT_METHOD(NodeManagerService,
GetTasksInfo,
grpc_client_,
/*method_timeout_ms*/ -1, )
/// Get all the object information from the node.
VOID_RPC_CLIENT_METHOD(NodeManagerService,
GetObjectsInfo,
grpc_client_,
/*method_timeout_ms*/ -1, )
private:
/// Constructor.
///

View file

@ -43,7 +43,9 @@ namespace rpc {
RPC_SERVICE_HANDLER(NodeManagerService, ReleaseUnusedBundles, -1) \
RPC_SERVICE_HANDLER(NodeManagerService, GetSystemConfig, -1) \
RPC_SERVICE_HANDLER(NodeManagerService, GetGcsServerAddress, -1) \
RPC_SERVICE_HANDLER(NodeManagerService, ShutdownRaylet, -1)
RPC_SERVICE_HANDLER(NodeManagerService, ShutdownRaylet, -1) \
RPC_SERVICE_HANDLER(NodeManagerService, GetTasksInfo, -1) \
RPC_SERVICE_HANDLER(NodeManagerService, GetObjectsInfo, -1)
/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`.
class NodeManagerServiceHandler {
@ -138,6 +140,12 @@ class NodeManagerServiceHandler {
virtual void HandleGetGcsServerAddress(const GetGcsServerAddressRequest &request,
GetGcsServerAddressReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleGetTasksInfo(const GetTasksInfoRequest &request,
GetTasksInfoReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleGetObjectsInfo(const GetObjectsInfoRequest &request,
GetObjectsInfoReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcService` for `NodeManagerService`.