mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00

This is to limit the max number of HTTP requests the dashboard (API server) will accept before rejecting more requests. This will make sure the observability requests do not overload the downstream systems (raylet/gcs) when delegating too many concurrent state observability requests to the cluster.
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import asyncio
|
|
import sys
|
|
import pytest
|
|
from ray.dashboard.modules.state.state_head import RateLimitedModule
|
|
|
|
|
|
class FailedCallError(Exception):
|
|
pass
|
|
|
|
|
|
class A(RateLimitedModule):
|
|
def __init__(self, max_num_call: int):
|
|
import logging
|
|
|
|
super().__init__(max_num_call, logging.getLogger(__name__))
|
|
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def fn1(self, err: bool = False):
|
|
if err:
|
|
raise FailedCallError
|
|
|
|
await asyncio.sleep(3)
|
|
return True
|
|
|
|
@RateLimitedModule.enforce_max_concurrent_calls
|
|
async def fn2(self):
|
|
await asyncio.sleep(3)
|
|
return True
|
|
|
|
async def limit_handler_(self):
|
|
return False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("extra_req_num", [-5, -3, -1, 0, 1, 3, 5])
|
|
async def test_max_concurrent_in_progress_functions(extra_req_num):
|
|
"""Test rate limiting for concurrent in-progress requests on StateHead"""
|
|
max_req = 10
|
|
a = A(max_num_call=max_req)
|
|
|
|
# Run more than allowed concurrent async functions should trigger rate limiting
|
|
res_arr = await asyncio.gather(
|
|
*[a.fn1() if i % 2 == 0 else a.fn2() for i in range(max_req + extra_req_num)]
|
|
)
|
|
fail_cnt = 0
|
|
for ok in res_arr:
|
|
fail_cnt += 0 if ok else 1
|
|
|
|
expected_fail_cnt = max(0, extra_req_num)
|
|
assert fail_cnt == expected_fail_cnt, (
|
|
f"{expected_fail_cnt} out of {max_req + extra_req_num} "
|
|
f"concurrent runs should fail with max={max_req} but {fail_cnt}."
|
|
)
|
|
|
|
assert a.num_call_ == 0, "All requests should be done"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"failures",
|
|
[
|
|
[True, True, True, True, True],
|
|
[False, False, False, False, False],
|
|
[False, True, False, True, False],
|
|
[False, False, False, True, True],
|
|
[True, True, False, False, False],
|
|
],
|
|
)
|
|
async def test_max_concurrent_with_exceptions(failures):
|
|
|
|
max_req = 10
|
|
a = A(max_num_call=max_req)
|
|
|
|
# Run more than allowed concurrent async functions should trigger rate limiting
|
|
res_arr = await asyncio.gather(
|
|
*[a.fn1(err=should_throw_err) for should_throw_err in failures],
|
|
return_exceptions=True,
|
|
)
|
|
|
|
expected_num_failure = sum(failures)
|
|
|
|
actual_num_failure = 0
|
|
for res in res_arr:
|
|
if isinstance(res, FailedCallError):
|
|
actual_num_failure += 1
|
|
|
|
assert expected_num_failure == actual_num_failure, "All failures should be captured"
|
|
assert a.num_call_ == 0, "Failure should decrement the counter correctly"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(pytest.main(["-v", __file__]))
|