[Parallel Iterators] Repartition functionality (#7163)

* repartition and tests

* blacklist lib/ files from import checks

* addressing comments and splitting up tests

* code readability

* adding explicit ref for parent iterator

* formatting
This commit is contained in:
Amog Kamsetty 2020-02-21 13:20:18 -08:00 committed by GitHub
parent c6f50ecc51
commit 1737a113be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 5 deletions

View file

@ -144,8 +144,7 @@ fi
# Ensure import ordering
# Make sure that for every import psutil; import setpproctitle
# There's a import ray above it.
python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build
python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build -s lib
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted changed files. Please review and stage the changes.'

View file

@ -100,6 +100,40 @@ def test_local_shuffle(ray_start_regular_shared):
assert value / len(freq_counter) > 0.2
def test_repartition_less(ray_start_regular_shared):
it = from_range(9, num_shards=3)
it1 = it.repartition(2)
assert repr(it1) == ("ParallelIterator[from_range[9, " +
"shards=3].repartition[num_partitions=2]]")
assert it1.num_shards() == 2
shard_0_set = set(it1.get_shard(0))
shard_1_set = set(it1.get_shard(1))
assert shard_0_set == {0, 2, 3, 5, 6, 8}
assert shard_1_set == {1, 4, 7}
def test_repartition_more(ray_start_regular_shared):
it = from_range(100, 2).repartition(3)
assert it.num_shards() == 3
assert set(it.get_shard(0)) == set(range(0, 50, 3)) | set(
(range(50, 100, 3)))
assert set(
it.get_shard(1)) == set(range(1, 50, 3)) | set(range(51, 100, 3))
assert set(
it.get_shard(2)) == set(range(2, 50, 3)) | set(range(52, 100, 3))
def test_repartition_consistent(ray_start_regular_shared):
# repartition should be deterministic
it1 = from_range(9, num_shards=1).repartition(2)
it2 = from_range(9, num_shards=1).repartition(2)
assert it1.num_shards() == 2
assert it2.num_shards() == 2
assert set(it1.get_shard(0)) == set(it2.get_shard(0))
assert set(it1.get_shard(1)) == set(it2.get_shard(1))
def test_batch(ray_start_regular_shared):
it = from_range(4, 1).batch(2)
assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"

View file

@ -1,5 +1,6 @@
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import collections
import random
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import ray
@ -253,8 +254,8 @@ class ParallelIterator(Generic[T]):
randomness. Default value is None.
Returns:
Returns a ParallelIterator with a local shuffle applied on the
base iterator
A ParallelIterator with a local shuffle applied on the base
iterator
Examples:
>>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
@ -279,6 +280,83 @@ class ParallelIterator(Generic[T]):
shuffle_buffer_size,
str(seed) if seed is not None else "None"))
def repartition(self, num_partitions: int) -> "ParallelIterator[T]":
"""Returns a new ParallelIterator instance with num_partitions shards.
The new iterator contains the same data in this instance except with
num_partitions shards. The data is split in round-robin fashion for
the new ParallelIterator.
Args:
num_partitions (int): The number of shards to use for the new
ParallelIterator
Returns:
A ParallelIterator with num_partitions number of shards and the
data of this ParallelIterator split round-robin among the new
number of shards.
Examples:
>>> it = from_range(8, 2)
>>> it = it.repartition(3)
>>> list(it.get_shard(0))
[0, 4, 3, 7]
>>> list(it.get_shard(1))
[1, 5]
>>> list(it.get_shard(2))
[2, 6]
"""
# initialize the local iterators for all the actors
all_actors = []
for actor_set in self.actor_sets:
actor_set.init_actors()
all_actors.extend(actor_set.actors)
def base_iterator(num_partitions, partition_index, timeout=None):
futures = {}
for a in all_actors:
futures[a.par_iter_slice.remote(
step=num_partitions, start=partition_index)] = a
while futures:
pending = list(futures)
if timeout is None:
# First try to do a batch wait for efficiency.
ready, _ = ray.wait(
pending, num_returns=len(pending), timeout=0)
# Fall back to a blocking wait.
if not ready:
ready, _ = ray.wait(pending, num_returns=1)
else:
ready, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout)
for obj_id in ready:
actor = futures.pop(obj_id)
try:
yield ray.get(obj_id)
futures[actor.par_iter_slice.remote(
step=num_partitions,
start=partition_index)] = actor
except StopIteration:
pass
# Always yield after each round of wait with timeout.
if timeout is not None:
yield _NextValueNotReady()
def make_gen_i(i):
return lambda: base_iterator(num_partitions, i)
name = self.name + ".repartition[num_partitions={}]".format(
num_partitions)
generators = [make_gen_i(s) for s in range(num_partitions)]
worker_cls = ray.remote(ParallelIteratorWorker)
actors = [worker_cls.remote(g, repeat=False) for g in generators]
x = ParallelIterator([_ActorSet(actors, [])], name)
# need explicit reference to self so actors in this instance do not die
x.parent_iterator = self
return x
def gather_sync(self) -> "LocalIterator[T]":
"""Returns a local iterable for synchronous iteration.
@ -710,6 +788,7 @@ class ParallelIteratorWorker(object):
self.transforms = []
self.local_it = None
self.next_ith_buffer = None
def par_iter_init(self, transforms):
"""Implements ParallelIterator worker init."""
@ -724,6 +803,29 @@ class ParallelIteratorWorker(object):
assert self.local_it is not None, "must call par_iter_init()"
return next(self.local_it)
def par_iter_slice(self, step: int, start: int):
"""Iterates in increments of step starting from start."""
assert self.local_it is not None, "must call par_iter_init()"
if self.next_ith_buffer is None:
self.next_ith_buffer = collections.defaultdict(list)
index_buffer = self.next_ith_buffer[start]
if len(index_buffer) > 0:
return index_buffer.pop(0)
else:
for j in range(step):
try:
val = next(self.local_it)
self.next_ith_buffer[j].append(val)
except StopIteration:
pass
if not self.next_ith_buffer[start]:
raise StopIteration
return self.next_ith_buffer[start].pop(0)
class _NextValueNotReady(Exception):
"""Indicates that a local iterator has no value currently available.