diff --git a/python/ray/data/_internal/split.py b/python/ray/data/_internal/split.py index 5ffb2d073..160941dea 100644 --- a/python/ray/data/_internal/split.py +++ b/python/ray/data/_internal/split.py @@ -1,6 +1,6 @@ import itertools import logging -from typing import Iterable, Tuple, List +from typing import Union, Iterable, Tuple, List import ray from ray.data._internal.block_list import BlockList @@ -93,9 +93,23 @@ def _split_single_block( block: Block, meta: BlockMetadata, split_indices: List[int], -) -> Tuple[int, BlockPartition]: - """Split the provided block at the given indices.""" - split_result = [] +) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]: + """Split the provided block at the given indices. + + Args: + block_id: the id of this block in the block list. + block: block to be split. + meta: metadata of the block, we expect meta.num is valid. + split_indices: the indices where the block should be split. + Returns: + returns block_id, split blocks metadata, and a list of blocks + in the following form. We return blocks in this way + so that the owner of blocks could be the caller(driver) + instead of worker itself. + Tuple(block_id, split_blocks_meta), block0, block1 ... + """ + split_meta = [] + split_blocks = [] block_accessor = BlockAccessor.for_block(block) prev_index = 0 # append one more entry at the last so we don't @@ -106,16 +120,19 @@ def _split_single_block( stats = BlockExecStats.builder() split_block = block_accessor.slice(prev_index, index, copy=True) accessor = BlockAccessor.for_block(split_block) - split_meta = BlockMetadata( + _meta = BlockMetadata( num_rows=accessor.num_rows(), size_bytes=accessor.size_bytes(), schema=meta.schema, input_files=meta.input_files, exec_stats=stats.build(), ) - split_result.append((ray.put(split_block), split_meta)) + split_meta.append(_meta) + split_blocks.append(split_block) prev_index = index - return (block_id, split_result) + results = [(block_id, split_meta)] + results.extend(split_blocks) + return tuple(results) def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]: @@ -145,8 +162,10 @@ def _split_all_blocks( blocks_with_metadata = block_list.get_blocks_with_metadata() all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata) - split_single_block_futures = [] + per_block_split_metadata_futures = [] + per_block_split_block_refs = [] + # tracking splitted blocks for gc. blocks_splitted = [] for block_id, block_split_indices in enumerate(per_block_split_indices): (block_ref, meta) = blocks_with_metadata[block_id] @@ -158,19 +177,27 @@ def _split_all_blocks( all_blocks_split_results[block_id] = [(block_ref, meta)] else: # otherwise call split remote function. - split_single_block_futures.append( - split_single_block.options(scheduling_strategy="SPREAD").remote( - block_id, - block_ref, - meta, - block_split_indices, - ) + object_refs = split_single_block.options( + scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices) + ).remote( + block_id, + block_ref, + meta, + block_split_indices, ) + per_block_split_metadata_futures.append(object_refs[0]) + per_block_split_block_refs.append(object_refs[1:]) + blocks_splitted.append(block_ref) - if split_single_block_futures: - split_single_block_results = ray.get(split_single_block_futures) - for block_id, block_split_result in split_single_block_results: - all_blocks_split_results[block_id] = block_split_result + + if per_block_split_metadata_futures: + # only get metadata. + per_block_split_metadata = ray.get(per_block_split_metadata_futures) + for (block_id, meta), block_refs in zip( + per_block_split_metadata, per_block_split_block_refs + ): + assert len(meta) == len(block_refs) + all_blocks_split_results[block_id] = zip(block_refs, meta) # We make a copy for the blocks that have been splitted, so the input blocks # can be cleared if they are owned by consumer (consumer-owned blocks will diff --git a/python/ray/data/tests/test_split.py b/python/ray/data/tests/test_split.py index 214e8277b..a3ca0dd21 100644 --- a/python/ray/data/tests/test_split.py +++ b/python/ray/data/tests/test_split.py @@ -522,47 +522,63 @@ def _create_blocklist(blocks): def test_split_single_block(ray_start_regular_shared): block = [1, 2, 3] - meta = _create_meta(3) + metadata = _create_meta(3) - block_id, splits = ray.get( - ray.remote(_split_single_block).remote(234, block, meta, []) + results = ray.get( + ray.remote(_split_single_block) + .options(num_returns=2) + .remote(234, block, metadata, []) ) + block_id, meta = results[0] + blocks = results[1:] assert 234 == block_id - assert len(splits) == 1 - assert ray.get(splits[0][0]) == [1, 2, 3] - assert splits[0][1].num_rows == 3 + assert len(blocks) == 1 + assert blocks[0] == [1, 2, 3] + assert meta[0].num_rows == 3 - block_id, splits = ray.get( - ray.remote(_split_single_block).remote(234, block, meta, [1]) + results = ray.get( + ray.remote(_split_single_block) + .options(num_returns=3) + .remote(234, block, metadata, [1]) ) + block_id, meta = results[0] + blocks = results[1:] assert 234 == block_id - assert len(splits) == 2 - assert ray.get(splits[0][0]) == [1] - assert splits[0][1].num_rows == 1 - assert ray.get(splits[1][0]) == [2, 3] - assert splits[1][1].num_rows == 2 + assert len(blocks) == 2 + assert blocks[0] == [1] + assert meta[0].num_rows == 1 + assert blocks[1] == [2, 3] + assert meta[1].num_rows == 2 - block_id, splits = ray.get( - ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3]) + results = ray.get( + ray.remote(_split_single_block) + .options(num_returns=6) + .remote(234, block, metadata, [0, 1, 1, 3]) ) + block_id, meta = results[0] + blocks = results[1:] assert 234 == block_id - assert len(splits) == 5 - assert ray.get(splits[0][0]) == [] - assert ray.get(splits[1][0]) == [1] - assert ray.get(splits[2][0]) == [] - assert ray.get(splits[3][0]) == [2, 3] - assert ray.get(splits[4][0]) == [] + assert len(blocks) == 5 + assert blocks[0] == [] + assert blocks[1] == [1] + assert blocks[2] == [] + assert blocks[3] == [2, 3] + assert blocks[4] == [] block = [] - meta = _create_meta(0) + metadata = _create_meta(0) - block_id, splits = ray.get( - ray.remote(_split_single_block).remote(234, block, meta, [0]) + results = ray.get( + ray.remote(_split_single_block) + .options(num_returns=3) + .remote(234, block, metadata, [0]) ) + block_id, meta = results[0] + blocks = results[1:] assert 234 == block_id - assert len(splits) == 2 - assert ray.get(splits[0][0]) == [] - assert ray.get(splits[1][0]) == [] + assert len(blocks) == 2 + assert blocks[0] == [] + assert blocks[1] == [] def test_drop_empty_block_split():