import random import pytest import unittest import ray import time from ray.rllib.execution.parallel_requests import AsyncRequestsManager @ray.remote class RemoteRLlibActor: def __init__(self, sleep_time): self.sleep_time = sleep_time def apply(self, func, *_args, **_kwargs): return func(self, *_args, **_kwargs) def task(self): time.sleep(self.sleep_time) return "done" def task2(self, a, b): time.sleep(self.sleep_time) return a + b class TestAsyncRequestsManager(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init(num_cpus=4) random.seed(0) @classmethod def tearDownClass(cls) -> None: ray.shutdown() @classmethod def shutdown_method(cls): ray.shutdown() def test_async_requests_manager_num_returns(self): """Tests that an async manager can properly handle actors with tasks that vary in the amount of time that they take to run""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] workers += [RemoteRLlibActor.remote(sleep_time=5) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=1 ) for _ in range(4): manager.call(lambda w: w.task()) time.sleep(3) if not len(manager.get_ready()) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have shorter tasks" ) time.sleep(7) if not len(manager.get_ready()) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have longer tasks" ) def test_round_robin_scheduling(self): """Test that the async manager schedules actors in a round robin fashion""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) for i in range(4): scheduled_actor = workers[i % len(workers)] manager.call(lambda w: w.task()) if i < 2: assert len(manager._remote_requests_in_flight[scheduled_actor]) == 1, ( "We should have 1 request in flight for the actor that we just " "scheduled on" ) else: assert len(manager._remote_requests_in_flight[scheduled_actor]) == 2, ( "We should have 2 request in flight for the actor that we just " "scheduled on" ) def test_test_async_requests_task_doesnt_buffering(self): """Tests that the async manager drops""" workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) for i in range(8): scheduled = manager.call(lambda w: w.task()) if i < 4: assert scheduled, "We should have scheduled the task" else: assert not scheduled, ( "We should not have scheduled the task because" " all workers are busy." ) assert len(manager._pending_remotes) == 4, "We should have 4 pending requests" time.sleep(3) ready_requests = manager.get_ready() for worker in workers: if not len(ready_requests[worker]) == 2: raise Exception( "We should return the 2 ready requests in this case from each " "actors." ) for _ in range(4): manager.call(lambda w: w.task()) # new tasks scheduled from the buffer time.sleep(3) ready_requests = manager.get_ready() for worker in workers: if not len(ready_requests[worker]) == 2: raise Exception( "We should return the 2 ready requests in this case from each " "actors" ) def test_args_kwargs(self): """Tests that the async manager can properly handle actors with tasks that vary in the amount of time that they take to run""" workers = [RemoteRLlibActor.remote(sleep_time=0.1)] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) for _ in range(2): manager.call(lambda w, a, b: w.task2(a, b), fn_args=[1, 2]) time.sleep(3) if not len(manager.get_ready()[workers[0]]) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have shorter tasks" ) for _ in range(2): manager.call(lambda w, a, b: w.task2(a, b), fn_kwargs=dict(a=1, b=2)) time.sleep(3) if not len(manager.get_ready()[workers[0]]) == 2: raise Exception( "We should return the 2 ready requests in this case from the actors" " that have longer tasks" ) def test_add_remove_actors(self): """Tests that the async manager can properly add and remove actors""" workers = [] manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) if not ( ( len(manager._all_workers) == len(manager._remote_requests_in_flight) == len(manager._pending_to_actor) == len(manager._pending_remotes) == 0 ) ): raise ValueError("We should have no workers in this case.") assert not manager.call(lambda w: w.task()), ( "Task shouldn't have been " "launched since there are no " "workers in the manager." ) worker = RemoteRLlibActor.remote(sleep_time=0.1) manager.add_workers(worker) manager.call(lambda w: w.task()) if not ( len(manager._remote_requests_in_flight[worker]) == len(manager._pending_to_actor) == len(manager._all_workers) == len(manager._pending_remotes) == 1 ): raise ValueError("We should have 1 worker and 1 pending request") time.sleep(3) manager.get_ready() # test worker removal for i in range(2): manager.call(lambda w: w.task()) assert len(manager._pending_remotes) == i + 1 manager.remove_workers(worker) if not ((len(manager._all_workers) == 0)): raise ValueError("We should have no workers that we can schedule tasks to") if not ( (len(manager._pending_remotes) == 2 and len(manager._pending_to_actor) == 2) ): raise ValueError( "We should still have 2 pending requests in flight from the worker" ) time.sleep(3) result = manager.get_ready() if not ( len(result) == 1 and len(result[worker]) == 2 and len(manager._pending_remotes) == 0 and len(manager._pending_to_actor) == 0 ): raise ValueError( "We should have 2 ready results from the worker and no pending requests" ) def test_call_to_actor(self): workers = [RemoteRLlibActor.remote(sleep_time=0.1) for _ in range(2)] worker_not_in_manager = RemoteRLlibActor.remote(sleep_time=0.1) manager = AsyncRequestsManager( workers, max_remote_requests_in_flight_per_worker=2 ) manager.call(lambda w: w.task(), actor=workers[0]) time.sleep(3) results = manager.get_ready() if not len(results) == 1 and workers[0] not in results: raise Exception( "We should return the 1 ready requests in this case from the worker we " "called to" ) with pytest.raises(ValueError, match=".*has not been added to the manager.*"): manager.call(lambda w: w.task(), actor=worker_not_in_manager) if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))