mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Data][Split] stable version of split with hints (#26778)
Why are these changes needed? Introduce a stable version of split with hints with a stable equalizing algorithm: use the greedy algorithm to generate the initial unbalanced splits. for each splits, first shave them so the number for rows are below the target_size based on how many rows needed for each split, do a one time split_at_index to the left over blocks. merge the shaved splits with the leftover splits. The guarantee of this algorithm is we at most need to split O(split) number of blocks.
This commit is contained in:
parent
37f4692aa8
commit
aaab4abad5
3 changed files with 322 additions and 207 deletions
158
python/ray/data/_internal/equalize.py
Normal file
158
python/ray/data/_internal/equalize.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
from typing import Tuple, List
|
||||
from ray.data._internal.block_list import BlockList
|
||||
|
||||
from ray.data._internal.split import _split_at_indices, _calculate_blocks_rows
|
||||
from ray.data.block import (
|
||||
Block,
|
||||
BlockPartition,
|
||||
BlockMetadata,
|
||||
)
|
||||
from ray.types import ObjectRef
|
||||
|
||||
|
||||
def _equalize(
|
||||
per_split_block_lists: List[BlockList],
|
||||
owned_by_consumer: bool,
|
||||
) -> List[BlockList]:
|
||||
"""Equalize split block lists into equal number of rows.
|
||||
|
||||
Args:
|
||||
per_split_block_lists: block lists to equalize.
|
||||
Returns:
|
||||
the equalized block lists.
|
||||
"""
|
||||
if len(per_split_block_lists) == 0:
|
||||
return per_split_block_lists
|
||||
per_split_blocks_with_metadata = [
|
||||
block_list.get_blocks_with_metadata() for block_list in per_split_block_lists
|
||||
]
|
||||
per_split_num_rows: List[List[int]] = [
|
||||
_calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
|
||||
]
|
||||
total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
|
||||
target_split_size = total_rows // len(per_split_blocks_with_metadata)
|
||||
|
||||
# phase 1: shave the current splits by dropping blocks (into leftovers)
|
||||
# and calculate num rows needed to the meet target.
|
||||
shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
|
||||
per_split_blocks_with_metadata, per_split_num_rows, target_split_size
|
||||
)
|
||||
|
||||
# validate invariants
|
||||
for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
|
||||
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
|
||||
assert num_shaved_rows <= target_split_size
|
||||
assert num_shaved_rows + split_needed_row == target_split_size
|
||||
|
||||
# phase 2: based on the num rows needed for each shaved split, split the leftovers
|
||||
# in the shape that exactly matches the rows needed.
|
||||
leftover_refs = []
|
||||
leftover_meta = []
|
||||
for (ref, meta) in leftovers:
|
||||
leftover_refs.append(ref)
|
||||
leftover_meta.append(meta)
|
||||
leftover_splits = _split_leftovers(
|
||||
BlockList(leftover_refs, leftover_meta, owned_by_consumer=owned_by_consumer),
|
||||
per_split_needed_rows,
|
||||
)
|
||||
|
||||
# phase 3: merge the shaved_splits and leftoever splits and return.
|
||||
for i, leftover_split in enumerate(leftover_splits):
|
||||
shaved_splits[i].extend(leftover_split)
|
||||
|
||||
# validate invariants.
|
||||
num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
|
||||
assert num_shaved_rows == target_split_size
|
||||
|
||||
# Compose the result back to blocklists
|
||||
equalized_block_lists: List[BlockList] = []
|
||||
for split in shaved_splits:
|
||||
block_refs: List[ObjectRef[Block]] = []
|
||||
meta: List[BlockMetadata] = []
|
||||
for (block_ref, m) in split:
|
||||
block_refs.append(block_ref)
|
||||
meta.append(m)
|
||||
equalized_block_lists.append(
|
||||
BlockList(block_refs, meta, owned_by_consumer=owned_by_consumer)
|
||||
)
|
||||
return equalized_block_lists
|
||||
|
||||
|
||||
def _shave_one_split(
|
||||
split: BlockPartition, num_rows_per_block: List[int], target_size: int
|
||||
) -> Tuple[BlockPartition, int, BlockPartition]:
|
||||
"""Shave a block list to the target size.
|
||||
|
||||
Args:
|
||||
split: the block list to shave.
|
||||
num_rows_per_block: num rows for each block in the list.
|
||||
target_size: the upper bound target size of the shaved list.
|
||||
Returns:
|
||||
A tuple of:
|
||||
- shaved block list.
|
||||
- num of rows needed for the block list to meet the target size.
|
||||
- leftover blocks.
|
||||
|
||||
"""
|
||||
# iterates through the blocks from the input list and
|
||||
shaved = []
|
||||
leftovers = []
|
||||
shaved_rows = 0
|
||||
for block_with_meta, block_rows in zip(split, num_rows_per_block):
|
||||
if block_rows + shaved_rows <= target_size:
|
||||
shaved.append(block_with_meta)
|
||||
shaved_rows += block_rows
|
||||
else:
|
||||
leftovers.append(block_with_meta)
|
||||
num_rows_needed = target_size - shaved_rows
|
||||
return shaved, num_rows_needed, leftovers
|
||||
|
||||
|
||||
def _shave_all_splits(
|
||||
input_splits: List[BlockPartition],
|
||||
per_split_num_rows: List[List[int]],
|
||||
target_size: int,
|
||||
) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
|
||||
"""Shave all block list to the target size.
|
||||
|
||||
Args:
|
||||
input_splits: all block list to shave.
|
||||
input_splits: num rows (per block) for each block list.
|
||||
target_size: the upper bound target size of the shaved lists.
|
||||
Returns:
|
||||
A tuple of:
|
||||
- all shaved block list.
|
||||
- num of rows needed for the block list to meet the target size.
|
||||
- leftover blocks.
|
||||
"""
|
||||
shaved_splits = []
|
||||
per_split_needed_rows = []
|
||||
leftovers = []
|
||||
|
||||
for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
|
||||
shaved, num_rows_needed, _leftovers = _shave_one_split(
|
||||
split, num_rows_per_block, target_size
|
||||
)
|
||||
shaved_splits.append(shaved)
|
||||
per_split_needed_rows.append(num_rows_needed)
|
||||
leftovers.extend(_leftovers)
|
||||
|
||||
return shaved_splits, per_split_needed_rows, leftovers
|
||||
|
||||
|
||||
def _split_leftovers(
|
||||
leftovers: BlockList, per_split_needed_rows: List[int]
|
||||
) -> List[BlockPartition]:
|
||||
"""Split leftover blocks by the num of rows needed."""
|
||||
num_splits = len(per_split_needed_rows)
|
||||
split_indices = []
|
||||
prev = 0
|
||||
for i, num_rows_needed in enumerate(per_split_needed_rows):
|
||||
split_indices.append(prev + num_rows_needed)
|
||||
prev = split_indices[i]
|
||||
split_result: Tuple[
|
||||
List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
|
||||
] = _split_at_indices(leftovers, split_indices)
|
||||
return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
|
||||
:num_splits
|
||||
]
|
|
@ -32,6 +32,7 @@ from ray.data._internal.compute import (
|
|||
TaskPoolStrategy,
|
||||
)
|
||||
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
|
||||
from ray.data._internal.equalize import _equalize
|
||||
from ray.data._internal.lazy_block_list import LazyBlockList
|
||||
from ray.data._internal.output_buffer import BlockOutputBuffer
|
||||
from ray.data._internal.util import _estimate_available_parallelism
|
||||
|
@ -909,186 +910,29 @@ class Dataset(Generic[T]):
|
|||
)
|
||||
|
||||
blocks = self._plan.execute()
|
||||
stats = self._plan.stats()
|
||||
|
||||
def _partition_splits(
|
||||
splits: List[Dataset[T]], part_size: int, counts_cache: Dict[str, int]
|
||||
):
|
||||
"""Partition splits into two sets: splits that are smaller than the
|
||||
target size and splits that are larger than the target size.
|
||||
"""
|
||||
splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()])
|
||||
idx = next(
|
||||
i
|
||||
for i, split in enumerate(splits)
|
||||
if counts_cache[split._get_uuid()] >= part_size
|
||||
)
|
||||
return splits[:idx], splits[idx:]
|
||||
|
||||
def _equalize_larger_splits(
|
||||
splits: List[Dataset[T]],
|
||||
target_size: int,
|
||||
counts_cache: Dict[str, int],
|
||||
num_splits_required: int,
|
||||
):
|
||||
"""Split each split into one or more subsplits that are each the
|
||||
target size, with at most one leftover split that's smaller
|
||||
than the target size.
|
||||
|
||||
This assume that the given splits are sorted in ascending order.
|
||||
"""
|
||||
if target_size == 0:
|
||||
return splits, []
|
||||
new_splits = []
|
||||
leftovers = []
|
||||
for split in splits:
|
||||
size = counts_cache[split._get_uuid()]
|
||||
if size == target_size:
|
||||
new_splits.append(split)
|
||||
continue
|
||||
split_indices = list(range(target_size, size, target_size))
|
||||
split_splits = split.split_at_indices(split_indices)
|
||||
last_split_size = split_splits[-1].count()
|
||||
if last_split_size < target_size:
|
||||
# Last split is smaller than the target size, save it for
|
||||
# our unioning of small splits.
|
||||
leftover = split_splits.pop()
|
||||
leftovers.append(leftover)
|
||||
counts_cache[leftover._get_uuid()] = leftover.count()
|
||||
if len(new_splits) + len(split_splits) >= num_splits_required:
|
||||
# Short-circuit if the new splits will make us reach the
|
||||
# desired number of splits.
|
||||
new_splits.extend(
|
||||
split_splits[: num_splits_required - len(new_splits)]
|
||||
)
|
||||
break
|
||||
new_splits.extend(split_splits)
|
||||
return new_splits, leftovers
|
||||
|
||||
def _equalize_smaller_splits(
|
||||
splits: List[Dataset[T]],
|
||||
target_size: int,
|
||||
counts_cache: Dict[str, int],
|
||||
num_splits_required: int,
|
||||
):
|
||||
"""Union small splits up to the target split size.
|
||||
|
||||
This assume that the given splits are sorted in ascending order.
|
||||
"""
|
||||
new_splits = []
|
||||
union_buffer = []
|
||||
union_buffer_size = 0
|
||||
low = 0
|
||||
high = len(splits) - 1
|
||||
while low <= high:
|
||||
# Union small splits up to the target split size.
|
||||
low_split = splits[low]
|
||||
low_count = counts_cache[low_split._get_uuid()]
|
||||
high_split = splits[high]
|
||||
high_count = counts_cache[high_split._get_uuid()]
|
||||
if union_buffer_size + high_count <= target_size:
|
||||
# Try to add the larger split to the union buffer first.
|
||||
union_buffer.append(high_split)
|
||||
union_buffer_size += high_count
|
||||
high -= 1
|
||||
elif union_buffer_size + low_count <= target_size:
|
||||
union_buffer.append(low_split)
|
||||
union_buffer_size += low_count
|
||||
low += 1
|
||||
else:
|
||||
# Neither the larger nor smaller split fit in the union
|
||||
# buffer, so we split the smaller split into a subsplit
|
||||
# that will fit into the union buffer and a leftover
|
||||
# subsplit that we add back into the candidate split list.
|
||||
diff = target_size - union_buffer_size
|
||||
diff_split, new_low_split = low_split.split_at_indices([diff])
|
||||
union_buffer.append(diff_split)
|
||||
union_buffer_size += diff
|
||||
# We overwrite the old low split and don't advance the low
|
||||
# pointer since (1) the old low split can be discarded,
|
||||
# (2) the leftover subsplit is guaranteed to be smaller
|
||||
# than the old low split, and (3) the low split should be
|
||||
# the smallest split in the candidate split list, which is
|
||||
# this subsplit.
|
||||
splits[low] = new_low_split
|
||||
counts_cache[new_low_split._get_uuid()] = low_count - diff
|
||||
if union_buffer_size == target_size:
|
||||
# Once the union buffer is full, we union together the
|
||||
# splits.
|
||||
assert len(union_buffer) > 1, union_buffer
|
||||
first_ds = union_buffer[0]
|
||||
new_split = first_ds.union(*union_buffer[1:])
|
||||
new_splits.append(new_split)
|
||||
# Clear the union buffer.
|
||||
union_buffer = []
|
||||
union_buffer_size = 0
|
||||
if len(new_splits) == num_splits_required:
|
||||
# Short-circuit if we've reached the desired number of
|
||||
# splits.
|
||||
break
|
||||
return new_splits
|
||||
|
||||
def equalize(splits: List[Dataset[T]], num_splits: int) -> List[Dataset[T]]:
|
||||
if not equal:
|
||||
return splits
|
||||
counts = {s._get_uuid(): s.count() for s in splits}
|
||||
total_rows = sum(counts.values())
|
||||
# Number of rows for each split.
|
||||
target_size = total_rows // num_splits
|
||||
|
||||
# Partition splits.
|
||||
smaller_splits, larger_splits = _partition_splits(
|
||||
splits, target_size, counts
|
||||
)
|
||||
if len(smaller_splits) == 0 and num_splits < len(splits):
|
||||
# All splits are already equal.
|
||||
return splits
|
||||
|
||||
# Split larger splits.
|
||||
new_splits, leftovers = _equalize_larger_splits(
|
||||
larger_splits, target_size, counts, num_splits
|
||||
)
|
||||
# Short-circuit if we've already reached the desired number of
|
||||
# splits.
|
||||
if len(new_splits) == num_splits:
|
||||
return new_splits
|
||||
# Add leftovers to small splits and re-sort.
|
||||
smaller_splits += leftovers
|
||||
smaller_splits = sorted(smaller_splits, key=lambda s: counts[s._get_uuid()])
|
||||
|
||||
# Union smaller splits.
|
||||
new_splits_small = _equalize_smaller_splits(
|
||||
smaller_splits, target_size, counts, num_splits - len(new_splits)
|
||||
)
|
||||
new_splits.extend(new_splits_small)
|
||||
return new_splits
|
||||
|
||||
block_refs, metadata = zip(*blocks.get_blocks_with_metadata())
|
||||
metadata_mapping = {b: m for b, m in zip(block_refs, metadata)}
|
||||
owned_by_consumer = blocks._owned_by_consumer
|
||||
stats = self._plan.stats()
|
||||
block_refs, metadata = zip(*blocks.get_blocks_with_metadata())
|
||||
|
||||
if locality_hints is None:
|
||||
ds = equalize(
|
||||
[
|
||||
Dataset(
|
||||
ExecutionPlan(
|
||||
BlockList(
|
||||
list(blocks),
|
||||
[metadata_mapping[b] for b in blocks],
|
||||
owned_by_consumer=owned_by_consumer,
|
||||
),
|
||||
stats,
|
||||
run_by_consumer=owned_by_consumer,
|
||||
blocks = np.array_split(block_refs, n)
|
||||
meta = np.array_split(metadata, n)
|
||||
return [
|
||||
Dataset(
|
||||
ExecutionPlan(
|
||||
BlockList(
|
||||
b.tolist(), m.tolist(), owned_by_consumer=owned_by_consumer
|
||||
),
|
||||
self._epoch,
|
||||
self._lazy,
|
||||
)
|
||||
for blocks in np.array_split(block_refs, n)
|
||||
],
|
||||
n,
|
||||
)
|
||||
assert len(ds) == n, (ds, n)
|
||||
return ds
|
||||
stats,
|
||||
run_by_consumer=owned_by_consumer,
|
||||
),
|
||||
self._epoch,
|
||||
self._lazy,
|
||||
)
|
||||
for b, m in zip(blocks, meta)
|
||||
]
|
||||
|
||||
metadata_mapping = {b: m for b, m in zip(block_refs, metadata)}
|
||||
|
||||
# If the locality_hints is set, we use a two-round greedy algorithm
|
||||
# to co-locate the blocks with the actors based on block
|
||||
|
@ -1182,25 +1026,31 @@ class Dataset(Generic[T]):
|
|||
|
||||
assert len(remaining_block_refs) == 0, len(remaining_block_refs)
|
||||
|
||||
return equalize(
|
||||
[
|
||||
Dataset(
|
||||
ExecutionPlan(
|
||||
BlockList(
|
||||
allocation_per_actor[actor],
|
||||
[metadata_mapping[b] for b in allocation_per_actor[actor]],
|
||||
owned_by_consumer=owned_by_consumer,
|
||||
),
|
||||
stats,
|
||||
run_by_consumer=owned_by_consumer,
|
||||
),
|
||||
self._epoch,
|
||||
self._lazy,
|
||||
)
|
||||
for actor in locality_hints
|
||||
],
|
||||
n,
|
||||
)
|
||||
per_split_block_lists = [
|
||||
BlockList(
|
||||
allocation_per_actor[actor],
|
||||
[metadata_mapping[b] for b in allocation_per_actor[actor]],
|
||||
owned_by_consumer=owned_by_consumer,
|
||||
)
|
||||
for actor in locality_hints
|
||||
]
|
||||
|
||||
if equal:
|
||||
# equalize the splits
|
||||
per_split_block_lists = _equalize(per_split_block_lists, owned_by_consumer)
|
||||
|
||||
return [
|
||||
Dataset(
|
||||
ExecutionPlan(
|
||||
block_split,
|
||||
stats,
|
||||
run_by_consumer=owned_by_consumer,
|
||||
),
|
||||
self._epoch,
|
||||
self._lazy,
|
||||
)
|
||||
for block_split in per_split_block_lists
|
||||
]
|
||||
|
||||
def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
|
||||
"""Split the dataset at the given indices (like np.split).
|
||||
|
|
|
@ -10,6 +10,9 @@ from ray.data.block import BlockMetadata
|
|||
|
||||
import ray
|
||||
from ray.data._internal.block_list import BlockList
|
||||
from ray.data._internal.equalize import (
|
||||
_equalize,
|
||||
)
|
||||
from ray.data._internal.plan import ExecutionPlan
|
||||
from ray.data._internal.stats import DatasetStats
|
||||
from ray.data._internal.split import (
|
||||
|
@ -95,9 +98,9 @@ def _test_equal_split_balanced(block_sizes, num_splits):
|
|||
blocks.append(ray.put(block))
|
||||
metadata.append(BlockAccessor.for_block(block).get_metadata(None, None))
|
||||
total_rows += block_size
|
||||
block_list = BlockList(blocks, metadata)
|
||||
block_list = BlockList(blocks, metadata, owned_by_consumer=True)
|
||||
ds = Dataset(
|
||||
ExecutionPlan(block_list, DatasetStats.TODO()),
|
||||
ExecutionPlan(block_list, DatasetStats.TODO(), run_by_consumer=True),
|
||||
0,
|
||||
False,
|
||||
)
|
||||
|
@ -507,6 +510,16 @@ def _create_block(data):
|
|||
return (ray.put(data), _create_meta(len(data)))
|
||||
|
||||
|
||||
def _create_blocklist(blocks):
|
||||
block_refs = []
|
||||
meta = []
|
||||
for block in blocks:
|
||||
block_ref, block_meta = _create_block(block)
|
||||
block_refs.append(block_ref)
|
||||
meta.append(block_meta)
|
||||
return BlockList(block_refs, meta, owned_by_consumer=True)
|
||||
|
||||
|
||||
def test_split_single_block(ray_start_regular_shared):
|
||||
block = [1, 2, 3]
|
||||
meta = _create_meta(3)
|
||||
|
@ -586,29 +599,123 @@ def test_generate_global_split_results(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_private_split_at_indices(ray_start_regular_shared):
|
||||
inputs = []
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [0])))
|
||||
inputs = _create_blocklist([])
|
||||
splits = list(zip(*_split_at_indices(inputs, [0])))
|
||||
verify_splits(splits, [[], []])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [])))
|
||||
splits = list(zip(*_split_at_indices(inputs, [])))
|
||||
verify_splits(splits, [[]])
|
||||
|
||||
inputs = [_create_block([1]), _create_block([2, 3]), _create_block([4])]
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [1])))
|
||||
splits = list(zip(*_split_at_indices(inputs, [1])))
|
||||
verify_splits(splits, [[[1]], [[2, 3], [4]]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [2])))
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
splits = list(zip(*_split_at_indices(inputs, [2])))
|
||||
verify_splits(splits, [[[1], [2]], [[3], [4]]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [1])))
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
splits = list(zip(*_split_at_indices(inputs, [1])))
|
||||
verify_splits(splits, [[[1]], [[2, 3], [4]]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [2, 2])))
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
splits = list(zip(*_split_at_indices(inputs, [2, 2])))
|
||||
verify_splits(splits, [[[1], [2]], [], [[3], [4]]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [])))
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
splits = list(zip(*_split_at_indices(inputs, [])))
|
||||
verify_splits(splits, [[[1], [2, 3], [4]]])
|
||||
|
||||
splits = list(zip(*_split_at_indices(iter(inputs), [0, 4])))
|
||||
inputs = _create_blocklist([[1], [2, 3], [4]])
|
||||
splits = list(zip(*_split_at_indices(inputs, [0, 4])))
|
||||
verify_splits(splits, [[], [[1], [2, 3], [4]], []])
|
||||
|
||||
|
||||
def equalize_helper(input_block_lists):
|
||||
result = _equalize(
|
||||
[_create_blocklist(block_list) for block_list in input_block_lists],
|
||||
owned_by_consumer=True,
|
||||
)
|
||||
result_block_lists = []
|
||||
for blocklist in result:
|
||||
block_list = []
|
||||
for block_ref, _ in blocklist.get_blocks_with_metadata():
|
||||
block = ray.get(block_ref)
|
||||
block_accessor = BlockAccessor.for_block(block)
|
||||
block_list.append(block_accessor.to_native())
|
||||
result_block_lists.append(block_list)
|
||||
return result_block_lists
|
||||
|
||||
|
||||
def verify_equalize_result(input_block_lists, expected_block_lists):
|
||||
result_block_lists = equalize_helper(input_block_lists)
|
||||
assert result_block_lists == expected_block_lists
|
||||
|
||||
|
||||
def test_equalize(ray_start_regular_shared):
|
||||
verify_equalize_result([], [])
|
||||
verify_equalize_result([[]], [[]])
|
||||
verify_equalize_result([[[1]], []], [[], []])
|
||||
verify_equalize_result([[[1], [2, 3]], [[4]]], [[[1], [2]], [[4], [3]]])
|
||||
verify_equalize_result([[[1], [2, 3]], []], [[[1]], [[2]]])
|
||||
verify_equalize_result(
|
||||
[[[1], [2, 3], [4, 5]], [[6]], []], [[[1], [2]], [[6], [3]], [[4, 5]]]
|
||||
)
|
||||
verify_equalize_result(
|
||||
[[[1, 2, 3], [4, 5]], [[6]], []], [[[4, 5]], [[6], [1]], [[2, 3]]]
|
||||
)
|
||||
|
||||
|
||||
def test_equalize_randomized(ray_start_regular_shared):
|
||||
# verify the entries in the splits are in the range of 0 .. num_rows,
|
||||
# unique, and the total number matches num_rows if exact_num == True.
|
||||
def assert_unique_and_inrange(splits, num_rows, exact_num=False):
|
||||
unique_set = set([])
|
||||
for split in splits:
|
||||
for block in split:
|
||||
for entry in block:
|
||||
assert entry not in unique_set
|
||||
assert entry >= 0 and entry < num_rows
|
||||
unique_set.add(entry)
|
||||
if exact_num:
|
||||
assert len(unique_set) == num_rows
|
||||
|
||||
# verify that splits are equalized.
|
||||
def assert_equal_split(splits, num_rows, num_split):
|
||||
split_size = num_rows // num_split
|
||||
for split in splits:
|
||||
assert len((list(itertools.chain.from_iterable(split)))) == split_size
|
||||
|
||||
# create randomized splits contains entries from 0 ... num_rows.
|
||||
def random_split(num_rows, num_split):
|
||||
split_point = [int(random.random() * num_rows) for _ in range(num_split - 1)]
|
||||
split_index_helper = [0] + sorted(split_point) + [num_rows]
|
||||
splits = []
|
||||
for i in range(1, len(split_index_helper)):
|
||||
split_start = split_index_helper[i - 1]
|
||||
split_end = split_index_helper[i]
|
||||
num_entries = split_end - split_start
|
||||
split = []
|
||||
num_block_split = int(random.random() * num_entries)
|
||||
block_split_point = [
|
||||
split_start + int(random.random() * num_entries)
|
||||
for _ in range(num_block_split)
|
||||
]
|
||||
block_index_helper = [split_start] + sorted(block_split_point) + [split_end]
|
||||
for j in range(1, len(block_index_helper)):
|
||||
split.append(
|
||||
list(range(block_index_helper[j - 1], block_index_helper[j]))
|
||||
)
|
||||
splits.append(split)
|
||||
assert_unique_and_inrange(splits, num_rows, exact_num=True)
|
||||
return splits
|
||||
|
||||
for i in range(100):
|
||||
num_rows = int(random.random() * 100)
|
||||
num_split = int(random.random() * 10) + 1
|
||||
input_splits = random_split(num_rows, num_split)
|
||||
print(input_splits)
|
||||
equalized_splits = equalize_helper(input_splits)
|
||||
assert_unique_and_inrange(equalized_splits, num_rows)
|
||||
assert_equal_split(equalized_splits, num_rows, num_split)
|
||||
|
|
Loading…
Add table
Reference in a new issue