mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Parallel Iterators] Repartition functionality (#7163)
* repartition and tests * blacklist lib/ files from import checks * addressing comments and splitting up tests * code readability * adding explicit ref for parent iterator * formatting
This commit is contained in:
parent
c6f50ecc51
commit
1737a113be
3 changed files with 140 additions and 5 deletions
|
@ -144,8 +144,7 @@ fi
|
|||
# Ensure import ordering
|
||||
# Make sure that for every import psutil; import setpproctitle
|
||||
# There's a import ray above it.
|
||||
|
||||
python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build
|
||||
python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build -s lib
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted changed files. Please review and stage the changes.'
|
||||
|
|
|
@ -100,6 +100,40 @@ def test_local_shuffle(ray_start_regular_shared):
|
|||
assert value / len(freq_counter) > 0.2
|
||||
|
||||
|
||||
def test_repartition_less(ray_start_regular_shared):
|
||||
it = from_range(9, num_shards=3)
|
||||
it1 = it.repartition(2)
|
||||
assert repr(it1) == ("ParallelIterator[from_range[9, " +
|
||||
"shards=3].repartition[num_partitions=2]]")
|
||||
|
||||
assert it1.num_shards() == 2
|
||||
shard_0_set = set(it1.get_shard(0))
|
||||
shard_1_set = set(it1.get_shard(1))
|
||||
assert shard_0_set == {0, 2, 3, 5, 6, 8}
|
||||
assert shard_1_set == {1, 4, 7}
|
||||
|
||||
|
||||
def test_repartition_more(ray_start_regular_shared):
|
||||
it = from_range(100, 2).repartition(3)
|
||||
assert it.num_shards() == 3
|
||||
assert set(it.get_shard(0)) == set(range(0, 50, 3)) | set(
|
||||
(range(50, 100, 3)))
|
||||
assert set(
|
||||
it.get_shard(1)) == set(range(1, 50, 3)) | set(range(51, 100, 3))
|
||||
assert set(
|
||||
it.get_shard(2)) == set(range(2, 50, 3)) | set(range(52, 100, 3))
|
||||
|
||||
|
||||
def test_repartition_consistent(ray_start_regular_shared):
|
||||
# repartition should be deterministic
|
||||
it1 = from_range(9, num_shards=1).repartition(2)
|
||||
it2 = from_range(9, num_shards=1).repartition(2)
|
||||
assert it1.num_shards() == 2
|
||||
assert it2.num_shards() == 2
|
||||
assert set(it1.get_shard(0)) == set(it2.get_shard(0))
|
||||
assert set(it1.get_shard(1)) == set(it2.get_shard(1))
|
||||
|
||||
|
||||
def test_batch(ray_start_regular_shared):
|
||||
it = from_range(4, 1).batch(2)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
import collections
|
||||
import random
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -253,8 +254,8 @@ class ParallelIterator(Generic[T]):
|
|||
randomness. Default value is None.
|
||||
|
||||
Returns:
|
||||
Returns a ParallelIterator with a local shuffle applied on the
|
||||
base iterator
|
||||
A ParallelIterator with a local shuffle applied on the base
|
||||
iterator
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
|
||||
|
@ -279,6 +280,83 @@ class ParallelIterator(Generic[T]):
|
|||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def repartition(self, num_partitions: int) -> "ParallelIterator[T]":
|
||||
"""Returns a new ParallelIterator instance with num_partitions shards.
|
||||
|
||||
The new iterator contains the same data in this instance except with
|
||||
num_partitions shards. The data is split in round-robin fashion for
|
||||
the new ParallelIterator.
|
||||
|
||||
Args:
|
||||
num_partitions (int): The number of shards to use for the new
|
||||
ParallelIterator
|
||||
|
||||
Returns:
|
||||
A ParallelIterator with num_partitions number of shards and the
|
||||
data of this ParallelIterator split round-robin among the new
|
||||
number of shards.
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(8, 2)
|
||||
>>> it = it.repartition(3)
|
||||
>>> list(it.get_shard(0))
|
||||
[0, 4, 3, 7]
|
||||
>>> list(it.get_shard(1))
|
||||
[1, 5]
|
||||
>>> list(it.get_shard(2))
|
||||
[2, 6]
|
||||
"""
|
||||
|
||||
# initialize the local iterators for all the actors
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
all_actors.extend(actor_set.actors)
|
||||
|
||||
def base_iterator(num_partitions, partition_index, timeout=None):
|
||||
futures = {}
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_slice.remote(
|
||||
step=num_partitions, start=partition_index)] = a
|
||||
while futures:
|
||||
pending = list(futures)
|
||||
if timeout is None:
|
||||
# First try to do a batch wait for efficiency.
|
||||
ready, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=0)
|
||||
# Fall back to a blocking wait.
|
||||
if not ready:
|
||||
ready, _ = ray.wait(pending, num_returns=1)
|
||||
else:
|
||||
ready, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout)
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_slice.remote(
|
||||
step=num_partitions,
|
||||
start=partition_index)] = actor
|
||||
except StopIteration:
|
||||
pass
|
||||
# Always yield after each round of wait with timeout.
|
||||
if timeout is not None:
|
||||
yield _NextValueNotReady()
|
||||
|
||||
def make_gen_i(i):
|
||||
return lambda: base_iterator(num_partitions, i)
|
||||
|
||||
name = self.name + ".repartition[num_partitions={}]".format(
|
||||
num_partitions)
|
||||
|
||||
generators = [make_gen_i(s) for s in range(num_partitions)]
|
||||
worker_cls = ray.remote(ParallelIteratorWorker)
|
||||
actors = [worker_cls.remote(g, repeat=False) for g in generators]
|
||||
x = ParallelIterator([_ActorSet(actors, [])], name)
|
||||
# need explicit reference to self so actors in this instance do not die
|
||||
x.parent_iterator = self
|
||||
return x
|
||||
|
||||
def gather_sync(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for synchronous iteration.
|
||||
|
||||
|
@ -710,6 +788,7 @@ class ParallelIteratorWorker(object):
|
|||
|
||||
self.transforms = []
|
||||
self.local_it = None
|
||||
self.next_ith_buffer = None
|
||||
|
||||
def par_iter_init(self, transforms):
|
||||
"""Implements ParallelIterator worker init."""
|
||||
|
@ -724,6 +803,29 @@ class ParallelIteratorWorker(object):
|
|||
assert self.local_it is not None, "must call par_iter_init()"
|
||||
return next(self.local_it)
|
||||
|
||||
def par_iter_slice(self, step: int, start: int):
|
||||
"""Iterates in increments of step starting from start."""
|
||||
assert self.local_it is not None, "must call par_iter_init()"
|
||||
|
||||
if self.next_ith_buffer is None:
|
||||
self.next_ith_buffer = collections.defaultdict(list)
|
||||
|
||||
index_buffer = self.next_ith_buffer[start]
|
||||
if len(index_buffer) > 0:
|
||||
return index_buffer.pop(0)
|
||||
else:
|
||||
for j in range(step):
|
||||
try:
|
||||
val = next(self.local_it)
|
||||
self.next_ith_buffer[j].append(val)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if not self.next_ith_buffer[start]:
|
||||
raise StopIteration
|
||||
|
||||
return self.next_ith_buffer[start].pop(0)
|
||||
|
||||
|
||||
class _NextValueNotReady(Exception):
|
||||
"""Indicates that a local iterator has no value currently available.
|
||||
|
|
Loading…
Add table
Reference in a new issue