[Data][Split] Fix split ownership (#27149)

fb54679 introduced a bug by calling ray.put in the remote _split_single_block. This changes the ownership from driver to the worker who runs _split_single_block, which breaks dataset's lineage requirement and failed the chaos test.

To fix the issue we need to ensure the split block refs are created by the driver, which we can achieved by creating the block_refs as part of function returns.
This commit is contained in:
Chen Shen 2022-07-28 10:39:14 -07:00 committed by GitHub
parent df124d0ad5
commit 0d49901651
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 46 deletions

View file

@ -1,6 +1,6 @@
import itertools
import logging
from typing import Iterable, Tuple, List
from typing import Union, Iterable, Tuple, List
import ray
from ray.data._internal.block_list import BlockList
@ -93,9 +93,23 @@ def _split_single_block(
block: Block,
meta: BlockMetadata,
split_indices: List[int],
) -> Tuple[int, BlockPartition]:
"""Split the provided block at the given indices."""
split_result = []
) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]:
"""Split the provided block at the given indices.
Args:
block_id: the id of this block in the block list.
block: block to be split.
meta: metadata of the block, we expect meta.num is valid.
split_indices: the indices where the block should be split.
Returns:
returns block_id, split blocks metadata, and a list of blocks
in the following form. We return blocks in this way
so that the owner of blocks could be the caller(driver)
instead of worker itself.
Tuple(block_id, split_blocks_meta), block0, block1 ...
"""
split_meta = []
split_blocks = []
block_accessor = BlockAccessor.for_block(block)
prev_index = 0
# append one more entry at the last so we don't
@ -106,16 +120,19 @@ def _split_single_block(
stats = BlockExecStats.builder()
split_block = block_accessor.slice(prev_index, index, copy=True)
accessor = BlockAccessor.for_block(split_block)
split_meta = BlockMetadata(
_meta = BlockMetadata(
num_rows=accessor.num_rows(),
size_bytes=accessor.size_bytes(),
schema=meta.schema,
input_files=meta.input_files,
exec_stats=stats.build(),
)
split_result.append((ray.put(split_block), split_meta))
split_meta.append(_meta)
split_blocks.append(split_block)
prev_index = index
return (block_id, split_result)
results = [(block_id, split_meta)]
results.extend(split_blocks)
return tuple(results)
def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
@ -145,8 +162,10 @@ def _split_all_blocks(
blocks_with_metadata = block_list.get_blocks_with_metadata()
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)
split_single_block_futures = []
per_block_split_metadata_futures = []
per_block_split_block_refs = []
# tracking splitted blocks for gc.
blocks_splitted = []
for block_id, block_split_indices in enumerate(per_block_split_indices):
(block_ref, meta) = blocks_with_metadata[block_id]
@ -158,19 +177,27 @@ def _split_all_blocks(
all_blocks_split_results[block_id] = [(block_ref, meta)]
else:
# otherwise call split remote function.
split_single_block_futures.append(
split_single_block.options(scheduling_strategy="SPREAD").remote(
block_id,
block_ref,
meta,
block_split_indices,
)
object_refs = split_single_block.options(
scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices)
).remote(
block_id,
block_ref,
meta,
block_split_indices,
)
per_block_split_metadata_futures.append(object_refs[0])
per_block_split_block_refs.append(object_refs[1:])
blocks_splitted.append(block_ref)
if split_single_block_futures:
split_single_block_results = ray.get(split_single_block_futures)
for block_id, block_split_result in split_single_block_results:
all_blocks_split_results[block_id] = block_split_result
if per_block_split_metadata_futures:
# only get metadata.
per_block_split_metadata = ray.get(per_block_split_metadata_futures)
for (block_id, meta), block_refs in zip(
per_block_split_metadata, per_block_split_block_refs
):
assert len(meta) == len(block_refs)
all_blocks_split_results[block_id] = zip(block_refs, meta)
# We make a copy for the blocks that have been splitted, so the input blocks
# can be cleared if they are owned by consumer (consumer-owned blocks will

View file

@ -522,47 +522,63 @@ def _create_blocklist(blocks):
def test_split_single_block(ray_start_regular_shared):
block = [1, 2, 3]
meta = _create_meta(3)
metadata = _create_meta(3)
block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [])
results = ray.get(
ray.remote(_split_single_block)
.options(num_returns=2)
.remote(234, block, metadata, [])
)
block_id, meta = results[0]
blocks = results[1:]
assert 234 == block_id
assert len(splits) == 1
assert ray.get(splits[0][0]) == [1, 2, 3]
assert splits[0][1].num_rows == 3
assert len(blocks) == 1
assert blocks[0] == [1, 2, 3]
assert meta[0].num_rows == 3
block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [1])
results = ray.get(
ray.remote(_split_single_block)
.options(num_returns=3)
.remote(234, block, metadata, [1])
)
block_id, meta = results[0]
blocks = results[1:]
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == [1]
assert splits[0][1].num_rows == 1
assert ray.get(splits[1][0]) == [2, 3]
assert splits[1][1].num_rows == 2
assert len(blocks) == 2
assert blocks[0] == [1]
assert meta[0].num_rows == 1
assert blocks[1] == [2, 3]
assert meta[1].num_rows == 2
block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0, 1, 1, 3])
results = ray.get(
ray.remote(_split_single_block)
.options(num_returns=6)
.remote(234, block, metadata, [0, 1, 1, 3])
)
block_id, meta = results[0]
blocks = results[1:]
assert 234 == block_id
assert len(splits) == 5
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == [1]
assert ray.get(splits[2][0]) == []
assert ray.get(splits[3][0]) == [2, 3]
assert ray.get(splits[4][0]) == []
assert len(blocks) == 5
assert blocks[0] == []
assert blocks[1] == [1]
assert blocks[2] == []
assert blocks[3] == [2, 3]
assert blocks[4] == []
block = []
meta = _create_meta(0)
metadata = _create_meta(0)
block_id, splits = ray.get(
ray.remote(_split_single_block).remote(234, block, meta, [0])
results = ray.get(
ray.remote(_split_single_block)
.options(num_returns=3)
.remote(234, block, metadata, [0])
)
block_id, meta = results[0]
blocks = results[1:]
assert 234 == block_id
assert len(splits) == 2
assert ray.get(splits[0][0]) == []
assert ray.get(splits[1][0]) == []
assert len(blocks) == 2
assert blocks[0] == []
assert blocks[1] == []
def test_drop_empty_block_split():