[data] Stage fusion optimizations, off by default (#22373)

This PR adds the following stage fusion optimizations (off by default). In a later PR, I plan to enable this by default for DatasetPipelines.
- Stage fusion: Whether to fuse compatible OneToOne stages.
- Read stage fusion: Whether to fuse read stages into downstream OneToOne stages. This is accomplished by rewriting the read stage (LazyBlockList) into a transformation over a collection of read tasks (BlockList -> MapBatches(do_read)).
- Shuffle stage fusion: Whether to fuse compatible OneToOne stages into shuffle stages that support specifying a map-side block UDF.

Stages are considered compatible if their compute strategy is the same ("tasks" vs "actors"), and they have the same Ray remote args. Currently, the PR is ignoring the remote args of read tasks, but this will be fixed as a followup (I didn't want to change the read tasks default here).
This commit is contained in:
Eric Liang 2022-02-16 21:08:27 -08:00 committed by GitHub
parent e10a2fbcf9
commit 786c5759de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 505 additions and 49 deletions

View file

@ -18,6 +18,17 @@ DEFAULT_BLOCK_SPLITTING_ENABLED = False
# TODO (kfstorm): Remove this once stable. # TODO (kfstorm): Remove this once stable.
DEFAULT_ENABLE_PANDAS_BLOCK = True DEFAULT_ENABLE_PANDAS_BLOCK = True
# Whether to enable stage-fusion optimizations for dataset pipelines.
# TODO(ekl): enable this by default when ready.
DEFAULT_OPTIMIZE_FUSE_STAGES = False
# Whether to furthermore fuse read stages. When this is enabled, data will also be
# re-read from the base dataset in each repetition of a DatasetPipeline.
DEFAULT_OPTIMIZE_FUSE_READ_STAGES = True
# Whether to furthermore fuse prior map tasks with shuffle stages.
DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES = True
@DeveloperAPI @DeveloperAPI
class DatasetContext: class DatasetContext:
@ -33,12 +44,22 @@ class DatasetContext:
block_splitting_enabled: bool, block_splitting_enabled: bool,
target_max_block_size: int, target_max_block_size: int,
enable_pandas_block: bool, enable_pandas_block: bool,
optimize_fuse_stages: bool,
optimize_fuse_read_stages: bool,
optimize_fuse_shuffle_stages: bool,
): ):
"""Private constructor (use get_current() instead).""" """Private constructor (use get_current() instead)."""
self.block_owner = block_owner self.block_owner = block_owner
self.block_splitting_enabled = block_splitting_enabled self.block_splitting_enabled = block_splitting_enabled
self.target_max_block_size = target_max_block_size self.target_max_block_size = target_max_block_size
self.enable_pandas_block = enable_pandas_block self.enable_pandas_block = enable_pandas_block
self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = (
optimize_fuse_stages and optimize_fuse_read_stages
)
self.optimize_fuse_shuffle_stages = (
optimize_fuse_stages and optimize_fuse_shuffle_stages
)
@staticmethod @staticmethod
def get_current() -> "DatasetContext": def get_current() -> "DatasetContext":
@ -57,6 +78,9 @@ class DatasetContext:
block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED, block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED,
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE, target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE,
enable_pandas_block=DEFAULT_ENABLE_PANDAS_BLOCK, enable_pandas_block=DEFAULT_ENABLE_PANDAS_BLOCK,
optimize_fuse_stages=DEFAULT_OPTIMIZE_FUSE_STAGES,
optimize_fuse_read_stages=DEFAULT_OPTIMIZE_FUSE_READ_STAGES,
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
) )
if ( if (

View file

@ -2498,29 +2498,59 @@ Dict[str, List[str]]]): The names of the columns
""" """
from ray.data.dataset_pipeline import DatasetPipeline from ray.data.dataset_pipeline import DatasetPipeline
# If optimizations are enabled, rewrite the read stage into a OneToOneStage
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
blocks = self._plan.execute()
read_stage = None
outer_stats = self._plan.stats()
if times is not None and times < 1: if times is not None and times < 1:
raise ValueError("`times` must be >= 1, got {}".format(times)) raise ValueError("`times` must be >= 1, got {}".format(times))
uuid = self._get_uuid()
class Iterator: class Iterator:
def __init__(self, ds: "Dataset[T]"): def __init__(self, blocks):
self._ds = ds self._blocks = blocks
self._i = 0 self._i = 0
def __next__(self) -> "Dataset[T]": def __next__(self) -> "Dataset[T]":
if times and self._i >= times: if times and self._i >= times:
raise StopIteration raise StopIteration
self._ds._set_epoch(self._i) epoch = self._i
blocks = self._blocks
self._i += 1 self._i += 1
return lambda: self._ds.fully_executed()
def gen():
ds = Dataset(
ExecutionPlan(blocks, outer_stats, dataset_uuid=uuid),
epoch,
lazy=False,
)
ds._set_uuid(uuid)
return ds
return gen
class Iterable: class Iterable:
def __init__(self, ds: "Dataset[T]"): def __init__(self, blocks):
self._ds = ds self._blocks = blocks
def __iter__(self): def __iter__(self):
return Iterator(self._ds) return Iterator(self._blocks)
return DatasetPipeline(Iterable(self), length=times or float("inf")) pipe = DatasetPipeline(Iterable(blocks), length=times or float("inf"))
if read_stage:
pipe = pipe.foreach_window(
lambda ds, read_stage=read_stage: Dataset(
ds._plan.with_stage(read_stage), ds._epoch, True
)
)
return pipe
def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]": def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]":
raise DeprecationWarning( raise DeprecationWarning(
@ -2576,8 +2606,16 @@ Dict[str, List[str]]]): The names of the columns
""" """
from ray.data.dataset_pipeline import DatasetPipeline from ray.data.dataset_pipeline import DatasetPipeline
blocks = self._plan.execute() # If optimizations are enabled, rewrite the read stage into a OneToOneStage
outer_stats = self._plan.stats() # to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
blocks = self._plan.execute()
read_stage = None
outer_stats = self._plan.stats()
class Iterator: class Iterator:
def __init__(self, splits, epoch): def __init__(self, splits, epoch):
@ -2607,7 +2645,14 @@ Dict[str, List[str]]]): The names of the columns
return Iterator(self._splits, self._epoch) return Iterator(self._splits, self._epoch)
it = Iterable(blocks, self._epoch) it = Iterable(blocks, self._epoch)
return DatasetPipeline(it, length=len(it._splits)) pipe = DatasetPipeline(it, length=len(it._splits))
if read_stage:
pipe = pipe.foreach_window(
lambda ds, read_stage=read_stage: Dataset(
ds._plan.with_stage(read_stage), ds._epoch, True
)
)
return pipe
@DeveloperAPI @DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]: def get_internal_block_refs(self) -> List[ObjectRef[Block]]:

View file

@ -1,4 +1,5 @@
import inspect import inspect
import time import time
from typing import ( from typing import (
Any, Any,
@ -21,7 +22,9 @@ from ray.data.impl.pipeline_executor import (
) )
from ray.data.row import TableRow from ray.data.row import TableRow
from ray.data.impl import progress_bar from ray.data.impl import progress_bar
from ray.data.impl.stats import DatasetPipelineStats from ray.data.impl.block_list import BlockList
from ray.data.impl.plan import ExecutionPlan
from ray.data.impl.stats import DatasetPipelineStats, DatasetStats
from ray.util.annotations import PublicAPI, DeveloperAPI from ray.util.annotations import PublicAPI, DeveloperAPI
if TYPE_CHECKING: if TYPE_CHECKING:
@ -80,6 +83,7 @@ class DatasetPipeline(Generic[T]):
""" """
self._base_iterable = base_iterable self._base_iterable = base_iterable
self._stages = stages or [] self._stages = stages or []
self._optimized_stages = None
self._length = length self._length = length
self._progress_bars = progress_bars self._progress_bars = progress_bars
self._uuid = None # For testing only. self._uuid = None # For testing only.
@ -610,6 +614,7 @@ class DatasetPipeline(Generic[T]):
if self._executed[0]: if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.") raise RuntimeError("Pipeline cannot be read multiple times.")
self._executed[0] = True self._executed[0] = True
self._optimize_stages()
return PipelineExecutor(self) return PipelineExecutor(self)
@DeveloperAPI @DeveloperAPI
@ -681,6 +686,31 @@ class DatasetPipeline(Generic[T]):
def _set_uuid(self, uuid: str) -> None: def _set_uuid(self, uuid: str) -> None:
self._uuid = uuid self._uuid = uuid
def _optimize_stages(self):
"""Optimize this pipeline, fusing stages together as possible."""
context = DatasetContext.get_current()
if not context.optimize_fuse_stages:
self._optimized_stages = self._stages
return
dummy_ds = Dataset(
ExecutionPlan(BlockList([], []), DatasetStats(stages={}, parent=None)),
0,
True,
)
for stage in self._stages:
dummy_ds = stage(dummy_ds)
dummy_ds._plan._optimize()
optimized_stages = []
for stage in dummy_ds._plan._stages:
optimized_stages.append(
lambda ds, stage=stage: Dataset(
ds._plan.with_stage(stage), ds._epoch, True
)
)
self._optimized_stages = optimized_stages
for method in PER_DATASET_OPS: for method in PER_DATASET_OPS:

View file

@ -132,6 +132,7 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
filesystem: Optional["pyarrow.fs.FileSystem"] = None, filesystem: Optional["pyarrow.fs.FileSystem"] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
open_stream_args: Optional[Dict[str, Any]] = None, open_stream_args: Optional[Dict[str, Any]] = None,
# TODO(ekl) deprecate this once read fusion is available.
_block_udf: Optional[Callable[[Block], Block]] = None, _block_udf: Optional[Callable[[Block], Block]] = None,
**reader_args, **reader_args,
) -> List[ReadTask]: ) -> List[ReadTask]:

View file

@ -51,20 +51,13 @@ class LazyBlockList(BlockList):
) )
def clear(self) -> None: def clear(self) -> None:
self._block_partitions = None self._block_partitions = [None for _ in self._block_partitions]
# TODO(ekl) we might also want to clear this in some cases.
# self._calls = None
def _check_if_cleared(self) -> None: def _check_if_cleared(self) -> None:
if self._block_partitions is None: pass # LazyBlockList can always be re-computed.
raise ValueError(
"This Dataset's blocks have been moved, which means that you "
"can no longer use this Dataset."
)
# Note: does not force execution prior to splitting. # Note: does not force execution prior to splitting.
def split(self, split_size: int) -> List["LazyBlockList"]: def split(self, split_size: int) -> List["LazyBlockList"]:
self._check_if_cleared()
num_splits = math.ceil(len(self._calls) / split_size) num_splits = math.ceil(len(self._calls) / split_size)
calls = np.array_split(self._calls, num_splits) calls = np.array_split(self._calls, num_splits)
meta = np.array_split(self._metadata, num_splits) meta = np.array_split(self._metadata, num_splits)
@ -76,7 +69,6 @@ class LazyBlockList(BlockList):
# Note: does not force execution prior to division. # Note: does not force execution prior to division.
def divide(self, part_idx: int) -> ("LazyBlockList", "LazyBlockList"): def divide(self, part_idx: int) -> ("LazyBlockList", "LazyBlockList"):
self._check_if_cleared()
left = LazyBlockList( left = LazyBlockList(
self._calls[:part_idx], self._calls[:part_idx],
self._metadata[:part_idx], self._metadata[:part_idx],
@ -98,7 +90,6 @@ class LazyBlockList(BlockList):
self, self,
) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]: ) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
context = DatasetContext.get_current() context = DatasetContext.get_current()
self._check_if_cleared()
outer = self outer = self
class Iter: class Iter:
@ -126,7 +117,6 @@ class LazyBlockList(BlockList):
def _iter_block_partitions( def _iter_block_partitions(
self, self,
) -> Iterator[Tuple[ObjectRef[MaybeBlockPartition], BlockPartitionMetadata]]: ) -> Iterator[Tuple[ObjectRef[MaybeBlockPartition], BlockPartitionMetadata]]:
self._check_if_cleared()
outer = self outer = self
class Iter: class Iter:
@ -148,7 +138,6 @@ class LazyBlockList(BlockList):
return Iter() return Iter()
def _get_or_compute(self, i: int) -> ObjectRef[MaybeBlockPartition]: def _get_or_compute(self, i: int) -> ObjectRef[MaybeBlockPartition]:
self._check_if_cleared()
assert i < len(self._calls), i assert i < len(self._calls), i
# Check if we need to compute more block_partitions. # Check if we need to compute more block_partitions.
if not self._block_partitions[i]: if not self._block_partitions[i]:

View file

@ -31,7 +31,7 @@ class PipelineExecutor:
def __init__(self, pipeline: "DatasetPipeline[T]"): def __init__(self, pipeline: "DatasetPipeline[T]"):
self._pipeline: "DatasetPipeline[T]" = pipeline self._pipeline: "DatasetPipeline[T]" = pipeline
self._stages: List[ObjectRef[Dataset[Any]]] = [None] * ( self._stages: List[ObjectRef[Dataset[Any]]] = [None] * (
len(self._pipeline._stages) + 1 len(self._pipeline._optimized_stages) + 1
) )
self._stage_runners = [_StageRunner.remote() for _ in self._stages] self._stage_runners = [_StageRunner.remote() for _ in self._stages]
self._iter = iter(self._pipeline._base_iterable) self._iter = iter(self._pipeline._base_iterable)
@ -86,7 +86,7 @@ class PipelineExecutor:
if is_last: if is_last:
output = result output = result
else: else:
fn = self._pipeline._stages[i] fn = self._pipeline._optimized_stages[i]
self._stages[i + 1] = self._stage_runners[i].run.remote( self._stages[i + 1] = self._stage_runners[i].run.remote(
lambda: fn(result), DatasetContext.get_current() lambda: fn(result), DatasetContext.get_current()
) )
@ -115,6 +115,7 @@ class PipelineSplitExecutorCoordinator:
context: DatasetContext, context: DatasetContext,
): ):
DatasetContext._set_current(context) DatasetContext._set_current(context)
pipeline._optimize_stages()
self.executor = PipelineExecutor(pipeline) self.executor = PipelineExecutor(pipeline)
self.n = n self.n = n
self.splitter = splitter self.splitter = splitter

View file

@ -1,27 +1,46 @@
from ray.data.block import Block from typing import Callable, Tuple, Optional, Union, Iterable, TYPE_CHECKING
from ray.data.impl.block_list import BlockList
from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats
from typing import Callable, Tuple, Optional, Union, TYPE_CHECKING
import uuid import uuid
if TYPE_CHECKING: if TYPE_CHECKING:
import pyarrow import pyarrow
import ray
from ray.data.context import DatasetContext
from ray.data.block import Block
from ray.data.impl.block_list import BlockList
from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats
from ray.data.impl.lazy_block_list import LazyBlockList
class ExecutionPlan: class ExecutionPlan:
def __init__(self, in_blocks: BlockList, stats: DatasetStats): """A lazy execution plan for a Dataset."""
def __init__(self, in_blocks: BlockList, stats: DatasetStats, dataset_uuid=None):
"""Create a plan with no transformation stages.
Args:
in_blocks: Base list of blocks.
stats: Stats for the base blocks.
"""
self._in_blocks = in_blocks self._in_blocks = in_blocks
self._out_blocks = None self._out_blocks = None
self._in_stats = stats self._in_stats = stats
self._out_stats = None self._out_stats = None
self._stages = [] self._stages = []
self._dataset_uuid = uuid.uuid4().hex self._dataset_uuid = dataset_uuid or uuid.uuid4().hex
if not stats.dataset_uuid: if not stats.dataset_uuid:
stats.dataset_uuid = self._dataset_uuid stats.dataset_uuid = self._dataset_uuid
def with_stage(self, stage: "Stage"): def with_stage(self, stage: "Stage") -> "ExecutionPlan":
"""Return a copy of this plan with the given stage appended.
Args:
stage: The stage to append.
Returns:
A new ExecutionPlan with this stage appended.
"""
if self._out_blocks: if self._out_blocks:
copy = ExecutionPlan(self._out_blocks, self._out_stats) copy = ExecutionPlan(self._out_blocks, self._out_stats)
copy._stages = [stage] copy._stages = [stage]
@ -32,6 +51,7 @@ class ExecutionPlan:
return copy return copy
def initial_num_blocks(self) -> int: def initial_num_blocks(self) -> int:
"""Get the estimated number of blocks after applying all plan stages."""
if self._out_blocks: if self._out_blocks:
return self._out_blocks.initial_num_blocks() return self._out_blocks.initial_num_blocks()
for stage in self._stages[::-1]: for stage in self._stages[::-1]:
@ -42,6 +62,14 @@ class ExecutionPlan:
def schema( def schema(
self, fetch_if_missing: bool = False self, fetch_if_missing: bool = False
) -> Union[type, "pyarrow.lib.Schema"]: ) -> Union[type, "pyarrow.lib.Schema"]:
"""Get the schema after applying all plan stages.
Args:
fetch_if_missing: Whether to execute the plan to fetch the schema.
Returns:
The schema of the output dataset.
"""
if self._stages: if self._stages:
if fetch_if_missing: if fetch_if_missing:
self.execute() self.execute()
@ -60,6 +88,13 @@ class ExecutionPlan:
return blocks.ensure_schema_for_first_block() return blocks.ensure_schema_for_first_block()
def meta_count(self) -> Optional[int]: def meta_count(self) -> Optional[int]:
"""Get the number of rows after applying all plan stages if possible.
This method will never trigger any computation.
Returns:
The number of records of the result Dataset, or None.
"""
if self._stages: if self._stages:
blocks = self._out_blocks blocks = self._out_blocks
else: else:
@ -71,11 +106,21 @@ class ExecutionPlan:
return None return None
def execute(self, clear_input_blocks: bool = True) -> BlockList: def execute(self, clear_input_blocks: bool = True) -> BlockList:
"""Execute this plan.
Args:
clear_input_blocks: Whether to assume ownership of the input blocks,
allowing them to be dropped from memory during execution.
Returns:
The blocks of the output dataset.
"""
# TODO: add optimizations: # TODO: add optimizations:
# 1. task fusion of OneToOne # 1. task fusion of OneToOne
# 2. task fusion of OneToOne to AlltoAll # 2. task fusion of OneToOne to AlltoAll
# 3. clear input blocks # 3. clear input blocks
if self._out_blocks is None: if self._out_blocks is None:
self._optimize()
blocks = self._in_blocks blocks = self._in_blocks
stats = self._in_stats stats = self._in_stats
for stage in self._stages: for stage in self._stages:
@ -92,15 +137,91 @@ class ExecutionPlan:
return self._out_blocks return self._out_blocks
def clear(self) -> None: def clear(self) -> None:
"""Clear all cached block references of this plan, including input blocks.
This will render the plan un-executable unless the root is a LazyBlockList."""
self._in_blocks.clear()
self._out_blocks = None self._out_blocks = None
self._out_stats = None self._out_stats = None
def stats(self) -> DatasetStats: def stats(self) -> DatasetStats:
"""Return stats for this plan, forcing execution if needed."""
self.execute() self.execute()
return self._out_stats return self._out_stats
def _optimize(self) -> None:
"""Apply stage fusion optimizations, updating this plan."""
context = DatasetContext.get_current()
if context.optimize_fuse_stages:
if context.optimize_fuse_read_stages:
self._rewrite_read_stages()
self._fuse_one_to_one_stages()
def _rewrite_read_stages(self) -> None:
"""Rewrites read stages into one-to-one stages."""
if self._stages and self._has_read_stage():
block_list, stage = self._rewrite_read_stage()
self._in_blocks = block_list
self._in_stats = DatasetStats(stages={}, parent=None)
self._stages.insert(0, stage)
def _has_read_stage(self) -> bool:
"""Whether this plan has a read stage for its input."""
return isinstance(self._in_blocks, LazyBlockList) and hasattr(
self._in_blocks, "_read_tasks"
)
def _is_read_stage(self) -> bool:
"""Whether this plan is a bare read stage."""
return self._has_read_stage() and not self._stages
def _rewrite_read_stage(self) -> Tuple[BlockList, "Stage"]:
"""Rewrite the read stage to a OneToOne stage over read tasks as input.
For example, suppose the plan was [Read -> MapBatches(Fn)]. These stages cannot
be fused, since read stages are handled specially.
After rewriting to [GetReadTasks -> MapBatches(DoRead) -> MapBatches(Fn)],
now we can fuse the latter two MapBatches stages into a single OneToOne stage:
[GetReadTasks -> MapBatches(DoRead -> Fn)].
"""
# Generate the "GetReadTasks" stage blocks.
blocks = []
metadata = []
for i, read_task in enumerate(self._in_blocks._read_tasks):
blocks.append(ray.put([read_task]))
metadata.append(self._in_blocks._metadata[i])
block_list = BlockList(blocks, metadata)
def block_fn(block: Block) -> Iterable[Block]:
[read_task] = block
for tmp1 in read_task._read_fn():
yield tmp1
# TODO(ekl): add num_cpus properly here and make the read default num_cpus=1.
return block_list, OneToOneStage("read", block_fn, None, {})
def _fuse_one_to_one_stages(self) -> None:
"""Fuses compatible one-to-one stages."""
optimized_stages = []
prev_stage = None
for stage in self._stages:
if prev_stage is None:
prev_stage = stage
elif stage.can_fuse(prev_stage):
prev_stage = stage.fuse(prev_stage)
else:
optimized_stages.append(prev_stage)
prev_stage = stage
if prev_stage:
optimized_stages.append(prev_stage)
prev_stage = None
self._stages = optimized_stages
class Stage: class Stage:
"""Represents a Dataset transform stage (e.g., map or shuffle)."""
def __init__(self, name: str, num_blocks: Optional[int]): def __init__(self, name: str, num_blocks: Optional[int]):
self.name = name self.name = name
self.num_blocks = num_blocks self.num_blocks = num_blocks
@ -108,10 +229,21 @@ class Stage:
def __call__( def __call__(
self, blocks: BlockList, clear_input_blocks: bool self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]: ) -> Tuple[BlockList, dict]:
"""Execute this stage against the given blocks."""
raise NotImplementedError
def can_fuse(self, other: "Stage") -> bool:
"""Return whether this can be fused with another stage."""
raise NotImplementedError
def fuse(self, other: "Stage") -> "Stage":
"""Fuse this stage with a compatible stage."""
raise NotImplementedError raise NotImplementedError
class OneToOneStage(Stage): class OneToOneStage(Stage):
"""A stage that transforms blocks independently (e.g., map or filter)."""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -121,9 +253,30 @@ class OneToOneStage(Stage):
): ):
super().__init__(name, None) super().__init__(name, None)
self.block_fn = block_fn self.block_fn = block_fn
self.compute = compute self.compute = compute or "tasks"
self.ray_remote_args = ray_remote_args self.ray_remote_args = ray_remote_args
def can_fuse(self, prev: Stage):
if not isinstance(prev, OneToOneStage):
return False
if prev.compute != self.compute:
return False
if prev.ray_remote_args != self.ray_remote_args:
return False
return True
def fuse(self, prev: Stage):
name = prev.name + "->" + self.name
fn1 = prev.block_fn
fn2 = self.block_fn
def block_fn(block: Block) -> Iterable[Block]:
for tmp1 in fn1(block):
for tmp2 in fn2(tmp1):
yield tmp2
return OneToOneStage(name, block_fn, self.compute, self.ray_remote_args)
def __call__( def __call__(
self, blocks: BlockList, clear_input_blocks: bool self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]: ) -> Tuple[BlockList, dict]:
@ -136,6 +289,8 @@ class OneToOneStage(Stage):
class AllToAllStage(Stage): class AllToAllStage(Stage):
"""A stage that transforms blocks holistically (e.g., shuffle)."""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -149,6 +304,26 @@ class AllToAllStage(Stage):
self.supports_block_udf = supports_block_udf self.supports_block_udf = supports_block_udf
self.block_udf = block_udf self.block_udf = block_udf
def can_fuse(self, prev: Stage):
context = DatasetContext.get_current()
# TODO(ekl) also support fusing shuffle stages to subsequent 1:1 stages.
if not context.optimize_fuse_shuffle_stages:
return False
if not self.supports_block_udf:
return False
if not isinstance(prev, OneToOneStage):
return False
if prev.compute != "tasks":
return False
if prev.ray_remote_args:
return False
return True
def fuse(self, prev: Stage):
assert self.supports_block_udf
name = prev.name + "->" + self.name
return AllToAllStage(name, self.num_blocks, self.fn, True, prev.block_fn)
def __call__( def __call__(
self, blocks: BlockList, clear_input_blocks: bool self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]: ) -> Tuple[BlockList, dict]:

View file

@ -114,14 +114,14 @@ def _shuffle_map(
stats = BlockExecStats.builder() stats = BlockExecStats.builder()
if block_udf: if block_udf:
# TODO(ekl) note that this effectively disables block splitting. # TODO(ekl) note that this effectively disables block splitting.
pieces = list(block_udf(block)) blocks = list(block_udf(block))
if len(pieces) > 1: if len(blocks) > 1:
builder = BlockAccessor.for_block(pieces[0]).builder() builder = BlockAccessor.for_block(blocks[0]).builder()
for p in pieces: for b in blocks:
builder.add_block(p) builder.add_block(b)
block = builder.build() block = builder.build()
else: else:
block = pieces[0] block = blocks[0]
block = BlockAccessor.for_block(block) block = BlockAccessor.for_block(block)
# Randomize the distribution of records to blocks. # Randomize the distribution of records to blocks.

View file

@ -116,6 +116,10 @@ def sort_impl(
for j in range(num_reducers): for j in range(num_reducers):
ret = merge_sorted_blocks.remote(key, descending, *map_results[:, j].tolist()) ret = merge_sorted_blocks.remote(key, descending, *map_results[:, j].tolist())
reduce_results.append(ret) reduce_results.append(ret)
# Early release memory.
del map_results
merge_bar = ProgressBar("Sort Merge", len(reduce_results)) merge_bar = ProgressBar("Sort Merge", len(reduce_results))
merge_bar.block_until_complete([ret[0] for ret in reduce_results]) merge_bar.block_until_complete([ret[0] for ret in reduce_results])
merge_bar.close() merge_bar.close()

View file

@ -1,5 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Dict, Set, Tuple, Union from typing import List, Optional, Set, Dict, Tuple, Union
import time import time
import collections import collections
import numpy as np import numpy as np
@ -54,8 +54,14 @@ class _DatasetStatsBuilder:
def build_multistage( def build_multistage(
self, stages: Dict[str, List[BlockMetadata]] self, stages: Dict[str, List[BlockMetadata]]
) -> "DatasetStats": ) -> "DatasetStats":
stage_infos = {}
for i, (k, v) in enumerate(stages.items()):
if i == 0:
stage_infos[self.stage_name + "_" + k] = v
else:
stage_infos[self.stage_name.split("->")[-1] + "_" + k] = v
stats = DatasetStats( stats = DatasetStats(
stages={self.stage_name + "_" + k: v for k, v in stages.items()}, stages=stage_infos,
parent=self.parent, parent=self.parent,
) )
stats.time_total_s = time.perf_counter() - self.start_time stats.time_total_s = time.perf_counter() - self.start_time

View file

@ -277,6 +277,8 @@ def read_datasource(
metadata.append(task.get_metadata()) metadata.append(task.get_metadata())
block_list = LazyBlockList(calls, metadata) block_list = LazyBlockList(calls, metadata)
# TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly.
block_list._read_tasks = read_tasks
# Get the schema from the first block synchronously. # Get the schema from the first block synchronously.
if metadata and metadata[0].schema is None: if metadata and metadata[0].schema is None:

View file

@ -3186,9 +3186,7 @@ def test_random_shuffle(shutdown_only, pipelined):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
ds = ds.map(lambda x: x).take(999) ds = ds.map(lambda x: x).take(999)
else: else:
# Source dataset should be unusable if not pipelining. ds = ds.map(lambda x: x).take(999)
with pytest.raises(ValueError):
ds = ds.map(lambda x: x).take(999)
r2 = range(100).random_shuffle(_move=True).take(999) r2 = range(100).random_shuffle(_move=True).take(999)
assert r1 != r2, (r1, r2) assert r1 != r2, (r1, r2)

View file

@ -0,0 +1,182 @@
import pytest
import pandas as pd
import os
import ray
from ray.data.context import DatasetContext
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.tests.conftest import * # noqa
def expect_stages(pipe, num_stages_expected, stage_names):
stats = pipe.stats()
for name in stage_names:
name = " " + name + ":"
assert name in stats, (name, stats)
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
def test_optimize_fuse(ray_start_regular_shared):
context = DatasetContext.get_current()
def build_pipe():
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.random_shuffle_each_window()
results = [sorted(p.take()) for p in pipe.iter_epochs()]
assert results == [[0, 1, 2], [0, 1, 2]], results
return pipe
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
expect_stages(
build_pipe(),
1,
["read->map_batches->map_batches->random_shuffle_map", "random_shuffle_reduce"],
)
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = True
expect_stages(
build_pipe(),
1,
[
"read",
"map_batches->map_batches->random_shuffle_map",
"random_shuffle_reduce",
],
)
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = False
expect_stages(
build_pipe(),
2,
[
"read",
"map_batches->map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
context.optimize_fuse_stages = False
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = False
expect_stages(
build_pipe(),
3,
[
"read",
"map_batches",
"map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
def test_optimize_incompatible_stages(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x, compute="actors")
pipe = pipe.map_batches(lambda x: x, compute="tasks")
pipe = pipe.random_shuffle_each_window()
pipe.take()
expect_stages(
pipe,
3,
[
"read",
"map_batches",
"map_batches->random_shuffle_map",
"random_shuffle_reduce",
],
)
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x, compute="tasks")
pipe = pipe.map_batches(lambda x: x, num_cpus=0.75)
pipe = pipe.random_shuffle_each_window()
pipe.take()
expect_stages(
pipe,
3,
[
"read->map_batches",
"map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
@ray.remote
class Counter:
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
def get(self):
return self.value
def reset(self):
self.value = 0
class MySource(CSVDatasource):
def __init__(self, counter):
self.counter = counter
def _read_stream(self, f, path: str, **reader_args):
count = self.counter.increment.remote()
ray.get(count)
for block in CSVDatasource._read_stream(self, f, path, **reader_args):
yield block
def test_optimize_reread_base_data(ray_start_regular_shared, local_path):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
# Re-read on.
N = 4
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
path1 = os.path.join(local_path, "test1.csv")
df1.to_csv(path1, index=False, storage_options={})
counter = Counter.remote()
source = MySource(counter)
ds1 = ray.data.read_datasource(source, parallelism=1, paths=path1)
pipe = ds1.repeat(N)
pipe.take()
num_reads = ray.get(counter.get.remote())
assert num_reads == N + 1, num_reads
# Re-read off.
context.optimize_fuse_read_stages = False
ray.get(counter.reset.remote())
ds1 = ray.data.read_datasource(source, parallelism=1, paths=path1)
pipe = ds1.repeat(N)
pipe.take()
num_reads = ray.get(counter.get.remote())
assert num_reads == 1, num_reads
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -139,7 +139,6 @@ def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
ds.write_parquet(str(tmp_path)) ds.write_parquet(str(tmp_path))
ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x) ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x)
stats = canonicalize(ds.stats()) stats = canonicalize(ds.stats())
print(stats)
assert ( assert (
stats stats
== """Stage Z read: N/N blocks executed in T == """Stage Z read: N/N blocks executed in T