[ParallelIterator] Fix for_each concurrent test cases/bugs (#8964)

* Everything works

* Update python/ray/util/iter.py

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>

* .

* .

* removed print statements

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Alex Wu 2020-06-22 18:26:45 -07:00 committed by GitHub
parent b88059326d
commit 40c15b1ba0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 25 deletions

View file

@ -248,7 +248,7 @@ class Semaphore:
self._sema = asyncio.Semaphore(value=value) self._sema = asyncio.Semaphore(value=value)
async def acquire(self): async def acquire(self):
self._sema.acquire() await self._sema.acquire()
async def release(self): async def release(self):
self._sema.release() self._sema.release()

View file

@ -159,7 +159,7 @@ def test_for_each(ray_start_regular_shared):
assert list(it.gather_sync()) == [0, 4, 2, 6] assert list(it.gather_sync()) == [0, 4, 2, 6]
def test_for_each_concur(ray_start_regular_shared): def test_for_each_concur_async(ray_start_regular_shared):
main_wait = Semaphore.remote(value=0) main_wait = Semaphore.remote(value=0)
test_wait = Semaphore.remote(value=0) test_wait = Semaphore.remote(value=0)
@ -169,15 +169,18 @@ def test_for_each_concur(ray_start_regular_shared):
ray.get(test_wait.acquire.remote()) ray.get(test_wait.acquire.remote())
return i + 10 return i + 10
@ray.remote(num_cpus=0.1) @ray.remote(num_cpus=0.01)
def to_list(it): def to_list(it):
return list(it) return list(it)
it = from_items( it = from_items(
[(i, main_wait, test_wait) for i in range(8)], num_shards=2) [(i, main_wait, test_wait) for i in range(8)], num_shards=2)
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1}) it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
list_promise = to_list.remote(it.gather_async())
for i in range(4): for i in range(4):
assert i in [0, 1, 2, 3]
ray.get(main_wait.acquire.remote()) ray.get(main_wait.acquire.remote())
# There should be exactly 4 tasks executing at this point. # There should be exactly 4 tasks executing at this point.
@ -189,12 +192,49 @@ def test_for_each_concur(ray_start_regular_shared):
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism" assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
# Finish everything and make sure the output matches a regular iterator. # Finish everything and make sure the output matches a regular iterator.
for i in range(3): for i in range(7):
ray.get(test_wait.release.remote()) ray.get(test_wait.release.remote())
assert repr( assert repr(
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]" it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18)) result_list = ray.get(list_promise)
assert set(result_list) == set(range(10, 18))
def test_for_each_concur_sync(ray_start_regular_shared):
main_wait = Semaphore.remote(value=0)
test_wait = Semaphore.remote(value=0)
def task(x):
i, main_wait, test_wait = x
ray.get(main_wait.release.remote())
ray.get(test_wait.acquire.remote())
return i + 10
@ray.remote(num_cpus=0.01)
def to_list(it):
return list(it)
it = from_items(
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
list_promise = to_list.remote(it.gather_sync())
for i in range(4):
assert i in [0, 1, 2, 3]
ray.get(main_wait.acquire.remote())
# There should be exactly 4 tasks executing at this point.
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
for i in range(8):
ray.get(test_wait.release.remote())
assert repr(
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
result_list = ray.get(list_promise)
assert set(result_list) == set(range(10, 18))
def test_combine(ray_start_regular_shared): def test_combine(ray_start_regular_shared):

View file

@ -203,9 +203,9 @@ class ParallelIterator(Generic[T]):
`max_concurrency` should be used to achieve a high degree of `max_concurrency` should be used to achieve a high degree of
parallelism without the overhead of increasing the number of shards parallelism without the overhead of increasing the number of shards
(which are actor based). This provides the semantic guarantee that (which are actor based). If `max_concurrency` is not 1, this function
`fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not provides no semantic guarantees on the output order.
necessarily finish first) Results will be returned as soon as they are ready.
A performance note: When executing concurrently, this function A performance note: When executing concurrently, this function
maintains its own internal buffer. If `num_async` is `n` and maintains its own internal buffer. If `num_async` is `n` and
@ -234,6 +234,7 @@ class ParallelIterator(Generic[T]):
... [0, 2, 4, 8] ... [0, 2, 4, 8]
""" """
assert max_concurrency >= 0, "max_concurrency must be non-negative."
return self._with_transform( return self._with_transform(
lambda local_it: local_it.for_each(fn, max_concurrency, resources), lambda local_it: local_it.for_each(fn, max_concurrency, resources),
".for_each()") ".for_each()")
@ -765,23 +766,13 @@ class LocalIterator(Generic[T]):
if isinstance(item, _NextValueNotReady): if isinstance(item, _NextValueNotReady):
yield item yield item
else: else:
finished, remaining = ray.wait(cur, timeout=0) if max_concurrency and len(cur) >= max_concurrency:
if max_concurrency and len( finished, cur = ray.wait(cur)
remaining) >= max_concurrency: yield from ray.get(finished)
ray.wait(cur, num_returns=(len(finished) + 1))
cur.append(remote_fn(item)) cur.append(remote_fn(item))
while cur:
while len(cur) > 0: finished, cur = ray.wait(cur)
to_yield = cur[0] yield from ray.get(finished)
finished, remaining = ray.wait(
[to_yield], timeout=0)
if finished:
cur.pop(0)
yield ray.get(to_yield)
else:
break
yield from ray.get(cur)
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
unwrapped = apply_foreach unwrapped = apply_foreach