[ParallelIterator] Fix for_each concurrent test cases/bugs (#8964)

* Everything works

* Update python/ray/util/iter.py

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>

* .

* .

* removed print statements

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Alex Wu 2020-06-22 18:26:45 -07:00 committed by GitHub
parent b88059326d
commit 40c15b1ba0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 25 deletions

View file

@ -248,7 +248,7 @@ class Semaphore:
self._sema = asyncio.Semaphore(value=value)
async def acquire(self):
self._sema.acquire()
await self._sema.acquire()
async def release(self):
self._sema.release()

View file

@ -159,7 +159,7 @@ def test_for_each(ray_start_regular_shared):
assert list(it.gather_sync()) == [0, 4, 2, 6]
def test_for_each_concur(ray_start_regular_shared):
def test_for_each_concur_async(ray_start_regular_shared):
main_wait = Semaphore.remote(value=0)
test_wait = Semaphore.remote(value=0)
@ -169,15 +169,18 @@ def test_for_each_concur(ray_start_regular_shared):
ray.get(test_wait.acquire.remote())
return i + 10
@ray.remote(num_cpus=0.1)
@ray.remote(num_cpus=0.01)
def to_list(it):
return list(it)
it = from_items(
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1})
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
list_promise = to_list.remote(it.gather_async())
for i in range(4):
assert i in [0, 1, 2, 3]
ray.get(main_wait.acquire.remote())
# There should be exactly 4 tasks executing at this point.
@ -189,12 +192,49 @@ def test_for_each_concur(ray_start_regular_shared):
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
# Finish everything and make sure the output matches a regular iterator.
for i in range(3):
for i in range(7):
ray.get(test_wait.release.remote())
assert repr(
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18))
result_list = ray.get(list_promise)
assert set(result_list) == set(range(10, 18))
def test_for_each_concur_sync(ray_start_regular_shared):
main_wait = Semaphore.remote(value=0)
test_wait = Semaphore.remote(value=0)
def task(x):
i, main_wait, test_wait = x
ray.get(main_wait.release.remote())
ray.get(test_wait.acquire.remote())
return i + 10
@ray.remote(num_cpus=0.01)
def to_list(it):
return list(it)
it = from_items(
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
list_promise = to_list.remote(it.gather_sync())
for i in range(4):
assert i in [0, 1, 2, 3]
ray.get(main_wait.acquire.remote())
# There should be exactly 4 tasks executing at this point.
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
for i in range(8):
ray.get(test_wait.release.remote())
assert repr(
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
result_list = ray.get(list_promise)
assert set(result_list) == set(range(10, 18))
def test_combine(ray_start_regular_shared):

View file

@ -203,9 +203,9 @@ class ParallelIterator(Generic[T]):
`max_concurrency` should be used to achieve a high degree of
parallelism without the overhead of increasing the number of shards
(which are actor based). This provides the semantic guarantee that
`fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not
necessarily finish first)
(which are actor based). If `max_concurrency` is not 1, this function
provides no semantic guarantees on the output order.
Results will be returned as soon as they are ready.
A performance note: When executing concurrently, this function
maintains its own internal buffer. If `num_async` is `n` and
@ -234,6 +234,7 @@ class ParallelIterator(Generic[T]):
... [0, 2, 4, 8]
"""
assert max_concurrency >= 0, "max_concurrency must be non-negative."
return self._with_transform(
lambda local_it: local_it.for_each(fn, max_concurrency, resources),
".for_each()")
@ -765,23 +766,13 @@ class LocalIterator(Generic[T]):
if isinstance(item, _NextValueNotReady):
yield item
else:
finished, remaining = ray.wait(cur, timeout=0)
if max_concurrency and len(
remaining) >= max_concurrency:
ray.wait(cur, num_returns=(len(finished) + 1))
if max_concurrency and len(cur) >= max_concurrency:
finished, cur = ray.wait(cur)
yield from ray.get(finished)
cur.append(remote_fn(item))
while len(cur) > 0:
to_yield = cur[0]
finished, remaining = ray.wait(
[to_yield], timeout=0)
if finished:
cur.pop(0)
yield ray.get(to_yield)
else:
break
yield from ray.get(cur)
while cur:
finished, cur = ray.wait(cur)
yield from ray.get(finished)
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
unwrapped = apply_foreach