mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00

This is to limit the max number of HTTP requests the dashboard (API server) will accept before rejecting more requests. This will make sure the observability requests do not overload the downstream systems (raylet/gcs) when delegating too many concurrent state observability requests to the cluster.
423 lines
16 KiB
Python
423 lines
16 KiB
Python
import logging
|
|
from dataclasses import asdict
|
|
from typing import Callable, Optional
|
|
|
|
import aiohttp.web
|
|
from abc import ABC, abstractmethod
|
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
|
import ray.dashboard.utils as dashboard_utils
|
|
from ray.dashboard.consts import (
|
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST,
|
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED,
|
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME,
|
|
)
|
|
from ray.dashboard.datacenter import DataSource
|
|
from ray.dashboard.modules.log.log_manager import LogsManager
|
|
from ray.dashboard.optional_utils import rest_response
|
|
from ray.dashboard.state_aggregator import StateAPIManager
|
|
from ray.dashboard.utils import Change
|
|
from ray.experimental.state.common import (
|
|
ListApiOptions,
|
|
GetLogOptions,
|
|
SummaryApiOptions,
|
|
SummaryApiResponse,
|
|
DEFAULT_RPC_TIMEOUT,
|
|
DEFAULT_LIMIT,
|
|
DEFAULT_LOG_LIMIT,
|
|
)
|
|
from ray.experimental.state.exception import DataSourceUnavailable
|
|
from ray.experimental.state.state_manager import StateDataSourceClient
|
|
from ray.experimental.state.util import convert_string_to_type
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
routes = dashboard_optional_utils.ClassMethodRouteTable
|
|
|
|
|
|
class RateLimitedModule(ABC):
|
|
"""Simple rate limiter
|
|
|
|
Inheriting from this class and decorate any class methods will
|
|
apply simple rate limit.
|
|
It will limit the maximal number of concurrent invocations of **all** the
|
|
methods decorated.
|
|
|
|
The below Example class will only allow 10 concurrent calls to A() and B()
|
|
|
|
E.g.:
|
|
|
|
class Example(RateLimitedModule):
|
|
def __init__(self):
|
|
super().__init__(max_num_call=10)
|
|
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def A():
|
|
...
|
|
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def B():
|
|
...
|
|
|
|
async def limit_handler_(self):
|
|
raise RuntimeError("rate limited reached!")
|
|
|
|
"""
|
|
|
|
def __init__(self, max_num_call: int, logger: Optional[logging.Logger] = None):
|
|
"""
|
|
Args:
|
|
max_num_call: Maximal number of concurrent invocations of all decorated
|
|
functions in the instance.
|
|
Setting to -1 will disable rate limiting.
|
|
|
|
logger: Logger
|
|
"""
|
|
self.max_num_call_ = max_num_call
|
|
self.num_call_ = 0
|
|
self.logger_ = logger
|
|
|
|
@staticmethod
|
|
def enforce_max_concurrent_calls(func):
|
|
"""Decorator to enforce max number of invocations of the decorated func
|
|
|
|
NOTE: This should be used as the innermost decorator if there are multiple
|
|
ones.
|
|
|
|
E.g., when decorating functions already with @routes.get(...), this must be
|
|
added below then the routes decorators:
|
|
```
|
|
@routes.get('/')
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def fn(self):
|
|
...
|
|
|
|
```
|
|
"""
|
|
|
|
async def async_wrapper(self, *args, **kwargs):
|
|
if self.max_num_call_ >= 0 and self.num_call_ >= self.max_num_call_:
|
|
if self.logger_:
|
|
self.logger_.warning(
|
|
f"Max concurrent requests reached={self.max_num_call_}"
|
|
)
|
|
return await self.limit_handler_()
|
|
self.num_call_ += 1
|
|
try:
|
|
ret = await func(self, *args, **kwargs)
|
|
finally:
|
|
self.num_call_ -= 1
|
|
return ret
|
|
|
|
# Returning closure here to avoid passing 'self' to the
|
|
# 'enforce_max_concurrent_calls' decorator.
|
|
return async_wrapper
|
|
|
|
@abstractmethod
|
|
async def limit_handler_(self):
|
|
"""Handler that is invoked when max number of concurrent calls reached"""
|
|
|
|
|
|
class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
|
|
"""Module to obtain state information from the Ray cluster.
|
|
|
|
It is responsible for state observability APIs such as
|
|
ray.list_actors(), ray.get_actor(), ray.summary_actors().
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dashboard_head,
|
|
):
|
|
"""Initialize for handling RESTful requests from State API Client"""
|
|
dashboard_utils.DashboardHeadModule.__init__(self, dashboard_head)
|
|
# We don't allow users to configure too high a rate limit
|
|
RateLimitedModule.__init__(
|
|
self,
|
|
min(
|
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST,
|
|
RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED,
|
|
),
|
|
)
|
|
self._state_api_data_source_client = None
|
|
self._state_api = None
|
|
self._log_api = None
|
|
|
|
DataSource.nodes.signal.append(self._update_raylet_stubs)
|
|
DataSource.agents.signal.append(self._update_agent_stubs)
|
|
|
|
async def limit_handler_(self):
|
|
return self._reply(
|
|
success=False,
|
|
error_message=(
|
|
"Max number of in-progress requests="
|
|
f"{self.max_num_call_} reached."
|
|
"To set a higher limit, set environment variable: "
|
|
f"export {RAY_STATE_SERVER_MAX_HTTP_REQUEST_ENV_NAME}='xxx'."
|
|
f"Max allowed = {RAY_STATE_SERVER_MAX_HTTP_REQUEST_ALLOWED}"
|
|
),
|
|
result=None,
|
|
)
|
|
|
|
def _options_from_req(self, req: aiohttp.web.Request) -> ListApiOptions:
|
|
"""Obtain `ListApiOptions` from the aiohttp request."""
|
|
limit = int(
|
|
req.query.get("limit")
|
|
if req.query.get("limit") is not None
|
|
else DEFAULT_LIMIT
|
|
)
|
|
timeout = int(req.query.get("timeout"))
|
|
filter_keys = req.query.getall("filter_keys", [])
|
|
filter_predicates = req.query.getall("filter_predicates", [])
|
|
filter_values = req.query.getall("filter_values", [])
|
|
assert len(filter_keys) == len(filter_values)
|
|
filters = []
|
|
for key, predicate, val in zip(filter_keys, filter_predicates, filter_values):
|
|
filters.append((key, predicate, val))
|
|
detail = convert_string_to_type(req.query.get("detail", False), bool)
|
|
|
|
return ListApiOptions(
|
|
limit=limit, timeout=timeout, filters=filters, detail=detail
|
|
)
|
|
|
|
def _summary_options_from_req(self, req: aiohttp.web.Request) -> SummaryApiOptions:
|
|
timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
|
|
return SummaryApiOptions(timeout=timeout)
|
|
|
|
def _reply(self, success: bool, error_message: str, result: dict, **kwargs):
|
|
"""Reply to the client."""
|
|
return rest_response(
|
|
success=success,
|
|
message=error_message,
|
|
result=result,
|
|
convert_google_style=False,
|
|
**kwargs,
|
|
)
|
|
|
|
async def _update_raylet_stubs(self, change: Change):
|
|
"""Callback that's called when a new raylet is added to Datasource.
|
|
|
|
Datasource is a api-server-specific module that's updated whenever
|
|
api server adds/removes a new node.
|
|
|
|
Args:
|
|
change: The change object. Whenever a new node is added
|
|
or removed, this callback is invoked.
|
|
When new node is added: information is in `change.new`.
|
|
When a node is removed: information is in `change.old`.
|
|
When a node id is overwritten by a new node with the same node id:
|
|
`change.old` contains the old node info, and
|
|
`change.new` contains the new node info.
|
|
"""
|
|
if change.old:
|
|
# When a node is deleted from the DataSource or it is overwritten.
|
|
node_id, node_info = change.old
|
|
self._state_api_data_source_client.unregister_raylet_client(node_id)
|
|
if change.new:
|
|
# When a new node information is written to DataSource.
|
|
node_id, node_info = change.new
|
|
self._state_api_data_source_client.register_raylet_client(
|
|
node_id,
|
|
node_info["nodeManagerAddress"],
|
|
int(node_info["nodeManagerPort"]),
|
|
)
|
|
|
|
async def _update_agent_stubs(self, change: Change):
|
|
"""Callback that's called when a new agent is added to Datasource."""
|
|
if change.old:
|
|
node_id, _ = change.old
|
|
self._state_api_data_source_client.unregister_agent_client(node_id)
|
|
if change.new:
|
|
# When a new node information is written to DataSource.
|
|
node_id, ports = change.new
|
|
ip = DataSource.node_id_to_ip[node_id]
|
|
self._state_api_data_source_client.register_agent_client(
|
|
node_id,
|
|
ip,
|
|
int(ports[1]),
|
|
)
|
|
|
|
async def _handle_list_api(
|
|
self, list_api_fn: Callable[[ListApiOptions], dict], req: aiohttp.web.Request
|
|
):
|
|
try:
|
|
result = await list_api_fn(option=self._options_from_req(req))
|
|
return self._reply(
|
|
success=True,
|
|
error_message="",
|
|
result=asdict(result),
|
|
)
|
|
except DataSourceUnavailable as e:
|
|
return self._reply(success=False, error_message=str(e), result=None)
|
|
|
|
@routes.get("/api/v0/actors")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_actors, req)
|
|
|
|
@routes.get("/api/v0/jobs")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
try:
|
|
result = self._state_api.list_jobs(option=self._options_from_req(req))
|
|
return self._reply(
|
|
success=True,
|
|
error_message="",
|
|
result=asdict(result),
|
|
)
|
|
except DataSourceUnavailable as e:
|
|
return self._reply(success=False, error_message=str(e), result=None)
|
|
|
|
@routes.get("/api/v0/nodes")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_nodes(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_nodes, req)
|
|
|
|
@routes.get("/api/v0/placement_groups")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_placement_groups(
|
|
self, req: aiohttp.web.Request
|
|
) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_placement_groups, req)
|
|
|
|
@routes.get("/api/v0/workers")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_workers(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_workers, req)
|
|
|
|
@routes.get("/api/v0/tasks")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_tasks(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_tasks, req)
|
|
|
|
@routes.get("/api/v0/objects")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_objects(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_objects, req)
|
|
|
|
@routes.get("/api/v0/runtime_envs")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_runtime_envs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_list_api(self._state_api.list_runtime_envs, req)
|
|
|
|
@routes.get("/api/v0/logs")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
"""Return a list of log files on a given node id.
|
|
|
|
Unlike other list APIs that display all existing resources in the cluster,
|
|
this API always require to specify node id and node ip.
|
|
"""
|
|
glob_filter = req.query.get("glob", "*")
|
|
node_id = req.query.get("node_id", None)
|
|
node_ip = req.query.get("node_ip", None)
|
|
timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
|
|
|
|
# TODO(sang): Do input validation from the middleware instead.
|
|
if not node_id and not node_ip:
|
|
return self._reply(
|
|
success=False,
|
|
error_message=(
|
|
"Both node id and node ip are not provided. "
|
|
"Please provide at least one of them."
|
|
),
|
|
result=None,
|
|
)
|
|
|
|
node_id = node_id or self._log_api.ip_to_node_id(node_ip)
|
|
if not node_id:
|
|
return self._reply(
|
|
success=False,
|
|
error_message=(
|
|
f"Cannot find matching node_id for a given node ip {node_ip}"
|
|
),
|
|
result=None,
|
|
)
|
|
|
|
try:
|
|
result = await self._log_api.list_logs(
|
|
node_id, timeout, glob_filter=glob_filter
|
|
)
|
|
except DataSourceUnavailable as e:
|
|
return self._reply(success=False, error_message=str(e), result=None)
|
|
|
|
return self._reply(success=True, error_message="", result=result)
|
|
|
|
@routes.get("/api/v0/logs/{media_type}")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def get_logs(self, req: aiohttp.web.Request):
|
|
# TODO(sang): We need a better error handling for streaming
|
|
# when we refactor the server framework.
|
|
options = GetLogOptions(
|
|
timeout=int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)),
|
|
node_id=req.query.get("node_id", None),
|
|
node_ip=req.query.get("node_ip", None),
|
|
media_type=req.match_info.get("media_type", "file"),
|
|
filename=req.query.get("filename", None),
|
|
actor_id=req.query.get("actor_id", None),
|
|
task_id=req.query.get("task_id", None),
|
|
pid=req.query.get("pid", None),
|
|
lines=req.query.get("lines", DEFAULT_LOG_LIMIT),
|
|
interval=req.query.get("interval", None),
|
|
)
|
|
|
|
response = aiohttp.web.StreamResponse()
|
|
response.content_type = "text/plain"
|
|
await response.prepare(req)
|
|
|
|
# NOTE: The first byte indicates the success / failure of individual
|
|
# stream. If the first byte is b"1", it means the stream was successful.
|
|
# If it is b"0", it means it is failed.
|
|
try:
|
|
async for logs_in_bytes in self._log_api.stream_logs(options):
|
|
logs_to_stream = bytearray(b"1")
|
|
logs_to_stream.extend(logs_in_bytes)
|
|
await response.write(bytes(logs_to_stream))
|
|
await response.write_eof()
|
|
return response
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
error_msg = bytearray(b"0")
|
|
error_msg.extend(
|
|
f"Closing HTTP stream due to internal server error.\n{e}".encode()
|
|
)
|
|
|
|
await response.write(bytes(error_msg))
|
|
await response.write_eof()
|
|
return response
|
|
|
|
async def _handle_summary_api(
|
|
self,
|
|
summary_fn: Callable[[SummaryApiOptions], SummaryApiResponse],
|
|
req: aiohttp.web.Request,
|
|
):
|
|
result = await summary_fn(option=self._summary_options_from_req(req))
|
|
return self._reply(
|
|
success=True,
|
|
error_message="",
|
|
result=asdict(result),
|
|
)
|
|
|
|
@routes.get("/api/v0/tasks/summarize")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def summarize_tasks(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_summary_api(self._state_api.summarize_tasks, req)
|
|
|
|
@routes.get("/api/v0/actors/summarize")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def summarize_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_summary_api(self._state_api.summarize_actors, req)
|
|
|
|
@routes.get("/api/v0/objects/summarize")
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def summarize_objects(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
return await self._handle_summary_api(self._state_api.summarize_objects, req)
|
|
|
|
async def run(self, server):
|
|
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
|
self._state_api_data_source_client = StateDataSourceClient(gcs_channel)
|
|
self._state_api = StateAPIManager(self._state_api_data_source_client)
|
|
self._log_api = LogsManager(self._state_api_data_source_client)
|
|
|
|
@staticmethod
|
|
def is_minimal_module():
|
|
return False
|