[Data][Split] stable version of split with hints (#26778)

Why are these changes needed?
Introduce a stable version of split with hints with a stable equalizing algorithm:

use the greedy algorithm to generate the initial unbalanced splits.
for each splits, first shave them so the number for rows are below the target_size
based on how many rows needed for each split, do a one time split_at_index to the left over blocks.
merge the shaved splits with the leftover splits.
The guarantee of this algorithm is we at most need to split O(split) number of blocks.
This commit is contained in:
Chen Shen 2022-07-23 22:13:11 -07:00 committed by GitHub
parent 37f4692aa8
commit aaab4abad5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 322 additions and 207 deletions

View file

@ -0,0 +1,158 @@
from typing import Tuple, List
from ray.data._internal.block_list import BlockList
from ray.data._internal.split import _split_at_indices, _calculate_blocks_rows
from ray.data.block import (
Block,
BlockPartition,
BlockMetadata,
)
from ray.types import ObjectRef
def _equalize(
per_split_block_lists: List[BlockList],
owned_by_consumer: bool,
) -> List[BlockList]:
"""Equalize split block lists into equal number of rows.
Args:
per_split_block_lists: block lists to equalize.
Returns:
the equalized block lists.
"""
if len(per_split_block_lists) == 0:
return per_split_block_lists
per_split_blocks_with_metadata = [
block_list.get_blocks_with_metadata() for block_list in per_split_block_lists
]
per_split_num_rows: List[List[int]] = [
_calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
]
total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
target_split_size = total_rows // len(per_split_blocks_with_metadata)
# phase 1: shave the current splits by dropping blocks (into leftovers)
# and calculate num rows needed to the meet target.
shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
per_split_blocks_with_metadata, per_split_num_rows, target_split_size
)
# validate invariants
for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
assert num_shaved_rows <= target_split_size
assert num_shaved_rows + split_needed_row == target_split_size
# phase 2: based on the num rows needed for each shaved split, split the leftovers
# in the shape that exactly matches the rows needed.
leftover_refs = []
leftover_meta = []
for (ref, meta) in leftovers:
leftover_refs.append(ref)
leftover_meta.append(meta)
leftover_splits = _split_leftovers(
BlockList(leftover_refs, leftover_meta, owned_by_consumer=owned_by_consumer),
per_split_needed_rows,
)
# phase 3: merge the shaved_splits and leftoever splits and return.
for i, leftover_split in enumerate(leftover_splits):
shaved_splits[i].extend(leftover_split)
# validate invariants.
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
assert num_shaved_rows == target_split_size
# Compose the result back to blocklists
equalized_block_lists: List[BlockList] = []
for split in shaved_splits:
block_refs: List[ObjectRef[Block]] = []
meta: List[BlockMetadata] = []
for (block_ref, m) in split:
block_refs.append(block_ref)
meta.append(m)
equalized_block_lists.append(
BlockList(block_refs, meta, owned_by_consumer=owned_by_consumer)
)
return equalized_block_lists
def _shave_one_split(
split: BlockPartition, num_rows_per_block: List[int], target_size: int
) -> Tuple[BlockPartition, int, BlockPartition]:
"""Shave a block list to the target size.
Args:
split: the block list to shave.
num_rows_per_block: num rows for each block in the list.
target_size: the upper bound target size of the shaved list.
Returns:
A tuple of:
- shaved block list.
- num of rows needed for the block list to meet the target size.
- leftover blocks.
"""
# iterates through the blocks from the input list and
shaved = []
leftovers = []
shaved_rows = 0
for block_with_meta, block_rows in zip(split, num_rows_per_block):
if block_rows + shaved_rows <= target_size:
shaved.append(block_with_meta)
shaved_rows += block_rows
else:
leftovers.append(block_with_meta)
num_rows_needed = target_size - shaved_rows
return shaved, num_rows_needed, leftovers
def _shave_all_splits(
input_splits: List[BlockPartition],
per_split_num_rows: List[List[int]],
target_size: int,
) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
"""Shave all block list to the target size.
Args:
input_splits: all block list to shave.
input_splits: num rows (per block) for each block list.
target_size: the upper bound target size of the shaved lists.
Returns:
A tuple of:
- all shaved block list.
- num of rows needed for the block list to meet the target size.
- leftover blocks.
"""
shaved_splits = []
per_split_needed_rows = []
leftovers = []
for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
shaved, num_rows_needed, _leftovers = _shave_one_split(
split, num_rows_per_block, target_size
)
shaved_splits.append(shaved)
per_split_needed_rows.append(num_rows_needed)
leftovers.extend(_leftovers)
return shaved_splits, per_split_needed_rows, leftovers
def _split_leftovers(
leftovers: BlockList, per_split_needed_rows: List[int]
) -> List[BlockPartition]:
"""Split leftover blocks by the num of rows needed."""
num_splits = len(per_split_needed_rows)
split_indices = []
prev = 0
for i, num_rows_needed in enumerate(per_split_needed_rows):
split_indices.append(prev + num_rows_needed)
prev = split_indices[i]
split_result: Tuple[
List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
] = _split_at_indices(leftovers, split_indices)
return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
:num_splits
]

View file

@ -32,6 +32,7 @@ from ray.data._internal.compute import (
TaskPoolStrategy,
)
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.equalize import _equalize
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.util import _estimate_available_parallelism
@ -909,186 +910,29 @@ class Dataset(Generic[T]):
)
blocks = self._plan.execute()
stats = self._plan.stats()
def _partition_splits(
splits: List[Dataset[T]], part_size: int, counts_cache: Dict[str, int]
):
"""Partition splits into two sets: splits that are smaller than the
target size and splits that are larger than the target size.
"""
splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()])
idx = next(
i
for i, split in enumerate(splits)
if counts_cache[split._get_uuid()] >= part_size
)
return splits[:idx], splits[idx:]
def _equalize_larger_splits(
splits: List[Dataset[T]],
target_size: int,
counts_cache: Dict[str, int],
num_splits_required: int,
):
"""Split each split into one or more subsplits that are each the
target size, with at most one leftover split that's smaller
than the target size.
This assume that the given splits are sorted in ascending order.
"""
if target_size == 0:
return splits, []
new_splits = []
leftovers = []
for split in splits:
size = counts_cache[split._get_uuid()]
if size == target_size:
new_splits.append(split)
continue
split_indices = list(range(target_size, size, target_size))
split_splits = split.split_at_indices(split_indices)
last_split_size = split_splits[-1].count()
if last_split_size < target_size:
# Last split is smaller than the target size, save it for
# our unioning of small splits.
leftover = split_splits.pop()
leftovers.append(leftover)
counts_cache[leftover._get_uuid()] = leftover.count()
if len(new_splits) + len(split_splits) >= num_splits_required:
# Short-circuit if the new splits will make us reach the
# desired number of splits.
new_splits.extend(
split_splits[: num_splits_required - len(new_splits)]
)
break
new_splits.extend(split_splits)
return new_splits, leftovers
def _equalize_smaller_splits(
splits: List[Dataset[T]],
target_size: int,
counts_cache: Dict[str, int],
num_splits_required: int,
):
"""Union small splits up to the target split size.
This assume that the given splits are sorted in ascending order.
"""
new_splits = []
union_buffer = []
union_buffer_size = 0
low = 0
high = len(splits) - 1
while low <= high:
# Union small splits up to the target split size.
low_split = splits[low]
low_count = counts_cache[low_split._get_uuid()]
high_split = splits[high]
high_count = counts_cache[high_split._get_uuid()]
if union_buffer_size + high_count <= target_size:
# Try to add the larger split to the union buffer first.
union_buffer.append(high_split)
union_buffer_size += high_count
high -= 1
elif union_buffer_size + low_count <= target_size:
union_buffer.append(low_split)
union_buffer_size += low_count
low += 1
else:
# Neither the larger nor smaller split fit in the union
# buffer, so we split the smaller split into a subsplit
# that will fit into the union buffer and a leftover
# subsplit that we add back into the candidate split list.
diff = target_size - union_buffer_size
diff_split, new_low_split = low_split.split_at_indices([diff])
union_buffer.append(diff_split)
union_buffer_size += diff
# We overwrite the old low split and don't advance the low
# pointer since (1) the old low split can be discarded,
# (2) the leftover subsplit is guaranteed to be smaller
# than the old low split, and (3) the low split should be
# the smallest split in the candidate split list, which is
# this subsplit.
splits[low] = new_low_split
counts_cache[new_low_split._get_uuid()] = low_count - diff
if union_buffer_size == target_size:
# Once the union buffer is full, we union together the
# splits.
assert len(union_buffer) > 1, union_buffer
first_ds = union_buffer[0]
new_split = first_ds.union(*union_buffer[1:])
new_splits.append(new_split)
# Clear the union buffer.
union_buffer = []
union_buffer_size = 0
if len(new_splits) == num_splits_required:
# Short-circuit if we've reached the desired number of
# splits.
break
return new_splits
def equalize(splits: List[Dataset[T]], num_splits: int) -> List[Dataset[T]]:
if not equal:
return splits
counts = {s._get_uuid(): s.count() for s in splits}
total_rows = sum(counts.values())
# Number of rows for each split.
target_size = total_rows // num_splits
# Partition splits.
smaller_splits, larger_splits = _partition_splits(
splits, target_size, counts
)
if len(smaller_splits) == 0 and num_splits < len(splits):
# All splits are already equal.
return splits
# Split larger splits.
new_splits, leftovers = _equalize_larger_splits(
larger_splits, target_size, counts, num_splits
)
# Short-circuit if we've already reached the desired number of
# splits.
if len(new_splits) == num_splits:
return new_splits
# Add leftovers to small splits and re-sort.
smaller_splits += leftovers
smaller_splits = sorted(smaller_splits, key=lambda s: counts[s._get_uuid()])
# Union smaller splits.
new_splits_small = _equalize_smaller_splits(
smaller_splits, target_size, counts, num_splits - len(new_splits)
)
new_splits.extend(new_splits_small)
return new_splits
block_refs, metadata = zip(*blocks.get_blocks_with_metadata())
metadata_mapping = {b: m for b, m in zip(block_refs, metadata)}
owned_by_consumer = blocks._owned_by_consumer
stats = self._plan.stats()
block_refs, metadata = zip(*blocks.get_blocks_with_metadata())
if locality_hints is None:
ds = equalize(
[
Dataset(
ExecutionPlan(
BlockList(
list(blocks),
[metadata_mapping[b] for b in blocks],
owned_by_consumer=owned_by_consumer,
),
stats,
run_by_consumer=owned_by_consumer,
blocks = np.array_split(block_refs, n)
meta = np.array_split(metadata, n)
return [
Dataset(
ExecutionPlan(
BlockList(
b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer
),
self._epoch,
self._lazy,
)
for blocks in np.array_split(block_refs, n)
],
n,
)
assert len(ds) == n, (ds, n)
return ds
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
)
for b, m in zip(blocks, meta)
]
metadata_mapping = {b: m for b, m in zip(block_refs, metadata)}
# If the locality_hints is set, we use a two-round greedy algorithm
# to co-locate the blocks with the actors based on block
@ -1182,25 +1026,31 @@ class Dataset(Generic[T]):
assert len(remaining_block_refs) == 0, len(remaining_block_refs)
return equalize(
[
Dataset(
ExecutionPlan(
BlockList(
allocation_per_actor[actor],
[metadata_mapping[b] for b in allocation_per_actor[actor]],
owned_by_consumer=owned_by_consumer,
),
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
)
for actor in locality_hints
],
n,
)
per_split_block_lists = [
BlockList(
allocation_per_actor[actor],
[metadata_mapping[b] for b in allocation_per_actor[actor]],
owned_by_consumer=owned_by_consumer,
)
for actor in locality_hints
]
if equal:
# equalize the splits
per_split_block_lists = _equalize(per_split_block_lists, owned_by_consumer)
return [
Dataset(
ExecutionPlan(
block_split,
stats,
run_by_consumer=owned_by_consumer,
),
self._epoch,
self._lazy,
)
for block_split in per_split_block_lists
]
def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
"""Split the dataset at the given indices (like np.split).

View file

@ -10,6 +10,9 @@ from ray.data.block import BlockMetadata
import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.equalize import (
_equalize,
)
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats
from ray.data._internal.split import (
@ -95,9 +98,9 @@ def _test_equal_split_balanced(block_sizes, num_splits):
blocks.append(ray.put(block))
metadata.append(BlockAccessor.for_block(block).get_metadata(None, None))
total_rows += block_size
block_list = BlockList(blocks, metadata)
block_list = BlockList(blocks, metadata, owned_by_consumer=True)
ds = Dataset(
ExecutionPlan(block_list, DatasetStats.TODO()),
ExecutionPlan(block_list, DatasetStats.TODO(), run_by_consumer=True),
0,
False,
)
@ -507,6 +510,16 @@ def _create_block(data):
return (ray.put(data), _create_meta(len(data)))
def _create_blocklist(blocks):
block_refs = []
meta = []
for block in blocks:
block_ref, block_meta = _create_block(block)
block_refs.append(block_ref)
meta.append(block_meta)
return BlockList(block_refs, meta, owned_by_consumer=True)
def test_split_single_block(ray_start_regular_shared):
block = [1, 2, 3]
meta = _create_meta(3)
@ -586,29 +599,123 @@ def test_generate_global_split_results(ray_start_regular_shared):
def test_private_split_at_indices(ray_start_regular_shared):
inputs = []
splits = list(zip(*_split_at_indices(iter(inputs), [0])))
inputs = _create_blocklist([])
splits = list(zip(*_split_at_indices(inputs, [0])))
verify_splits(splits, [[], []])
splits = list(zip(*_split_at_indices(iter(inputs), [])))
splits = list(zip(*_split_at_indices(inputs, [])))
verify_splits(splits, [[]])
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(iter(inputs), [1])))
splits = list(zip(*_split_at_indices(inputs, [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])
splits = list(zip(*_split_at_indices(iter(inputs), [2])))
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(inputs, [2])))
verify_splits(splits, [[[1], [2]], [[3], [4]]])
splits = list(zip(*_split_at_indices(iter(inputs), [1])))
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(inputs, [1])))
verify_splits(splits, [[[1]], [[2, 3], [4]]])
splits = list(zip(*_split_at_indices(iter(inputs), [2, 2])))
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(inputs, [2, 2])))
verify_splits(splits, [[[1], [2]], [], [[3], [4]]])
splits = list(zip(*_split_at_indices(iter(inputs), [])))
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(inputs, [])))
verify_splits(splits, [[[1], [2, 3], [4]]])
splits = list(zip(*_split_at_indices(iter(inputs), [0, 4])))
inputs = _create_blocklist([[1], [2, 3], [4]])
splits = list(zip(*_split_at_indices(inputs, [0, 4])))
verify_splits(splits, [[], [[1], [2, 3], [4]], []])
def equalize_helper(input_block_lists):
result = _equalize(
[_create_blocklist(block_list) for block_list in input_block_lists],
owned_by_consumer=True,
)
result_block_lists = []
for blocklist in result:
block_list = []
for block_ref, _ in blocklist.get_blocks_with_metadata():
block = ray.get(block_ref)
block_accessor = BlockAccessor.for_block(block)
block_list.append(block_accessor.to_native())
result_block_lists.append(block_list)
return result_block_lists
def verify_equalize_result(input_block_lists, expected_block_lists):
result_block_lists = equalize_helper(input_block_lists)
assert result_block_lists == expected_block_lists
def test_equalize(ray_start_regular_shared):
verify_equalize_result([], [])
verify_equalize_result([[]], [[]])
verify_equalize_result([[[1]], []], [[], []])
verify_equalize_result([[[1], [2, 3]], [[4]]], [[[1], [2]], [[4], [3]]])
verify_equalize_result([[[1], [2, 3]], []], [[[1]], [[2]]])
verify_equalize_result(
[[[1], [2, 3], [4, 5]], [[6]], []], [[[1], [2]], [[6], [3]], [[4, 5]]]
)
verify_equalize_result(
[[[1, 2, 3], [4, 5]], [[6]], []], [[[4, 5]], [[6], [1]], [[2, 3]]]
)
def test_equalize_randomized(ray_start_regular_shared):
# verify the entries in the splits are in the range of 0 .. num_rows,
# unique, and the total number matches num_rows if exact_num == True.
def assert_unique_and_inrange(splits, num_rows, exact_num=False):
unique_set = set([])
for split in splits:
for block in split:
for entry in block:
assert entry not in unique_set
assert entry >= 0 and entry < num_rows
unique_set.add(entry)
if exact_num:
assert len(unique_set) == num_rows
# verify that splits are equalized.
def assert_equal_split(splits, num_rows, num_split):
split_size = num_rows // num_split
for split in splits:
assert len((list(itertools.chain.from_iterable(split)))) == split_size
# create randomized splits contains entries from 0 ... num_rows.
def random_split(num_rows, num_split):
split_point = [int(random.random() * num_rows) for _ in range(num_split - 1)]
split_index_helper = [0] + sorted(split_point) + [num_rows]
splits = []
for i in range(1, len(split_index_helper)):
split_start = split_index_helper[i - 1]
split_end = split_index_helper[i]
num_entries = split_end - split_start
split = []
num_block_split = int(random.random() * num_entries)
block_split_point = [
split_start + int(random.random() * num_entries)
for _ in range(num_block_split)
]
block_index_helper = [split_start] + sorted(block_split_point) + [split_end]
for j in range(1, len(block_index_helper)):
split.append(
list(range(block_index_helper[j - 1], block_index_helper[j]))
)
splits.append(split)
assert_unique_and_inrange(splits, num_rows, exact_num=True)
return splits
for i in range(100):
num_rows = int(random.random() * 100)
num_split = int(random.random() * 10) + 1
input_splits = random_split(num_rows, num_split)
print(input_splits)
equalized_splits = equalize_helper(input_splits)
assert_unique_and_inrange(equalized_splits, num_rows)
assert_equal_split(equalized_splits, num_rows, num_split)