mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ParallelIterator] Fix for_each concurrent test cases/bugs (#8964)
* Everything works * Update python/ray/util/iter.py Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com> * . * . * removed print statements Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
parent
b88059326d
commit
40c15b1ba0
3 changed files with 56 additions and 25 deletions
|
@ -248,7 +248,7 @@ class Semaphore:
|
|||
self._sema = asyncio.Semaphore(value=value)
|
||||
|
||||
async def acquire(self):
|
||||
self._sema.acquire()
|
||||
await self._sema.acquire()
|
||||
|
||||
async def release(self):
|
||||
self._sema.release()
|
||||
|
|
|
@ -159,7 +159,7 @@ def test_for_each(ray_start_regular_shared):
|
|||
assert list(it.gather_sync()) == [0, 4, 2, 6]
|
||||
|
||||
|
||||
def test_for_each_concur(ray_start_regular_shared):
|
||||
def test_for_each_concur_async(ray_start_regular_shared):
|
||||
main_wait = Semaphore.remote(value=0)
|
||||
test_wait = Semaphore.remote(value=0)
|
||||
|
||||
|
@ -169,15 +169,18 @@ def test_for_each_concur(ray_start_regular_shared):
|
|||
ray.get(test_wait.acquire.remote())
|
||||
return i + 10
|
||||
|
||||
@ray.remote(num_cpus=0.1)
|
||||
@ray.remote(num_cpus=0.01)
|
||||
def to_list(it):
|
||||
return list(it)
|
||||
|
||||
it = from_items(
|
||||
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
|
||||
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1})
|
||||
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
|
||||
|
||||
list_promise = to_list.remote(it.gather_async())
|
||||
|
||||
for i in range(4):
|
||||
assert i in [0, 1, 2, 3]
|
||||
ray.get(main_wait.acquire.remote())
|
||||
|
||||
# There should be exactly 4 tasks executing at this point.
|
||||
|
@ -189,12 +192,49 @@ def test_for_each_concur(ray_start_regular_shared):
|
|||
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
|
||||
|
||||
# Finish everything and make sure the output matches a regular iterator.
|
||||
for i in range(3):
|
||||
for i in range(7):
|
||||
ray.get(test_wait.release.remote())
|
||||
|
||||
assert repr(
|
||||
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
|
||||
assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18))
|
||||
result_list = ray.get(list_promise)
|
||||
assert set(result_list) == set(range(10, 18))
|
||||
|
||||
|
||||
def test_for_each_concur_sync(ray_start_regular_shared):
|
||||
main_wait = Semaphore.remote(value=0)
|
||||
test_wait = Semaphore.remote(value=0)
|
||||
|
||||
def task(x):
|
||||
i, main_wait, test_wait = x
|
||||
ray.get(main_wait.release.remote())
|
||||
ray.get(test_wait.acquire.remote())
|
||||
return i + 10
|
||||
|
||||
@ray.remote(num_cpus=0.01)
|
||||
def to_list(it):
|
||||
return list(it)
|
||||
|
||||
it = from_items(
|
||||
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
|
||||
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.01})
|
||||
|
||||
list_promise = to_list.remote(it.gather_sync())
|
||||
|
||||
for i in range(4):
|
||||
assert i in [0, 1, 2, 3]
|
||||
ray.get(main_wait.acquire.remote())
|
||||
|
||||
# There should be exactly 4 tasks executing at this point.
|
||||
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
|
||||
|
||||
for i in range(8):
|
||||
ray.get(test_wait.release.remote())
|
||||
|
||||
assert repr(
|
||||
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
|
||||
result_list = ray.get(list_promise)
|
||||
assert set(result_list) == set(range(10, 18))
|
||||
|
||||
|
||||
def test_combine(ray_start_regular_shared):
|
||||
|
|
|
@ -203,9 +203,9 @@ class ParallelIterator(Generic[T]):
|
|||
|
||||
`max_concurrency` should be used to achieve a high degree of
|
||||
parallelism without the overhead of increasing the number of shards
|
||||
(which are actor based). This provides the semantic guarantee that
|
||||
`fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not
|
||||
necessarily finish first)
|
||||
(which are actor based). If `max_concurrency` is not 1, this function
|
||||
provides no semantic guarantees on the output order.
|
||||
Results will be returned as soon as they are ready.
|
||||
|
||||
A performance note: When executing concurrently, this function
|
||||
maintains its own internal buffer. If `num_async` is `n` and
|
||||
|
@ -234,6 +234,7 @@ class ParallelIterator(Generic[T]):
|
|||
... [0, 2, 4, 8]
|
||||
|
||||
"""
|
||||
assert max_concurrency >= 0, "max_concurrency must be non-negative."
|
||||
return self._with_transform(
|
||||
lambda local_it: local_it.for_each(fn, max_concurrency, resources),
|
||||
".for_each()")
|
||||
|
@ -765,23 +766,13 @@ class LocalIterator(Generic[T]):
|
|||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
finished, remaining = ray.wait(cur, timeout=0)
|
||||
if max_concurrency and len(
|
||||
remaining) >= max_concurrency:
|
||||
ray.wait(cur, num_returns=(len(finished) + 1))
|
||||
if max_concurrency and len(cur) >= max_concurrency:
|
||||
finished, cur = ray.wait(cur)
|
||||
yield from ray.get(finished)
|
||||
cur.append(remote_fn(item))
|
||||
|
||||
while len(cur) > 0:
|
||||
to_yield = cur[0]
|
||||
finished, remaining = ray.wait(
|
||||
[to_yield], timeout=0)
|
||||
if finished:
|
||||
cur.pop(0)
|
||||
yield ray.get(to_yield)
|
||||
else:
|
||||
break
|
||||
|
||||
yield from ray.get(cur)
|
||||
while cur:
|
||||
finished, cur = ray.wait(cur)
|
||||
yield from ray.get(finished)
|
||||
|
||||
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
|
||||
unwrapped = apply_foreach
|
||||
|
|
Loading…
Add table
Reference in a new issue