mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
139 lines
5 KiB
Python
139 lines
5 KiB
Python
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import ray
|
|
from ray.rllib.utils.actors import TaskPool
|
|
|
|
|
|
def createMockWorkerAndObjectRef(obj_ref):
|
|
return ({obj_ref: 1}, obj_ref)
|
|
|
|
|
|
class TaskPoolTest(unittest.TestCase):
|
|
@patch("ray.wait")
|
|
def test_completed_prefetch_yieldsAllComplete(self, rayWaitMock):
|
|
task1 = createMockWorkerAndObjectRef(1)
|
|
task2 = createMockWorkerAndObjectRef(2)
|
|
# Return the second task as complete and the first as pending
|
|
rayWaitMock.return_value = ([2], [1])
|
|
|
|
pool = TaskPool()
|
|
pool.add(*task1)
|
|
pool.add(*task2)
|
|
|
|
fetched = list(pool.completed_prefetch())
|
|
self.assertListEqual(fetched, [task2])
|
|
|
|
@patch("ray.wait")
|
|
def test_completed_prefetch_yieldsAllCompleteUpToDefaultLimit(self, rayWaitMock):
|
|
# Load the pool with 1000 tasks, mock them all as complete and then
|
|
# check that the first call to completed_prefetch only yields 999
|
|
# items and the second call yields the final one
|
|
pool = TaskPool()
|
|
for i in range(1000):
|
|
task = createMockWorkerAndObjectRef(i)
|
|
pool.add(*task)
|
|
|
|
rayWaitMock.return_value = (list(range(1000)), [])
|
|
|
|
# For this test, we're only checking the object refs
|
|
fetched = [pair[1] for pair in pool.completed_prefetch()]
|
|
self.assertListEqual(fetched, list(range(999)))
|
|
|
|
# Finally, check the next iteration returns the final taks
|
|
fetched = [pair[1] for pair in pool.completed_prefetch()]
|
|
self.assertListEqual(fetched, [999])
|
|
|
|
@patch("ray.wait")
|
|
def test_completed_prefetch_yieldsAllCompleteUpToSpecifiedLimit(self, rayWaitMock):
|
|
# Load the pool with 1000 tasks, mock them all as complete and then
|
|
# check that the first call to completed_prefetch only yield 999 items
|
|
# and the second call yields the final one
|
|
pool = TaskPool()
|
|
for i in range(1000):
|
|
task = createMockWorkerAndObjectRef(i)
|
|
pool.add(*task)
|
|
|
|
rayWaitMock.return_value = (list(range(1000)), [])
|
|
|
|
# Verify that only the first 500 tasks are returned, this should leave
|
|
# some tasks in the _fetching deque for later
|
|
fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=500)]
|
|
self.assertListEqual(fetched, list(range(500)))
|
|
|
|
# Finally, check the next iteration returns the remaining tasks
|
|
fetched = [pair[1] for pair in pool.completed_prefetch()]
|
|
self.assertListEqual(fetched, list(range(500, 1000)))
|
|
|
|
@patch("ray.wait")
|
|
def test_completed_prefetch_yieldsRemainingIfIterationStops(self, rayWaitMock):
|
|
# Test for issue #7106
|
|
# In versions of Ray up to 0.8.1, if the pre-fetch generator failed to
|
|
# run to completion, then the TaskPool would fail to clear up already
|
|
# fetched tasks resulting in stale object refs being returned
|
|
pool = TaskPool()
|
|
for i in range(10):
|
|
task = createMockWorkerAndObjectRef(i)
|
|
pool.add(*task)
|
|
|
|
rayWaitMock.return_value = (list(range(10)), [])
|
|
|
|
# This should fetch just the first item in the list
|
|
try:
|
|
for _ in pool.completed_prefetch():
|
|
# Simulate a worker failure returned by ray.get()
|
|
raise ray.exceptions.RayError
|
|
except ray.exceptions.RayError:
|
|
pass
|
|
|
|
# This fetch should return the remaining pre-fetched tasks
|
|
fetched = [pair[1] for pair in pool.completed_prefetch()]
|
|
self.assertListEqual(fetched, list(range(1, 10)))
|
|
|
|
@patch("ray.wait")
|
|
def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock):
|
|
pool = TaskPool()
|
|
# We need to hold onto the tasks for this test so that we can fail a
|
|
# specific worker
|
|
tasks = []
|
|
|
|
for i in range(10):
|
|
task = createMockWorkerAndObjectRef(i)
|
|
pool.add(*task)
|
|
tasks.append(task)
|
|
|
|
# Simulate only some of the work being complete and fetch a couple of
|
|
# tasks in order to fill the fetching queue
|
|
rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9])
|
|
fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)]
|
|
|
|
# As we still have some pending tasks, we need to update the
|
|
# completion states to remove the completed tasks
|
|
rayWaitMock.return_value = ([], [6, 7, 8, 9])
|
|
|
|
pool.reset_workers(
|
|
[
|
|
tasks[0][0],
|
|
tasks[1][0],
|
|
tasks[2][0],
|
|
tasks[3][0],
|
|
# OH NO! WORKER 4 HAS CRASHED!
|
|
tasks[5][0],
|
|
tasks[6][0],
|
|
tasks[7][0],
|
|
tasks[8][0],
|
|
tasks[9][0],
|
|
]
|
|
)
|
|
|
|
# Fetch the remaining tasks which should already be in the _fetching
|
|
# queue
|
|
fetched = [pair[1] for pair in pool.completed_prefetch()]
|
|
self.assertListEqual(fetched, [2, 3, 5])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|