From 40c15b1ba05e5c9a8ea1165f3a7dffda04156a41 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Mon, 22 Jun 2020 18:26:45 -0700 Subject: [PATCH] [ParallelIterator] Fix for_each concurrent test cases/bugs (#8964) * Everything works * Update python/ray/util/iter.py Co-authored-by: Amog Kamsetty * . * . * removed print statements Co-authored-by: Amog Kamsetty --- python/ray/test_utils.py | 2 +- python/ray/tests/test_iter.py | 50 +++++++++++++++++++++++++++++++---- python/ray/util/iter.py | 29 +++++++------------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 7bc37ebcc..7d696da1d 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -248,7 +248,7 @@ class Semaphore: self._sema = asyncio.Semaphore(value=value) async def acquire(self): - self._sema.acquire() + await self._sema.acquire() async def release(self): self._sema.release() diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index f1e140bc7..8eb0d606c 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -159,7 +159,7 @@ def test_for_each(ray_start_regular_shared): assert list(it.gather_sync()) == [0, 4, 2, 6] -def test_for_each_concur(ray_start_regular_shared): +def test_for_each_concur_async(ray_start_regular_shared): main_wait = Semaphore.remote(value=0) test_wait = Semaphore.remote(value=0) @@ -169,15 +169,18 @@ def test_for_each_concur(ray_start_regular_shared): ray.get(test_wait.acquire.remote()) return i + 10 - @ray.remote(num_cpus=0.1) + @ray.remote(num_cpus=0.01) def to_list(it): return list(it) it = from_items( [(i, main_wait, test_wait) for i in range(8)], num_shards=2) - it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1}) + it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01}) + + list_promise = to_list.remote(it.gather_async()) for i in range(4): + assert i in [0, 1, 2, 3] ray.get(main_wait.acquire.remote()) # There should be exactly 4 tasks executing at this point. @@ -189,12 +192,49 @@ def test_for_each_concur(ray_start_regular_shared): assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism" # Finish everything and make sure the output matches a regular iterator. - for i in range(3): + for i in range(7): ray.get(test_wait.release.remote()) assert repr( it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]" - assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18)) + result_list = ray.get(list_promise) + assert set(result_list) == set(range(10, 18)) + + +def test_for_each_concur_sync(ray_start_regular_shared): + main_wait = Semaphore.remote(value=0) + test_wait = Semaphore.remote(value=0) + + def task(x): + i, main_wait, test_wait = x + ray.get(main_wait.release.remote()) + ray.get(test_wait.acquire.remote()) + return i + 10 + + @ray.remote(num_cpus=0.01) + def to_list(it): + return list(it) + + it = from_items( + [(i, main_wait, test_wait) for i in range(8)], num_shards=2) + it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01}) + + list_promise = to_list.remote(it.gather_sync()) + + for i in range(4): + assert i in [0, 1, 2, 3] + ray.get(main_wait.acquire.remote()) + + # There should be exactly 4 tasks executing at this point. + assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism" + + for i in range(8): + ray.get(test_wait.release.remote()) + + assert repr( + it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]" + result_list = ray.get(list_promise) + assert set(result_list) == set(range(10, 18)) def test_combine(ray_start_regular_shared): diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 542a53aec..c277e4128 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -203,9 +203,9 @@ class ParallelIterator(Generic[T]): `max_concurrency` should be used to achieve a high degree of parallelism without the overhead of increasing the number of shards - (which are actor based). This provides the semantic guarantee that - `fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not - necessarily finish first) + (which are actor based). If `max_concurrency` is not 1, this function + provides no semantic guarantees on the output order. + Results will be returned as soon as they are ready. A performance note: When executing concurrently, this function maintains its own internal buffer. If `num_async` is `n` and @@ -234,6 +234,7 @@ class ParallelIterator(Generic[T]): ... [0, 2, 4, 8] """ + assert max_concurrency >= 0, "max_concurrency must be non-negative." return self._with_transform( lambda local_it: local_it.for_each(fn, max_concurrency, resources), ".for_each()") @@ -765,23 +766,13 @@ class LocalIterator(Generic[T]): if isinstance(item, _NextValueNotReady): yield item else: - finished, remaining = ray.wait(cur, timeout=0) - if max_concurrency and len( - remaining) >= max_concurrency: - ray.wait(cur, num_returns=(len(finished) + 1)) + if max_concurrency and len(cur) >= max_concurrency: + finished, cur = ray.wait(cur) + yield from ray.get(finished) cur.append(remote_fn(item)) - - while len(cur) > 0: - to_yield = cur[0] - finished, remaining = ray.wait( - [to_yield], timeout=0) - if finished: - cur.pop(0) - yield ray.get(to_yield) - else: - break - - yield from ray.get(cur) + while cur: + finished, cur = ray.wait(cur) + yield from ray.get(finished) if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): unwrapped = apply_foreach