[data] Refactor all to all op implementations into a separate file (#26585)

This commit is contained in:
Eric Liang 2022-07-15 18:17:48 -07:00 committed by GitHub
parent fea94dc976
commit cf980c3020
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 218 additions and 157 deletions

View file

@ -194,6 +194,8 @@ class ExecutionPlan:
Returns:
The schema of the output dataset.
"""
from ray.data._internal.stage_impl import RandomizeBlocksStage
if self._stages_after_snapshot:
if fetch_if_missing:
if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage):
@ -396,6 +398,8 @@ class ExecutionPlan:
def is_read_stage_equivalent(self) -> bool:
"""Return whether this plan can be executed as only a read stage."""
from ray.data._internal.stage_impl import RandomizeBlocksStage
context = DatasetContext.get_current()
remaining_stages = self._stages_after_snapshot
if (
@ -712,20 +716,6 @@ class AllToAllStage(Stage):
return blocks, stage_info
class RandomizeBlocksStage(AllToAllStage):
def __init__(self, seed: Optional[int]):
self._seed = seed
super().__init__("randomize_block_order", None, self.do_randomize)
def do_randomize(self, block_list, *_):
num_blocks = block_list.initial_num_blocks()
if num_blocks == 0:
return block_list, {}
randomized_block_list = block_list.randomize_block_order(self._seed)
return randomized_block_list, {}
def _rewrite_read_stages(
blocks: BlockList,
stats: DatasetStats,
@ -758,6 +748,8 @@ def _rewrite_read_stage(
Non-lazy block list containing read tasks for not-yet-read block partitions,
new stats for the block list, and the new list of stages.
"""
from ray.data._internal.stage_impl import RandomizeBlocksStage
# Generate the "GetReadTasks" stage blocks.
remote_args = in_blocks._remote_args
blocks, metadata = [], []
@ -798,6 +790,7 @@ def _reorder_stages(stages: List[Stage]) -> List[Stage]:
Returns:
Reordered stages.
"""
from ray.data._internal.stage_impl import RandomizeBlocksStage
output: List[Stage] = []
reorder_buf: List[RandomizeBlocksStage] = []

View file

@ -0,0 +1,194 @@
from typing import Optional, TYPE_CHECKING
import ray
from ray.data._internal.fast_repartition import fast_repartition
from ray.data._internal.plan import AllToAllStage
from ray.data._internal.shuffle_and_partition import (
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.block_list import BlockList
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.sort import sort_impl
from ray.data.context import DatasetContext
from ray.data.block import (
_validate_key_fn,
Block,
KeyFn,
BlockMetadata,
BlockAccessor,
BlockExecStats,
)
if TYPE_CHECKING:
from ray.data import Dataset
class RepartitionStage(AllToAllStage):
"""Implementation of `Dataset.repartition()`."""
def __init__(self, num_blocks: int, shuffle: bool):
if shuffle:
def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
shuffle_op = shuffle_op_cls(block_udf, random_shuffle=False)
return shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
super().__init__(
"repartition", num_blocks, do_shuffle, supports_block_udf=True
)
else:
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return fast_repartition(blocks, num_blocks)
super().__init__("repartition", num_blocks, do_fast_repartition)
class RandomizeBlocksStage(AllToAllStage):
"""Implementation of `Dataset.randomize_blocks()`."""
def __init__(self, seed: Optional[int]):
self._seed = seed
super().__init__("randomize_block_order", None, self.do_randomize)
def do_randomize(self, block_list, *_):
num_blocks = block_list.initial_num_blocks()
if num_blocks == 0:
return block_list, {}
randomized_block_list = block_list.randomize_block_order(self._seed)
return randomized_block_list, {}
class RandomShuffleStage(AllToAllStage):
"""Implementation of `Dataset.random_shuffle()`."""
def __init__(self, seed: Optional[int], output_num_blocks: Optional[int]):
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
if output_num_blocks is not None:
raise NotImplementedError(
"Push-based shuffle doesn't support setting num_blocks yet."
)
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
random_shuffle_op = shuffle_op_cls(
block_udf, random_shuffle=True, random_seed=seed
)
return random_shuffle_op.execute(
blocks,
output_num_blocks or num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
super().__init__(
"random_shuffle", output_num_blocks, do_shuffle, supports_block_udf=True
)
class ZipStage(AllToAllStage):
"""Implementation of `Dataset.zip()`."""
def __init__(self, other: "Dataset"):
def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()
if clear_input_blocks:
block_list.clear()
if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)
)
)
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
stats = BlockExecStats.builder()
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
# Early release memory.
del blocks1, blocks2
# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
blocks = BlockList(blocks, metadata)
return blocks, {}
super().__init__("zip", None, do_zip_all)
class SortStage(AllToAllStage):
"""Implementation of `Dataset.sort()`."""
def __init__(self, ds: "Dataset", key: Optional[KeyFn], descending: bool):
def do_sort(block_list, clear_input_blocks: bool, *_):
# Handle empty dataset.
if block_list.initial_num_blocks() == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(ds, subkey)
else:
_validate_key_fn(ds, key)
return sort_impl(blocks, clear_input_blocks, key, descending)
super().__init__("sort", None, do_sort)

View file

@ -32,22 +32,21 @@ from ray.data._internal.compute import (
TaskPoolStrategy,
)
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.fast_repartition import fast_repartition
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.plan import (
AllToAllStage,
ExecutionPlan,
OneToOneStage,
)
from ray.data._internal.stage_impl import (
RandomizeBlocksStage,
RepartitionStage,
RandomShuffleStage,
ZipStage,
SortStage,
)
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.shuffle_and_partition import (
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.sort import sort_impl
from ray.data._internal.stats import DatasetStats
from ray.data._internal.table_block import VALUE_COL_NAME
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum
@ -723,50 +722,7 @@ class Dataset(Generic[T]):
The repartitioned dataset.
"""
if shuffle:
def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
shuffle_op = shuffle_op_cls(block_udf, random_shuffle=False)
return shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage(
AllToAllStage(
"repartition", num_blocks, do_shuffle, supports_block_udf=True
)
)
else:
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return fast_repartition(blocks, num_blocks)
plan = self._plan.with_stage(
AllToAllStage("repartition", num_blocks, do_fast_repartition)
)
plan = self._plan.with_stage(RepartitionStage(num_blocks, shuffle))
return Dataset(plan, self._epoch, self._lazy)
def random_shuffle(
@ -799,36 +755,7 @@ class Dataset(Generic[T]):
The shuffled dataset.
"""
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
random_shuffle_op = shuffle_op_cls(
block_udf, random_shuffle=True, random_seed=seed
)
return random_shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage(
AllToAllStage(
"random_shuffle", num_blocks, do_shuffle, supports_block_udf=True
)
)
plan = self._plan.with_stage(RandomShuffleStage(seed, num_blocks))
return Dataset(plan, self._epoch, self._lazy)
def randomize_block_order(
@ -1834,25 +1761,7 @@ class Dataset(Generic[T]):
A new, sorted dataset.
"""
def do_sort(block_list, clear_input_blocks: bool, *_):
# Handle empty dataset.
if block_list.initial_num_blocks() == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(self, subkey)
else:
_validate_key_fn(self, key)
return sort_impl(blocks, clear_input_blocks, key, descending)
plan = self._plan.with_stage(AllToAllStage("sort", None, do_sort))
plan = self._plan.with_stage(SortStage(self, key, descending))
return Dataset(plan, self._epoch, self._lazy)
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
@ -1882,46 +1791,7 @@ class Dataset(Generic[T]):
comes from the first dataset and v comes from the second.
"""
def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()
if clear_input_blocks:
block_list.clear()
if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)
)
)
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
stats = BlockExecStats.builder()
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
# Early release memory.
del blocks1, blocks2
# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
blocks = BlockList(blocks, metadata)
return blocks, {}
plan = self._plan.with_stage(AllToAllStage("zip", None, do_zip_all))
plan = self._plan.with_stage(ZipStage(other))
return Dataset(plan, self._epoch, self._lazy)
def limit(self, limit: int) -> "Dataset[T]":

View file

@ -3975,9 +3975,13 @@ def test_random_shuffle(shutdown_only, pipelined, use_push_based_shuffle):
r2 = range(100, parallelism=1).random_shuffle().take(999)
assert r1 != r2, (r1, r2)
r1 = range(100).random_shuffle(num_blocks=1).take(999)
r2 = range(100).random_shuffle(num_blocks=1).take(999)
assert r1 != r2, (r1, r2)
# TODO(swang): fix this
if not use_push_based_shuffle:
if not pipelined:
assert range(100).random_shuffle(num_blocks=1).num_blocks() == 1
r1 = range(100).random_shuffle(num_blocks=1).take(999)
r2 = range(100).random_shuffle(num_blocks=1).take(999)
assert r1 != r2, (r1, r2)
r0 = range(100, parallelism=5).take(999)
r1 = range(100, parallelism=5).random_shuffle(seed=0).take(999)