ray/rllib/execution/tests/test_async_requests_manager.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

228 lines
8.3 KiB
Python
Raw Normal View History

import random
import pytest
import unittest
import ray
import time
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
@ray.remote
class RemoteRLlibActor:
def __init__(self, sleep_time):
self.sleep_time = sleep_time
def apply(self, func, *_args, **_kwargs):
return func(self, *_args, **_kwargs)
def task(self):
time.sleep(self.sleep_time)
return "done"
def task2(self, a, b):
time.sleep(self.sleep_time)
return a + b
class TestAsyncRequestsManager(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
random.seed(0)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
@classmethod
def shutdown_method(cls):
ray.shutdown()
def test_async_requests_manager_num_returns(self):
"""Tests that an async manager can properly handle actors with tasks that
vary in the amount of time that they take to run"""
workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)]
workers += [RemoteRLlibActor.remote(sleep_time=5) for _ in range(2)]
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=1
)
for _ in range(4):
manager.call(lambda w: w.task())
time.sleep(3)
if not len(manager.get_ready()) == 2:
raise Exception(
"We should return the 2 ready requests in this case from the actors"
" that have shorter tasks"
)
time.sleep(7)
if not len(manager.get_ready()) == 2:
raise Exception(
"We should return the 2 ready requests in this case from the actors"
" that have longer tasks"
)
def test_round_robin_scheduling(self):
"""Test that the async manager schedules actors in a round robin fashion"""
workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)]
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=2
)
for i in range(4):
scheduled_actor = workers[i % len(workers)]
manager.call(lambda w: w.task())
if i < 2:
assert len(manager._remote_requests_in_flight[scheduled_actor]) == 1, (
"We should have 1 request in flight for the actor that we just "
"scheduled on"
)
else:
assert len(manager._remote_requests_in_flight[scheduled_actor]) == 2, (
"We should have 2 request in flight for the actor that we just "
"scheduled on"
)
def test_test_async_requests_task_doesnt_buffering(self):
"""Tests that the async manager drops"""
workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)]
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=2
)
for i in range(8):
scheduled = manager.call(lambda w: w.task())
if i < 4:
assert scheduled, "We should have scheduled the task"
else:
assert not scheduled, (
"We should not have scheduled the task because"
" all workers are busy."
)
assert len(manager._pending_remotes) == 4, "We should have 4 pending requests"
time.sleep(3)
ready_requests = manager.get_ready()
for worker in workers:
if not len(ready_requests[worker]) == 2:
raise Exception(
"We should return the 2 ready requests in this case from each "
"actors."
)
for _ in range(4):
manager.call(lambda w: w.task())
# new tasks scheduled from the buffer
time.sleep(3)
ready_requests = manager.get_ready()
for worker in workers:
if not len(ready_requests[worker]) == 2:
raise Exception(
"We should return the 2 ready requests in this case from each "
"actors"
)
def test_args_kwargs(self):
"""Tests that the async manager can properly handle actors with tasks that
vary in the amount of time that they take to run"""
workers = [RemoteRLlibActor.remote(sleep_time=0.1)]
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=2
)
for _ in range(2):
manager.call(lambda w, a, b: w.task2(a, b), fn_args=[1, 2])
time.sleep(3)
if not len(manager.get_ready()[workers[0]]) == 2:
raise Exception(
"We should return the 2 ready requests in this case from the actors"
" that have shorter tasks"
)
for _ in range(2):
manager.call(lambda w, a, b: w.task2(a, b), fn_kwargs=dict(a=1, b=2))
time.sleep(3)
if not len(manager.get_ready()[workers[0]]) == 2:
raise Exception(
"We should return the 2 ready requests in this case from the actors"
" that have longer tasks"
)
def test_add_remove_actors(self):
"""Tests that the async manager can properly add and remove actors"""
workers = []
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=2
)
if not (
(
len(manager._all_workers)
== len(manager._remote_requests_in_flight)
== len(manager._pending_to_actor)
== len(manager._pending_remotes)
== 0
)
):
raise ValueError("We should have no workers in this case.")
assert not manager.call(lambda w: w.task()), (
"Task shouldn't have been "
"launched since there are no "
"workers in the manager."
)
worker = RemoteRLlibActor.remote(sleep_time=0.1)
manager.add_workers(worker)
manager.call(lambda w: w.task())
if not (
len(manager._remote_requests_in_flight[worker])
== len(manager._pending_to_actor)
== len(manager._all_workers)
== len(manager._pending_remotes)
== 1
):
raise ValueError("We should have 1 worker and 1 pending request")
time.sleep(3)
manager.get_ready()
# test worker removal
for i in range(2):
manager.call(lambda w: w.task())
assert len(manager._pending_remotes) == i + 1
manager.remove_workers(worker)
if not ((len(manager._all_workers) == 0)):
raise ValueError("We should have no workers that we can schedule tasks to")
if not (
(len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2)
):
raise ValueError(
"We should still have 2 pending requests in flight from the worker"
)
time.sleep(3)
result = manager.get_ready()
if not (
len(result) == 1
and len(result[worker]) == 2
and len(manager._pending_remotes) == 0
and len(manager._pending_to_actor) == 0
):
raise ValueError(
"We should have 2 ready results from the worker and no pending requests"
)
def test_call_to_actor(self):
workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)]
worker_not_in_manager = RemoteRLlibActor.remote(sleep_time=0.1)
manager = AsyncRequestsManager(
workers, max_remote_requests_in_flight_per_worker=2
)
manager.call(lambda w: w.task(), actor=workers[0])
time.sleep(3)
results = manager.get_ready()
if not len(results) == 1 and workers[0] not in results:
raise Exception(
"We should return the 1 ready requests in this case from the worker we "
"called to"
)
with pytest.raises(ValueError, match=".*has not been added to the manager.*"):
manager.call(lambda w: w.task(), actor=worker_not_in_manager)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))