import asyncio
import sys
import pytest
from ray.dashboard.modules.state.state_head import RateLimitedModule


class FailedCallError(Exception):
    pass


class A(RateLimitedModule):
    def __init__(self, max_num_call: int):
        import logging

        super().__init__(max_num_call, logging.getLogger(__name__))

    @RateLimitedModule.enforce_max_concurrent_calls
    async def fn1(self, err: bool = False):
        if err:
            raise FailedCallError

        await asyncio.sleep(3)
        return True

    @RateLimitedModule.enforce_max_concurrent_calls
    async def fn2(self):
        await asyncio.sleep(3)
        return True

    async def limit_handler_(self):
        return False


@pytest.mark.asyncio
@pytest.mark.parametrize("extra_req_num", [-5, -3, -1, 0, 1, 3, 5])
async def test_max_concurrent_in_progress_functions(extra_req_num):
    """Test rate limiting for concurrent in-progress requests on StateHead"""
    max_req = 10
    a = A(max_num_call=max_req)

    # Run more than allowed concurrent async functions should trigger rate limiting
    res_arr = await asyncio.gather(
        *[a.fn1() if i % 2 == 0 else a.fn2() for i in range(max_req + extra_req_num)]
    )
    fail_cnt = 0
    for ok in res_arr:
        fail_cnt += 0 if ok else 1

    expected_fail_cnt = max(0, extra_req_num)
    assert fail_cnt == expected_fail_cnt, (
        f"{expected_fail_cnt} out of {max_req + extra_req_num} "
        f"concurrent runs should fail with max={max_req} but {fail_cnt}."
    )

    assert a.num_call_ == 0, "All requests should be done"


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "failures",
    [
        [True, True, True, True, True],
        [False, False, False, False, False],
        [False, True, False, True, False],
        [False, False, False, True, True],
        [True, True, False, False, False],
    ],
)
async def test_max_concurrent_with_exceptions(failures):

    max_req = 10
    a = A(max_num_call=max_req)

    # Run more than allowed concurrent async functions should trigger rate limiting
    res_arr = await asyncio.gather(
        *[a.fn1(err=should_throw_err) for should_throw_err in failures],
        return_exceptions=True,
    )

    expected_num_failure = sum(failures)

    actual_num_failure = 0
    for res in res_arr:
        if isinstance(res, FailedCallError):
            actual_num_failure += 1

    assert expected_num_failure == actual_num_failure, "All failures should be captured"
    assert a.num_call_ == 0, "Failure should decrement the counter correctly"


if __name__ == "__main__":
    sys.exit(pytest.main(["-v", __file__]))