[data] Refactor all to all op implementations into a separate file (#26585)

This commit is contained in:
Eric Liang 2022-07-15 18:17:48 -07:00 committed by GitHub
parent fea94dc976
commit cf980c3020
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 218 additions and 157 deletions

View file

@ -194,6 +194,8 @@ class ExecutionPlan:
Returns: Returns:
The schema of the output dataset. The schema of the output dataset.
""" """
from ray.data._internal.stage_impl import RandomizeBlocksStage
if self._stages_after_snapshot: if self._stages_after_snapshot:
if fetch_if_missing: if fetch_if_missing:
if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage): if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage):
@ -396,6 +398,8 @@ class ExecutionPlan:
def is_read_stage_equivalent(self) -> bool: def is_read_stage_equivalent(self) -> bool:
"""Return whether this plan can be executed as only a read stage.""" """Return whether this plan can be executed as only a read stage."""
from ray.data._internal.stage_impl import RandomizeBlocksStage
context = DatasetContext.get_current() context = DatasetContext.get_current()
remaining_stages = self._stages_after_snapshot remaining_stages = self._stages_after_snapshot
if ( if (
@ -712,20 +716,6 @@ class AllToAllStage(Stage):
return blocks, stage_info return blocks, stage_info
class RandomizeBlocksStage(AllToAllStage):
def __init__(self, seed: Optional[int]):
self._seed = seed
super().__init__("randomize_block_order", None, self.do_randomize)
def do_randomize(self, block_list, *_):
num_blocks = block_list.initial_num_blocks()
if num_blocks == 0:
return block_list, {}
randomized_block_list = block_list.randomize_block_order(self._seed)
return randomized_block_list, {}
def _rewrite_read_stages( def _rewrite_read_stages(
blocks: BlockList, blocks: BlockList,
stats: DatasetStats, stats: DatasetStats,
@ -758,6 +748,8 @@ def _rewrite_read_stage(
Non-lazy block list containing read tasks for not-yet-read block partitions, Non-lazy block list containing read tasks for not-yet-read block partitions,
new stats for the block list, and the new list of stages. new stats for the block list, and the new list of stages.
""" """
from ray.data._internal.stage_impl import RandomizeBlocksStage
# Generate the "GetReadTasks" stage blocks. # Generate the "GetReadTasks" stage blocks.
remote_args = in_blocks._remote_args remote_args = in_blocks._remote_args
blocks, metadata = [], [] blocks, metadata = [], []
@ -798,6 +790,7 @@ def _reorder_stages(stages: List[Stage]) -> List[Stage]:
Returns: Returns:
Reordered stages. Reordered stages.
""" """
from ray.data._internal.stage_impl import RandomizeBlocksStage
output: List[Stage] = [] output: List[Stage] = []
reorder_buf: List[RandomizeBlocksStage] = [] reorder_buf: List[RandomizeBlocksStage] = []

View file

@ -0,0 +1,194 @@
from typing import Optional, TYPE_CHECKING
import ray
from ray.data._internal.fast_repartition import fast_repartition
from ray.data._internal.plan import AllToAllStage
from ray.data._internal.shuffle_and_partition import (
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.block_list import BlockList
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.sort import sort_impl
from ray.data.context import DatasetContext
from ray.data.block import (
_validate_key_fn,
Block,
KeyFn,
BlockMetadata,
BlockAccessor,
BlockExecStats,
)
if TYPE_CHECKING:
from ray.data import Dataset
class RepartitionStage(AllToAllStage):
"""Implementation of `Dataset.repartition()`."""
def __init__(self, num_blocks: int, shuffle: bool):
if shuffle:
def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
shuffle_op = shuffle_op_cls(block_udf, random_shuffle=False)
return shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
super().__init__(
"repartition", num_blocks, do_shuffle, supports_block_udf=True
)
else:
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return fast_repartition(blocks, num_blocks)
super().__init__("repartition", num_blocks, do_fast_repartition)
class RandomizeBlocksStage(AllToAllStage):
"""Implementation of `Dataset.randomize_blocks()`."""
def __init__(self, seed: Optional[int]):
self._seed = seed
super().__init__("randomize_block_order", None, self.do_randomize)
def do_randomize(self, block_list, *_):
num_blocks = block_list.initial_num_blocks()
if num_blocks == 0:
return block_list, {}
randomized_block_list = block_list.randomize_block_order(self._seed)
return randomized_block_list, {}
class RandomShuffleStage(AllToAllStage):
"""Implementation of `Dataset.random_shuffle()`."""
def __init__(self, seed: Optional[int], output_num_blocks: Optional[int]):
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
if output_num_blocks is not None:
raise NotImplementedError(
"Push-based shuffle doesn't support setting num_blocks yet."
)
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
random_shuffle_op = shuffle_op_cls(
block_udf, random_shuffle=True, random_seed=seed
)
return random_shuffle_op.execute(
blocks,
output_num_blocks or num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
super().__init__(
"random_shuffle", output_num_blocks, do_shuffle, supports_block_udf=True
)
class ZipStage(AllToAllStage):
"""Implementation of `Dataset.zip()`."""
def __init__(self, other: "Dataset"):
def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()
if clear_input_blocks:
block_list.clear()
if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)
)
)
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
stats = BlockExecStats.builder()
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
# Early release memory.
del blocks1, blocks2
# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
blocks = BlockList(blocks, metadata)
return blocks, {}
super().__init__("zip", None, do_zip_all)
class SortStage(AllToAllStage):
"""Implementation of `Dataset.sort()`."""
def __init__(self, ds: "Dataset", key: Optional[KeyFn], descending: bool):
def do_sort(block_list, clear_input_blocks: bool, *_):
# Handle empty dataset.
if block_list.initial_num_blocks() == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(ds, subkey)
else:
_validate_key_fn(ds, key)
return sort_impl(blocks, clear_input_blocks, key, descending)
super().__init__("sort", None, do_sort)

View file

@ -32,22 +32,21 @@ from ray.data._internal.compute import (
TaskPoolStrategy, TaskPoolStrategy,
) )
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.fast_repartition import fast_repartition
from ray.data._internal.lazy_block_list import LazyBlockList from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.output_buffer import BlockOutputBuffer from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.plan import ( from ray.data._internal.plan import (
AllToAllStage,
ExecutionPlan, ExecutionPlan,
OneToOneStage, OneToOneStage,
)
from ray.data._internal.stage_impl import (
RandomizeBlocksStage, RandomizeBlocksStage,
RepartitionStage,
RandomShuffleStage,
ZipStage,
SortStage,
) )
from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.shuffle_and_partition import (
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.sort import sort_impl
from ray.data._internal.stats import DatasetStats from ray.data._internal.stats import DatasetStats
from ray.data._internal.table_block import VALUE_COL_NAME from ray.data._internal.table_block import VALUE_COL_NAME
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum
@ -723,50 +722,7 @@ class Dataset(Generic[T]):
The repartitioned dataset. The repartitioned dataset.
""" """
if shuffle: plan = self._plan.with_stage(RepartitionStage(num_blocks, shuffle))
def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
shuffle_op = shuffle_op_cls(block_udf, random_shuffle=False)
return shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage(
AllToAllStage(
"repartition", num_blocks, do_shuffle, supports_block_udf=True
)
)
else:
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return fast_repartition(blocks, num_blocks)
plan = self._plan.with_stage(
AllToAllStage("repartition", num_blocks, do_fast_repartition)
)
return Dataset(plan, self._epoch, self._lazy) return Dataset(plan, self._epoch, self._lazy)
def random_shuffle( def random_shuffle(
@ -799,36 +755,7 @@ class Dataset(Generic[T]):
The shuffled dataset. The shuffled dataset.
""" """
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): plan = self._plan.with_stage(RandomShuffleStage(seed, num_blocks))
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
context = DatasetContext.get_current()
if context.use_push_based_shuffle:
shuffle_op_cls = PushBasedShufflePartitionOp
else:
shuffle_op_cls = SimpleShufflePartitionOp
random_shuffle_op = shuffle_op_cls(
block_udf, random_shuffle=True, random_seed=seed
)
return random_shuffle_op.execute(
blocks,
num_blocks,
clear_input_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage(
AllToAllStage(
"random_shuffle", num_blocks, do_shuffle, supports_block_udf=True
)
)
return Dataset(plan, self._epoch, self._lazy) return Dataset(plan, self._epoch, self._lazy)
def randomize_block_order( def randomize_block_order(
@ -1834,25 +1761,7 @@ class Dataset(Generic[T]):
A new, sorted dataset. A new, sorted dataset.
""" """
def do_sort(block_list, clear_input_blocks: bool, *_): plan = self._plan.with_stage(SortStage(self, key, descending))
# Handle empty dataset.
if block_list.initial_num_blocks() == 0:
return block_list, {}
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
if isinstance(key, list):
if not key:
raise ValueError("`key` must be a list of non-zero length")
for subkey in key:
_validate_key_fn(self, subkey)
else:
_validate_key_fn(self, key)
return sort_impl(blocks, clear_input_blocks, key, descending)
plan = self._plan.with_stage(AllToAllStage("sort", None, do_sort))
return Dataset(plan, self._epoch, self._lazy) return Dataset(plan, self._epoch, self._lazy)
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
@ -1882,46 +1791,7 @@ class Dataset(Generic[T]):
comes from the first dataset and v comes from the second. comes from the first dataset and v comes from the second.
""" """
def do_zip_all(block_list, clear_input_blocks: bool, *_): plan = self._plan.with_stage(ZipStage(other))
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()
if clear_input_blocks:
block_list.clear()
if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)
)
)
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
stats = BlockExecStats.builder()
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
# Early release memory.
del blocks1, blocks2
# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
blocks = BlockList(blocks, metadata)
return blocks, {}
plan = self._plan.with_stage(AllToAllStage("zip", None, do_zip_all))
return Dataset(plan, self._epoch, self._lazy) return Dataset(plan, self._epoch, self._lazy)
def limit(self, limit: int) -> "Dataset[T]": def limit(self, limit: int) -> "Dataset[T]":

View file

@ -3975,9 +3975,13 @@ def test_random_shuffle(shutdown_only, pipelined, use_push_based_shuffle):
r2 = range(100, parallelism=1).random_shuffle().take(999) r2 = range(100, parallelism=1).random_shuffle().take(999)
assert r1 != r2, (r1, r2) assert r1 != r2, (r1, r2)
r1 = range(100).random_shuffle(num_blocks=1).take(999) # TODO(swang): fix this
r2 = range(100).random_shuffle(num_blocks=1).take(999) if not use_push_based_shuffle:
assert r1 != r2, (r1, r2) if not pipelined:
assert range(100).random_shuffle(num_blocks=1).num_blocks() == 1
r1 = range(100).random_shuffle(num_blocks=1).take(999)
r2 = range(100).random_shuffle(num_blocks=1).take(999)
assert r1 != r2, (r1, r2)
r0 = range(100, parallelism=5).take(999) r0 = range(100, parallelism=5).take(999)
r1 = range(100, parallelism=5).random_shuffle(seed=0).take(999) r1 = range(100, parallelism=5).random_shuffle(seed=0).take(999)