[data] Stage fusion optimizations, off by default (#22373)

This PR adds the following stage fusion optimizations (off by default). In a later PR, I plan to enable this by default for DatasetPipelines.
- Stage fusion: Whether to fuse compatible OneToOne stages.
- Read stage fusion: Whether to fuse read stages into downstream OneToOne stages. This is accomplished by rewriting the read stage (LazyBlockList) into a transformation over a collection of read tasks (BlockList -> MapBatches(do_read)).
- Shuffle stage fusion: Whether to fuse compatible OneToOne stages into shuffle stages that support specifying a map-side block UDF.

Stages are considered compatible if their compute strategy is the same ("tasks" vs "actors"), and they have the same Ray remote args. Currently, the PR is ignoring the remote args of read tasks, but this will be fixed as a followup (I didn't want to change the read tasks default here).
This commit is contained in:
Eric Liang 2022-02-16 21:08:27 -08:00 committed by GitHub
parent e10a2fbcf9
commit 786c5759de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 505 additions and 49 deletions

View file

@ -18,6 +18,17 @@ DEFAULT_BLOCK_SPLITTING_ENABLED = False
# TODO (kfstorm): Remove this once stable.
DEFAULT_ENABLE_PANDAS_BLOCK = True
# Whether to enable stage-fusion optimizations for dataset pipelines.
# TODO(ekl): enable this by default when ready.
DEFAULT_OPTIMIZE_FUSE_STAGES = False
# Whether to furthermore fuse read stages. When this is enabled, data will also be
# re-read from the base dataset in each repetition of a DatasetPipeline.
DEFAULT_OPTIMIZE_FUSE_READ_STAGES = True
# Whether to furthermore fuse prior map tasks with shuffle stages.
DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES = True
@DeveloperAPI
class DatasetContext:
@ -33,12 +44,22 @@ class DatasetContext:
block_splitting_enabled: bool,
target_max_block_size: int,
enable_pandas_block: bool,
optimize_fuse_stages: bool,
optimize_fuse_read_stages: bool,
optimize_fuse_shuffle_stages: bool,
):
"""Private constructor (use get_current() instead)."""
self.block_owner = block_owner
self.block_splitting_enabled = block_splitting_enabled
self.target_max_block_size = target_max_block_size
self.enable_pandas_block = enable_pandas_block
self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = (
optimize_fuse_stages and optimize_fuse_read_stages
)
self.optimize_fuse_shuffle_stages = (
optimize_fuse_stages and optimize_fuse_shuffle_stages
)
@staticmethod
def get_current() -> "DatasetContext":
@ -57,6 +78,9 @@ class DatasetContext:
block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED,
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE,
enable_pandas_block=DEFAULT_ENABLE_PANDAS_BLOCK,
optimize_fuse_stages=DEFAULT_OPTIMIZE_FUSE_STAGES,
optimize_fuse_read_stages=DEFAULT_OPTIMIZE_FUSE_READ_STAGES,
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
)
if (

View file

@ -2498,29 +2498,59 @@ Dict[str, List[str]]]): The names of the columns
"""
from ray.data.dataset_pipeline import DatasetPipeline
# If optimizations are enabled, rewrite the read stage into a OneToOneStage
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
blocks = self._plan.execute()
read_stage = None
outer_stats = self._plan.stats()
if times is not None and times < 1:
raise ValueError("`times` must be >= 1, got {}".format(times))
uuid = self._get_uuid()
class Iterator:
def __init__(self, ds: "Dataset[T]"):
self._ds = ds
def __init__(self, blocks):
self._blocks = blocks
self._i = 0
def __next__(self) -> "Dataset[T]":
if times and self._i >= times:
raise StopIteration
self._ds._set_epoch(self._i)
epoch = self._i
blocks = self._blocks
self._i += 1
return lambda: self._ds.fully_executed()
def gen():
ds = Dataset(
ExecutionPlan(blocks, outer_stats, dataset_uuid=uuid),
epoch,
lazy=False,
)
ds._set_uuid(uuid)
return ds
return gen
class Iterable:
def __init__(self, ds: "Dataset[T]"):
self._ds = ds
def __init__(self, blocks):
self._blocks = blocks
def __iter__(self):
return Iterator(self._ds)
return Iterator(self._blocks)
return DatasetPipeline(Iterable(self), length=times or float("inf"))
pipe = DatasetPipeline(Iterable(blocks), length=times or float("inf"))
if read_stage:
pipe = pipe.foreach_window(
lambda ds, read_stage=read_stage: Dataset(
ds._plan.with_stage(read_stage), ds._epoch, True
)
)
return pipe
def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]":
raise DeprecationWarning(
@ -2576,8 +2606,16 @@ Dict[str, List[str]]]): The names of the columns
"""
from ray.data.dataset_pipeline import DatasetPipeline
blocks = self._plan.execute()
outer_stats = self._plan.stats()
# If optimizations are enabled, rewrite the read stage into a OneToOneStage
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
blocks = self._plan.execute()
read_stage = None
outer_stats = self._plan.stats()
class Iterator:
def __init__(self, splits, epoch):
@ -2607,7 +2645,14 @@ Dict[str, List[str]]]): The names of the columns
return Iterator(self._splits, self._epoch)
it = Iterable(blocks, self._epoch)
return DatasetPipeline(it, length=len(it._splits))
pipe = DatasetPipeline(it, length=len(it._splits))
if read_stage:
pipe = pipe.foreach_window(
lambda ds, read_stage=read_stage: Dataset(
ds._plan.with_stage(read_stage), ds._epoch, True
)
)
return pipe
@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:

View file

@ -1,4 +1,5 @@
import inspect
import time
from typing import (
Any,
@ -21,7 +22,9 @@ from ray.data.impl.pipeline_executor import (
)
from ray.data.row import TableRow
from ray.data.impl import progress_bar
from ray.data.impl.stats import DatasetPipelineStats
from ray.data.impl.block_list import BlockList
from ray.data.impl.plan import ExecutionPlan
from ray.data.impl.stats import DatasetPipelineStats, DatasetStats
from ray.util.annotations import PublicAPI, DeveloperAPI
if TYPE_CHECKING:
@ -80,6 +83,7 @@ class DatasetPipeline(Generic[T]):
"""
self._base_iterable = base_iterable
self._stages = stages or []
self._optimized_stages = None
self._length = length
self._progress_bars = progress_bars
self._uuid = None # For testing only.
@ -610,6 +614,7 @@ class DatasetPipeline(Generic[T]):
if self._executed[0]:
raise RuntimeError("Pipeline cannot be read multiple times.")
self._executed[0] = True
self._optimize_stages()
return PipelineExecutor(self)
@DeveloperAPI
@ -681,6 +686,31 @@ class DatasetPipeline(Generic[T]):
def _set_uuid(self, uuid: str) -> None:
self._uuid = uuid
def _optimize_stages(self):
"""Optimize this pipeline, fusing stages together as possible."""
context = DatasetContext.get_current()
if not context.optimize_fuse_stages:
self._optimized_stages = self._stages
return
dummy_ds = Dataset(
ExecutionPlan(BlockList([], []), DatasetStats(stages={}, parent=None)),
0,
True,
)
for stage in self._stages:
dummy_ds = stage(dummy_ds)
dummy_ds._plan._optimize()
optimized_stages = []
for stage in dummy_ds._plan._stages:
optimized_stages.append(
lambda ds, stage=stage: Dataset(
ds._plan.with_stage(stage), ds._epoch, True
)
)
self._optimized_stages = optimized_stages
for method in PER_DATASET_OPS:

View file

@ -132,6 +132,7 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
open_stream_args: Optional[Dict[str, Any]] = None,
# TODO(ekl) deprecate this once read fusion is available.
_block_udf: Optional[Callable[[Block], Block]] = None,
**reader_args,
) -> List[ReadTask]:

View file

@ -51,20 +51,13 @@ class LazyBlockList(BlockList):
)
def clear(self) -> None:
self._block_partitions = None
# TODO(ekl) we might also want to clear this in some cases.
# self._calls = None
self._block_partitions = [None for _ in self._block_partitions]
def _check_if_cleared(self) -> None:
if self._block_partitions is None:
raise ValueError(
"This Dataset's blocks have been moved, which means that you "
"can no longer use this Dataset."
)
pass # LazyBlockList can always be re-computed.
# Note: does not force execution prior to splitting.
def split(self, split_size: int) -> List["LazyBlockList"]:
self._check_if_cleared()
num_splits = math.ceil(len(self._calls) / split_size)
calls = np.array_split(self._calls, num_splits)
meta = np.array_split(self._metadata, num_splits)
@ -76,7 +69,6 @@ class LazyBlockList(BlockList):
# Note: does not force execution prior to division.
def divide(self, part_idx: int) -> ("LazyBlockList", "LazyBlockList"):
self._check_if_cleared()
left = LazyBlockList(
self._calls[:part_idx],
self._metadata[:part_idx],
@ -98,7 +90,6 @@ class LazyBlockList(BlockList):
self,
) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
context = DatasetContext.get_current()
self._check_if_cleared()
outer = self
class Iter:
@ -126,7 +117,6 @@ class LazyBlockList(BlockList):
def _iter_block_partitions(
self,
) -> Iterator[Tuple[ObjectRef[MaybeBlockPartition], BlockPartitionMetadata]]:
self._check_if_cleared()
outer = self
class Iter:
@ -148,7 +138,6 @@ class LazyBlockList(BlockList):
return Iter()
def _get_or_compute(self, i: int) -> ObjectRef[MaybeBlockPartition]:
self._check_if_cleared()
assert i < len(self._calls), i
# Check if we need to compute more block_partitions.
if not self._block_partitions[i]:

View file

@ -31,7 +31,7 @@ class PipelineExecutor:
def __init__(self, pipeline: "DatasetPipeline[T]"):
self._pipeline: "DatasetPipeline[T]" = pipeline
self._stages: List[ObjectRef[Dataset[Any]]] = [None] * (
len(self._pipeline._stages) + 1
len(self._pipeline._optimized_stages) + 1
)
self._stage_runners = [_StageRunner.remote() for _ in self._stages]
self._iter = iter(self._pipeline._base_iterable)
@ -86,7 +86,7 @@ class PipelineExecutor:
if is_last:
output = result
else:
fn = self._pipeline._stages[i]
fn = self._pipeline._optimized_stages[i]
self._stages[i + 1] = self._stage_runners[i].run.remote(
lambda: fn(result), DatasetContext.get_current()
)
@ -115,6 +115,7 @@ class PipelineSplitExecutorCoordinator:
context: DatasetContext,
):
DatasetContext._set_current(context)
pipeline._optimize_stages()
self.executor = PipelineExecutor(pipeline)
self.n = n
self.splitter = splitter

View file

@ -1,27 +1,46 @@
from ray.data.block import Block
from ray.data.impl.block_list import BlockList
from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats
from typing import Callable, Tuple, Optional, Union, TYPE_CHECKING
from typing import Callable, Tuple, Optional, Union, Iterable, TYPE_CHECKING
import uuid
if TYPE_CHECKING:
import pyarrow
import ray
from ray.data.context import DatasetContext
from ray.data.block import Block
from ray.data.impl.block_list import BlockList
from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats
from ray.data.impl.lazy_block_list import LazyBlockList
class ExecutionPlan:
def __init__(self, in_blocks: BlockList, stats: DatasetStats):
"""A lazy execution plan for a Dataset."""
def __init__(self, in_blocks: BlockList, stats: DatasetStats, dataset_uuid=None):
"""Create a plan with no transformation stages.
Args:
in_blocks: Base list of blocks.
stats: Stats for the base blocks.
"""
self._in_blocks = in_blocks
self._out_blocks = None
self._in_stats = stats
self._out_stats = None
self._stages = []
self._dataset_uuid = uuid.uuid4().hex
self._dataset_uuid = dataset_uuid or uuid.uuid4().hex
if not stats.dataset_uuid:
stats.dataset_uuid = self._dataset_uuid
def with_stage(self, stage: "Stage"):
def with_stage(self, stage: "Stage") -> "ExecutionPlan":
"""Return a copy of this plan with the given stage appended.
Args:
stage: The stage to append.
Returns:
A new ExecutionPlan with this stage appended.
"""
if self._out_blocks:
copy = ExecutionPlan(self._out_blocks, self._out_stats)
copy._stages = [stage]
@ -32,6 +51,7 @@ class ExecutionPlan:
return copy
def initial_num_blocks(self) -> int:
"""Get the estimated number of blocks after applying all plan stages."""
if self._out_blocks:
return self._out_blocks.initial_num_blocks()
for stage in self._stages[::-1]:
@ -42,6 +62,14 @@ class ExecutionPlan:
def schema(
self, fetch_if_missing: bool = False
) -> Union[type, "pyarrow.lib.Schema"]:
"""Get the schema after applying all plan stages.
Args:
fetch_if_missing: Whether to execute the plan to fetch the schema.
Returns:
The schema of the output dataset.
"""
if self._stages:
if fetch_if_missing:
self.execute()
@ -60,6 +88,13 @@ class ExecutionPlan:
return blocks.ensure_schema_for_first_block()
def meta_count(self) -> Optional[int]:
"""Get the number of rows after applying all plan stages if possible.
This method will never trigger any computation.
Returns:
The number of records of the result Dataset, or None.
"""
if self._stages:
blocks = self._out_blocks
else:
@ -71,11 +106,21 @@ class ExecutionPlan:
return None
def execute(self, clear_input_blocks: bool = True) -> BlockList:
"""Execute this plan.
Args:
clear_input_blocks: Whether to assume ownership of the input blocks,
allowing them to be dropped from memory during execution.
Returns:
The blocks of the output dataset.
"""
# TODO: add optimizations:
# 1. task fusion of OneToOne
# 2. task fusion of OneToOne to AlltoAll
# 3. clear input blocks
if self._out_blocks is None:
self._optimize()
blocks = self._in_blocks
stats = self._in_stats
for stage in self._stages:
@ -92,15 +137,91 @@ class ExecutionPlan:
return self._out_blocks
def clear(self) -> None:
"""Clear all cached block references of this plan, including input blocks.
This will render the plan un-executable unless the root is a LazyBlockList."""
self._in_blocks.clear()
self._out_blocks = None
self._out_stats = None
def stats(self) -> DatasetStats:
"""Return stats for this plan, forcing execution if needed."""
self.execute()
return self._out_stats
def _optimize(self) -> None:
"""Apply stage fusion optimizations, updating this plan."""
context = DatasetContext.get_current()
if context.optimize_fuse_stages:
if context.optimize_fuse_read_stages:
self._rewrite_read_stages()
self._fuse_one_to_one_stages()
def _rewrite_read_stages(self) -> None:
"""Rewrites read stages into one-to-one stages."""
if self._stages and self._has_read_stage():
block_list, stage = self._rewrite_read_stage()
self._in_blocks = block_list
self._in_stats = DatasetStats(stages={}, parent=None)
self._stages.insert(0, stage)
def _has_read_stage(self) -> bool:
"""Whether this plan has a read stage for its input."""
return isinstance(self._in_blocks, LazyBlockList) and hasattr(
self._in_blocks, "_read_tasks"
)
def _is_read_stage(self) -> bool:
"""Whether this plan is a bare read stage."""
return self._has_read_stage() and not self._stages
def _rewrite_read_stage(self) -> Tuple[BlockList, "Stage"]:
"""Rewrite the read stage to a OneToOne stage over read tasks as input.
For example, suppose the plan was [Read -> MapBatches(Fn)]. These stages cannot
be fused, since read stages are handled specially.
After rewriting to [GetReadTasks -> MapBatches(DoRead) -> MapBatches(Fn)],
now we can fuse the latter two MapBatches stages into a single OneToOne stage:
[GetReadTasks -> MapBatches(DoRead -> Fn)].
"""
# Generate the "GetReadTasks" stage blocks.
blocks = []
metadata = []
for i, read_task in enumerate(self._in_blocks._read_tasks):
blocks.append(ray.put([read_task]))
metadata.append(self._in_blocks._metadata[i])
block_list = BlockList(blocks, metadata)
def block_fn(block: Block) -> Iterable[Block]:
[read_task] = block
for tmp1 in read_task._read_fn():
yield tmp1
# TODO(ekl): add num_cpus properly here and make the read default num_cpus=1.
return block_list, OneToOneStage("read", block_fn, None, {})
def _fuse_one_to_one_stages(self) -> None:
"""Fuses compatible one-to-one stages."""
optimized_stages = []
prev_stage = None
for stage in self._stages:
if prev_stage is None:
prev_stage = stage
elif stage.can_fuse(prev_stage):
prev_stage = stage.fuse(prev_stage)
else:
optimized_stages.append(prev_stage)
prev_stage = stage
if prev_stage:
optimized_stages.append(prev_stage)
prev_stage = None
self._stages = optimized_stages
class Stage:
"""Represents a Dataset transform stage (e.g., map or shuffle)."""
def __init__(self, name: str, num_blocks: Optional[int]):
self.name = name
self.num_blocks = num_blocks
@ -108,10 +229,21 @@ class Stage:
def __call__(
self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]:
"""Execute this stage against the given blocks."""
raise NotImplementedError
def can_fuse(self, other: "Stage") -> bool:
"""Return whether this can be fused with another stage."""
raise NotImplementedError
def fuse(self, other: "Stage") -> "Stage":
"""Fuse this stage with a compatible stage."""
raise NotImplementedError
class OneToOneStage(Stage):
"""A stage that transforms blocks independently (e.g., map or filter)."""
def __init__(
self,
name: str,
@ -121,9 +253,30 @@ class OneToOneStage(Stage):
):
super().__init__(name, None)
self.block_fn = block_fn
self.compute = compute
self.compute = compute or "tasks"
self.ray_remote_args = ray_remote_args
def can_fuse(self, prev: Stage):
if not isinstance(prev, OneToOneStage):
return False
if prev.compute != self.compute:
return False
if prev.ray_remote_args != self.ray_remote_args:
return False
return True
def fuse(self, prev: Stage):
name = prev.name + "->" + self.name
fn1 = prev.block_fn
fn2 = self.block_fn
def block_fn(block: Block) -> Iterable[Block]:
for tmp1 in fn1(block):
for tmp2 in fn2(tmp1):
yield tmp2
return OneToOneStage(name, block_fn, self.compute, self.ray_remote_args)
def __call__(
self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]:
@ -136,6 +289,8 @@ class OneToOneStage(Stage):
class AllToAllStage(Stage):
"""A stage that transforms blocks holistically (e.g., shuffle)."""
def __init__(
self,
name: str,
@ -149,6 +304,26 @@ class AllToAllStage(Stage):
self.supports_block_udf = supports_block_udf
self.block_udf = block_udf
def can_fuse(self, prev: Stage):
context = DatasetContext.get_current()
# TODO(ekl) also support fusing shuffle stages to subsequent 1:1 stages.
if not context.optimize_fuse_shuffle_stages:
return False
if not self.supports_block_udf:
return False
if not isinstance(prev, OneToOneStage):
return False
if prev.compute != "tasks":
return False
if prev.ray_remote_args:
return False
return True
def fuse(self, prev: Stage):
assert self.supports_block_udf
name = prev.name + "->" + self.name
return AllToAllStage(name, self.num_blocks, self.fn, True, prev.block_fn)
def __call__(
self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]:

View file

@ -114,14 +114,14 @@ def _shuffle_map(
stats = BlockExecStats.builder()
if block_udf:
# TODO(ekl) note that this effectively disables block splitting.
pieces = list(block_udf(block))
if len(pieces) > 1:
builder = BlockAccessor.for_block(pieces[0]).builder()
for p in pieces:
builder.add_block(p)
blocks = list(block_udf(block))
if len(blocks) > 1:
builder = BlockAccessor.for_block(blocks[0]).builder()
for b in blocks:
builder.add_block(b)
block = builder.build()
else:
block = pieces[0]
block = blocks[0]
block = BlockAccessor.for_block(block)
# Randomize the distribution of records to blocks.

View file

@ -116,6 +116,10 @@ def sort_impl(
for j in range(num_reducers):
ret = merge_sorted_blocks.remote(key, descending, *map_results[:, j].tolist())
reduce_results.append(ret)
# Early release memory.
del map_results
merge_bar = ProgressBar("Sort Merge", len(reduce_results))
merge_bar.block_until_complete([ret[0] for ret in reduce_results])
merge_bar.close()

View file

@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import List, Optional, Dict, Set, Tuple, Union
from typing import List, Optional, Set, Dict, Tuple, Union
import time
import collections
import numpy as np
@ -54,8 +54,14 @@ class _DatasetStatsBuilder:
def build_multistage(
self, stages: Dict[str, List[BlockMetadata]]
) -> "DatasetStats":
stage_infos = {}
for i, (k, v) in enumerate(stages.items()):
if i == 0:
stage_infos[self.stage_name + "_" + k] = v
else:
stage_infos[self.stage_name.split("->")[-1] + "_" + k] = v
stats = DatasetStats(
stages={self.stage_name + "_" + k: v for k, v in stages.items()},
stages=stage_infos,
parent=self.parent,
)
stats.time_total_s = time.perf_counter() - self.start_time

View file

@ -277,6 +277,8 @@ def read_datasource(
metadata.append(task.get_metadata())
block_list = LazyBlockList(calls, metadata)
# TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly.
block_list._read_tasks = read_tasks
# Get the schema from the first block synchronously.
if metadata and metadata[0].schema is None:

View file

@ -3186,9 +3186,7 @@ def test_random_shuffle(shutdown_only, pipelined):
with pytest.raises(RuntimeError):
ds = ds.map(lambda x: x).take(999)
else:
# Source dataset should be unusable if not pipelining.
with pytest.raises(ValueError):
ds = ds.map(lambda x: x).take(999)
ds = ds.map(lambda x: x).take(999)
r2 = range(100).random_shuffle(_move=True).take(999)
assert r1 != r2, (r1, r2)

View file

@ -0,0 +1,182 @@
import pytest
import pandas as pd
import os
import ray
from ray.data.context import DatasetContext
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.tests.conftest import * # noqa
def expect_stages(pipe, num_stages_expected, stage_names):
stats = pipe.stats()
for name in stage_names:
name = " " + name + ":"
assert name in stats, (name, stats)
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
def test_optimize_fuse(ray_start_regular_shared):
context = DatasetContext.get_current()
def build_pipe():
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.random_shuffle_each_window()
results = [sorted(p.take()) for p in pipe.iter_epochs()]
assert results == [[0, 1, 2], [0, 1, 2]], results
return pipe
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
expect_stages(
build_pipe(),
1,
["read->map_batches->map_batches->random_shuffle_map", "random_shuffle_reduce"],
)
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = True
expect_stages(
build_pipe(),
1,
[
"read",
"map_batches->map_batches->random_shuffle_map",
"random_shuffle_reduce",
],
)
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = False
expect_stages(
build_pipe(),
2,
[
"read",
"map_batches->map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
context.optimize_fuse_stages = False
context.optimize_fuse_read_stages = False
context.optimize_fuse_shuffle_stages = False
expect_stages(
build_pipe(),
3,
[
"read",
"map_batches",
"map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
def test_optimize_incompatible_stages(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x, compute="actors")
pipe = pipe.map_batches(lambda x: x, compute="tasks")
pipe = pipe.random_shuffle_each_window()
pipe.take()
expect_stages(
pipe,
3,
[
"read",
"map_batches",
"map_batches->random_shuffle_map",
"random_shuffle_reduce",
],
)
pipe = ray.data.range(3).repeat(2)
pipe = pipe.map_batches(lambda x: x, compute="tasks")
pipe = pipe.map_batches(lambda x: x, num_cpus=0.75)
pipe = pipe.random_shuffle_each_window()
pipe.take()
expect_stages(
pipe,
3,
[
"read->map_batches",
"map_batches",
"random_shuffle_map",
"random_shuffle_reduce",
],
)
@ray.remote
class Counter:
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
def get(self):
return self.value
def reset(self):
self.value = 0
class MySource(CSVDatasource):
def __init__(self, counter):
self.counter = counter
def _read_stream(self, f, path: str, **reader_args):
count = self.counter.increment.remote()
ray.get(count)
for block in CSVDatasource._read_stream(self, f, path, **reader_args):
yield block
def test_optimize_reread_base_data(ray_start_regular_shared, local_path):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
context.optimize_fuse_read_stages = True
context.optimize_fuse_shuffle_stages = True
# Re-read on.
N = 4
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
path1 = os.path.join(local_path, "test1.csv")
df1.to_csv(path1, index=False, storage_options={})
counter = Counter.remote()
source = MySource(counter)
ds1 = ray.data.read_datasource(source, parallelism=1, paths=path1)
pipe = ds1.repeat(N)
pipe.take()
num_reads = ray.get(counter.get.remote())
assert num_reads == N + 1, num_reads
# Re-read off.
context.optimize_fuse_read_stages = False
ray.get(counter.reset.remote())
ds1 = ray.data.read_datasource(source, parallelism=1, paths=path1)
pipe = ds1.repeat(N)
pipe.take()
num_reads = ray.get(counter.get.remote())
assert num_reads == 1, num_reads
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -139,7 +139,6 @@ def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
ds.write_parquet(str(tmp_path))
ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x)
stats = canonicalize(ds.stats())
print(stats)
assert (
stats
== """Stage Z read: N/N blocks executed in T