Enable stage fusion by default for dataset pipelines (#22476)

This PR enables stage fusion for dataset pipelines. This also requires:
1. Removing the num_cpus=0.5 default for the read stage, to enable fusion of the read stage.
2. Removing spread_resource_prefix (not supported for now).
This commit is contained in:
Eric Liang 2022-02-23 17:34:05 -08:00 committed by GitHub
parent a62a9c38fb
commit e15a419028
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 164 additions and 102 deletions

View file

@ -19,8 +19,7 @@ DEFAULT_BLOCK_SPLITTING_ENABLED = False
DEFAULT_ENABLE_PANDAS_BLOCK = True DEFAULT_ENABLE_PANDAS_BLOCK = True
# Whether to enable stage-fusion optimizations for dataset pipelines. # Whether to enable stage-fusion optimizations for dataset pipelines.
# TODO(ekl): enable this by default when ready. DEFAULT_OPTIMIZE_FUSE_STAGES = True
DEFAULT_OPTIMIZE_FUSE_STAGES = False
# Whether to furthermore fuse read stages. When this is enabled, data will also be # Whether to furthermore fuse read stages. When this is enabled, data will also be
# re-read from the base dataset in each repetition of a DatasetPipeline. # re-read from the base dataset in each repetition of a DatasetPipeline.
@ -54,12 +53,8 @@ class DatasetContext:
self.target_max_block_size = target_max_block_size self.target_max_block_size = target_max_block_size
self.enable_pandas_block = enable_pandas_block self.enable_pandas_block = enable_pandas_block
self.optimize_fuse_stages = optimize_fuse_stages self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = ( self.optimize_fuse_read_stages = optimize_fuse_read_stages
optimize_fuse_stages and optimize_fuse_read_stages self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
)
self.optimize_fuse_shuffle_stages = (
optimize_fuse_stages and optimize_fuse_shuffle_stages
)
@staticmethod @staticmethod
def get_current() -> "DatasetContext": def get_current() -> "DatasetContext":

View file

@ -463,13 +463,21 @@ class Dataset(Generic[T]):
if shuffle: if shuffle:
def do_shuffle(block_list, clear_input_blocks: bool, block_udf): def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks: if clear_input_blocks:
blocks = block_list.copy() blocks = block_list.copy()
block_list.clear() block_list.clear()
else: else:
blocks = block_list blocks = block_list
return simple_shuffle(blocks, block_udf, num_blocks) return simple_shuffle(
blocks,
block_udf,
num_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage( plan = self._plan.with_stage(
AllToAllStage( AllToAllStage(
@ -479,7 +487,7 @@ class Dataset(Generic[T]):
else: else:
def do_fast_repartition(block_list, clear_input_blocks: bool, _): def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks: if clear_input_blocks:
blocks = block_list.copy() blocks = block_list.copy()
block_list.clear() block_list.clear()
@ -524,7 +532,7 @@ class Dataset(Generic[T]):
The shuffled dataset. The shuffled dataset.
""" """
def do_shuffle(block_list, clear_input_blocks: bool, block_udf): def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking. num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0: if num_blocks == 0:
return block_list, {} return block_list, {}
@ -540,6 +548,8 @@ class Dataset(Generic[T]):
random_shuffle=True, random_shuffle=True,
random_seed=seed, random_seed=seed,
_spread_resource_prefix=_spread_resource_prefix, _spread_resource_prefix=_spread_resource_prefix,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
) )
return new_blocks, stage_info return new_blocks, stage_info
@ -1380,7 +1390,7 @@ class Dataset(Generic[T]):
A new, sorted dataset. A new, sorted dataset.
""" """
def do_sort(block_list, clear_input_blocks: bool, block_udf): def do_sort(block_list, clear_input_blocks: bool, *_):
# Handle empty dataset. # Handle empty dataset.
if block_list.initial_num_blocks() == 0: if block_list.initial_num_blocks() == 0:
return block_list, {} return block_list, {}
@ -1424,7 +1434,7 @@ class Dataset(Generic[T]):
comes from the first dataset and v comes from the second. comes from the first dataset and v comes from the second.
""" """
def do_zip_all(block_list, clear_input_blocks: bool, block_udf): def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks() blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs() blocks2 = other.get_internal_block_refs()
@ -2681,11 +2691,24 @@ Dict[str, List[str]]]): The names of the columns
This can be used to read all blocks into memory. By default, Datasets This can be used to read all blocks into memory. By default, Datasets
doesn't read blocks from the datasource until the first transform. doesn't read blocks from the datasource until the first transform.
Returns:
A Dataset with all blocks fully materialized in memory.
""" """
blocks = self.get_internal_block_refs() blocks = self.get_internal_block_refs()
bar = ProgressBar("Force reads", len(blocks)) bar = ProgressBar("Force reads", len(blocks))
bar.block_until_complete(blocks) bar.block_until_complete(blocks)
return self ds = Dataset(
ExecutionPlan(
BlockList(blocks, self._plan.execute().get_metadata()),
self._plan.stats(),
dataset_uuid=self._get_uuid(),
),
self._epoch,
lazy=False,
)
ds._set_uuid(self._get_uuid())
return ds
@DeveloperAPI @DeveloperAPI
def stats(self) -> str: def stats(self) -> str:

View file

@ -394,10 +394,6 @@ class DatasetPipeline(Generic[T]):
This operation is only allowed for pipelines of a finite length. An This operation is only allowed for pipelines of a finite length. An
error will be raised for pipelines of infinite length. error will be raised for pipelines of infinite length.
Transformations prior to the call to ``repeat()`` are evaluated once.
Transformations done on the repeated pipeline are evaluated on each
loop of the pipeline over the base pipeline.
Note that every repeat of the pipeline is considered an "epoch" for Note that every repeat of the pipeline is considered an "epoch" for
the purposes of ``iter_epochs()``. If there are multiple repeat calls, the purposes of ``iter_epochs()``. If there are multiple repeat calls,
the latest repeat takes precedence for the purpose of defining epochs. the latest repeat takes precedence for the purpose of defining epochs.
@ -424,10 +420,15 @@ class DatasetPipeline(Generic[T]):
# Still going through the original pipeline. # Still going through the original pipeline.
if self._original_iter: if self._original_iter:
try: try:
res = next(self._original_iter) make_ds = next(self._original_iter)
res._set_epoch(0) self._results.append(make_ds)
self._results.append(res)
return lambda: res def gen():
res = make_ds()
res._set_epoch(0)
return res
return gen
except StopIteration: except StopIteration:
self._original_iter = None self._original_iter = None
# Calculate the cursor limit. # Calculate the cursor limit.
@ -437,10 +438,16 @@ class DatasetPipeline(Generic[T]):
self._max_i = float("inf") self._max_i = float("inf")
# Going through a repeat of the pipeline. # Going through a repeat of the pipeline.
if self._i < self._max_i: if self._i < self._max_i:
res = self._results[self._i % len(self._results)] make_ds = self._results[self._i % len(self._results)]
res._set_epoch(1 + self._i // len(self._results)) epoch = 1 + self._i // len(self._results)
def gen():
res = make_ds()
res._set_epoch(epoch)
return res
self._i += 1 self._i += 1
return lambda: res return gen
else: else:
raise StopIteration raise StopIteration
@ -458,7 +465,11 @@ class DatasetPipeline(Generic[T]):
else: else:
length = None length = None
return DatasetPipeline(RepeatIterable(self.iter_datasets()), length=length) return DatasetPipeline(
RepeatIterable(iter(self._base_iterable)),
stages=self._stages.copy(),
length=length,
)
def schema(self) -> Union[type, "pyarrow.lib.Schema"]: def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
"""Return the schema of the dataset pipeline. """Return the schema of the dataset pipeline.

View file

@ -56,7 +56,7 @@ class GroupedDataset(Generic[T]):
If groupby key is ``None`` then the key part of return is omitted. If groupby key is ``None`` then the key part of return is omitted.
""" """
def do_agg(blocks, clear_input_blocks: bool, block_udf): def do_agg(blocks, clear_input_blocks: bool, *_):
# TODO: implement clear_input_blocks # TODO: implement clear_input_blocks
stage_info = {} stage_info = {}
if len(aggs) == 0: if len(aggs) == 0:

View file

@ -12,6 +12,9 @@ from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats from ray.data.impl.stats import DatasetStats
from ray.data.impl.lazy_block_list import LazyBlockList from ray.data.impl.lazy_block_list import LazyBlockList
# Scheduling strategy can be inherited from prev stage if not specified.
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
class ExecutionPlan: class ExecutionPlan:
"""A lazy execution plan for a Dataset.""" """A lazy execution plan for a Dataset."""
@ -115,10 +118,6 @@ class ExecutionPlan:
Returns: Returns:
The blocks of the output dataset. The blocks of the output dataset.
""" """
# TODO: add optimizations:
# 1. task fusion of OneToOne
# 2. task fusion of OneToOne to AlltoAll
# 3. clear input blocks
if self._out_blocks is None: if self._out_blocks is None:
self._optimize() self._optimize()
blocks = self._in_blocks blocks = self._in_blocks
@ -186,6 +185,7 @@ class ExecutionPlan:
[GetReadTasks -> MapBatches(DoRead -> Fn)]. [GetReadTasks -> MapBatches(DoRead -> Fn)].
""" """
# Generate the "GetReadTasks" stage blocks. # Generate the "GetReadTasks" stage blocks.
remote_args = self._in_blocks._read_remote_args
blocks = [] blocks = []
metadata = [] metadata = []
for i, read_task in enumerate(self._in_blocks._read_tasks): for i, read_task in enumerate(self._in_blocks._read_tasks):
@ -198,8 +198,7 @@ class ExecutionPlan:
for tmp1 in read_task._read_fn(): for tmp1 in read_task._read_fn():
yield tmp1 yield tmp1
# TODO(ekl): add num_cpus properly here and make the read default num_cpus=1. return block_list, OneToOneStage("read", block_fn, "tasks", remote_args)
return block_list, OneToOneStage("read", block_fn, None, {})
def _fuse_one_to_one_stages(self) -> None: def _fuse_one_to_one_stages(self) -> None:
"""Fuses compatible one-to-one stages.""" """Fuses compatible one-to-one stages."""
@ -254,14 +253,18 @@ class OneToOneStage(Stage):
super().__init__(name, None) super().__init__(name, None)
self.block_fn = block_fn self.block_fn = block_fn
self.compute = compute or "tasks" self.compute = compute or "tasks"
self.ray_remote_args = ray_remote_args self.ray_remote_args = ray_remote_args or {}
def can_fuse(self, prev: Stage): def can_fuse(self, prev: Stage):
if not isinstance(prev, OneToOneStage): if not isinstance(prev, OneToOneStage):
return False return False
if prev.compute != self.compute: if prev.compute != self.compute:
return False return False
if prev.ray_remote_args != self.ray_remote_args: for key in INHERITABLE_REMOTE_ARGS:
remote_args = self.ray_remote_args.copy()
if key in prev.ray_remote_args:
remote_args[key] = prev.ray_remote_args[key]
if prev.ray_remote_args != remote_args:
return False return False
return True return True
@ -275,7 +278,7 @@ class OneToOneStage(Stage):
for tmp2 in fn2(tmp1): for tmp2 in fn2(tmp1):
yield tmp2 yield tmp2
return OneToOneStage(name, block_fn, self.compute, self.ray_remote_args) return OneToOneStage(name, block_fn, prev.compute, prev.ray_remote_args)
def __call__( def __call__(
self, blocks: BlockList, clear_input_blocks: bool self, blocks: BlockList, clear_input_blocks: bool
@ -298,11 +301,13 @@ class AllToAllStage(Stage):
fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]], fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]],
supports_block_udf: bool = False, supports_block_udf: bool = False,
block_udf=None, block_udf=None,
remote_args=None,
): ):
super().__init__(name, num_blocks) super().__init__(name, num_blocks)
self.fn = fn self.fn = fn
self.supports_block_udf = supports_block_udf self.supports_block_udf = supports_block_udf
self.block_udf = block_udf self.block_udf = block_udf
self.ray_remote_args = remote_args or {}
def can_fuse(self, prev: Stage): def can_fuse(self, prev: Stage):
context = DatasetContext.get_current() context = DatasetContext.get_current()
@ -315,18 +320,22 @@ class AllToAllStage(Stage):
return False return False
if prev.compute != "tasks": if prev.compute != "tasks":
return False return False
if prev.ray_remote_args: if any(k not in INHERITABLE_REMOTE_ARGS for k in prev.ray_remote_args):
return False return False
return True return True
def fuse(self, prev: Stage): def fuse(self, prev: Stage):
assert self.supports_block_udf assert self.supports_block_udf
name = prev.name + "->" + self.name name = prev.name + "->" + self.name
return AllToAllStage(name, self.num_blocks, self.fn, True, prev.block_fn) return AllToAllStage(
name, self.num_blocks, self.fn, True, prev.block_fn, prev.ray_remote_args
)
def __call__( def __call__(
self, blocks: BlockList, clear_input_blocks: bool self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]: ) -> Tuple[BlockList, dict]:
blocks, stage_info = self.fn(blocks, clear_input_blocks, self.block_udf) blocks, stage_info = self.fn(
blocks, clear_input_blocks, self.block_udf, self.ray_remote_args
)
assert isinstance(blocks, BlockList), blocks assert isinstance(blocks, BlockList), blocks
return blocks, stage_info return blocks, stage_info

View file

@ -83,3 +83,9 @@ class ProgressBar:
def __del__(self): def __del__(self):
self.close() self.close()
def __getstate__(self):
return {}
def __setstate__(self, state):
self._bar = None # Progress bar is disabled on remote nodes.

View file

@ -32,6 +32,7 @@ def simple_shuffle(
if reduce_ray_remote_args is None: if reduce_ray_remote_args is None:
reduce_ray_remote_args = {} reduce_ray_remote_args = {}
if "scheduling_strategy" not in reduce_ray_remote_args: if "scheduling_strategy" not in reduce_ray_remote_args:
reduce_ray_remote_args = reduce_ray_remote_args.copy()
reduce_ray_remote_args["scheduling_strategy"] = "SPREAD" reduce_ray_remote_args["scheduling_strategy"] = "SPREAD"
input_num_blocks = len(input_blocks) input_num_blocks = len(input_blocks)
if _spread_resource_prefix is not None: if _spread_resource_prefix is not None:

View file

@ -209,8 +209,10 @@ class DatasetStats:
out = "" out = ""
if self.parents: if self.parents:
for p in self.parents: for p in self.parents:
out += p.summary_string(already_printed) parent_sum = p.summary_string(already_printed)
out += "\n" if parent_sum:
out += parent_sum
out += "\n"
first = True first = True
for stage_name, metadata in self.stages.items(): for stage_name, metadata in self.stages.items():
stage_uuid = self.dataset_uuid + stage_name stage_uuid = self.dataset_uuid + stage_name

View file

@ -259,17 +259,16 @@ def read_datasource(
if ray_remote_args is None: if ray_remote_args is None:
ray_remote_args = {} ray_remote_args = {}
# Increase the read parallelism by default to maximize IO throughput. This
# is particularly important when reading from e.g., remote storage.
if "num_cpus" not in ray_remote_args:
# Note that the too many workers warning triggers at 4x subscription,
# so we go at 0.5 to avoid the warning message.
ray_remote_args["num_cpus"] = 0.5
if "scheduling_strategy" not in ray_remote_args: if "scheduling_strategy" not in ray_remote_args:
ray_remote_args["scheduling_strategy"] = "SPREAD" ray_remote_args["scheduling_strategy"] = "SPREAD"
remote_read = cached_remote_fn(remote_read) remote_read = cached_remote_fn(remote_read)
if _spread_resource_prefix is not None: if _spread_resource_prefix is not None:
if context.optimize_fuse_stages:
logger.warning(
"_spread_resource_prefix has no effect when optimize_fuse_stages "
"is enabled. Tasks are spread by default."
)
# Use given spread resource prefix for round-robin resource-based # Use given spread resource prefix for round-robin resource-based
# scheduling. # scheduling.
nodes = ray.nodes() nodes = ray.nodes()
@ -294,6 +293,7 @@ def read_datasource(
block_list = LazyBlockList(calls, metadata) block_list = LazyBlockList(calls, metadata)
# TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly. # TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly.
block_list._read_tasks = read_tasks block_list._read_tasks = read_tasks
block_list._read_remote_args = ray_remote_args
# Get the schema from the first block synchronously. # Get the schema from the first block synchronously.
if metadata and metadata[0].schema is None: if metadata and metadata[0].schema is None:

View file

@ -6,6 +6,7 @@ import pandas as pd
import numpy as np import numpy as np
import ray import ray
from ray.data.context import DatasetContext
from ray.data.dataset_pipeline import DatasetPipeline from ray.data.dataset_pipeline import DatasetPipeline
from ray.tests.conftest import * # noqa from ray.tests.conftest import * # noqa
@ -89,31 +90,35 @@ def test_cannot_read_twice(ray_start_regular_shared):
def test_basic_pipeline(ray_start_regular_shared): def test_basic_pipeline(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10) ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1) pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.count() == 10 assert pipe.count() == 10
pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x) pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)" assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=4)"
assert pipe.take() == list(range(10)) assert pipe.take() == list(range(10))
pipe = ds.window(blocks_per_window=999) pipe = ds.window(blocks_per_window=999)
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=2)"
assert pipe.count() == 10 assert pipe.count() == 10
pipe = ds.repeat(10) pipe = ds.repeat(10)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.count() == 100 assert pipe.count() == 100
pipe = ds.repeat(10) pipe = ds.repeat(10)
assert pipe.sum() == 450 assert pipe.sum() == 450
def test_window(ray_start_regular_shared): def test_window(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10) ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1) pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
pipe = pipe.rewindow(blocks_per_window=3) pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets()) datasets = list(pipe.iter_datasets())
@ -125,7 +130,7 @@ def test_window(ray_start_regular_shared):
ds = ray.data.range(10) ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=5) pipe = ds.window(blocks_per_window=5)
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=2)"
pipe = pipe.rewindow(blocks_per_window=3) pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets()) datasets = list(pipe.iter_datasets())
@ -137,17 +142,19 @@ def test_window(ray_start_regular_shared):
def test_repeat(ray_start_regular_shared): def test_repeat(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(5) ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1) pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=2)"
pipe = pipe.repeat(2) pipe = pipe.repeat(2)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.take() == (list(range(5)) + list(range(5))) assert pipe.take() == (list(range(5)) + list(range(5)))
ds = ray.data.range(5) ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1) pipe = ds.window(blocks_per_window=1)
pipe = pipe.repeat() pipe = pipe.repeat()
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
assert len(pipe.take(99)) == 99 assert len(pipe.take(99)) == 99
pipe = ray.data.range(5).repeat() pipe = ray.data.range(5).repeat()
@ -163,9 +170,11 @@ def test_from_iterable(ray_start_regular_shared):
def test_repeat_forever(ray_start_regular_shared): def test_repeat_forever(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10) ds = ray.data.range(10)
pipe = ds.repeat() pipe = ds.repeat()
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)" assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
for i, v in enumerate(pipe.iter_rows()): for i, v in enumerate(pipe.iter_rows()):
assert v == i % 10, (v, i, i % 10) assert v == i % 10, (v, i, i % 10)
if i > 1000: if i > 1000:
@ -212,7 +221,7 @@ def test_schema(ray_start_regular_shared):
def test_split(ray_start_regular_shared): def test_split(ray_start_regular_shared):
pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10) pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10)
@ray.remote @ray.remote(num_cpus=0)
def consume(shard, i): def consume(shard, i):
total = 0 total = 0
for row in shard.iter_rows(): for row in shard.iter_rows():
@ -230,7 +239,7 @@ def test_split_at_indices(ray_start_regular_shared):
n = 8 n = 8
pipe = ray.data.range(n).map(lambda x: x + 1).repeat(2) pipe = ray.data.range(n).map(lambda x: x + 1).repeat(2)
@ray.remote @ray.remote(num_cpus=0)
def consume(shard, i): def consume(shard, i):
total = 0 total = 0
out = [] out = []

View file

@ -17,11 +17,22 @@ def expect_stages(pipe, num_stages_expected, stage_names):
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
def test_spread_hint_inherit(ray_start_regular_shared):
ds = ray.data.range(10)._experimental_lazy()
ds = ds.map(lambda x: x + 1)
ds = ds.random_shuffle()
for s in ds._plan._stages:
assert s.ray_remote_args == {}, s.ray_remote_args
ds._plan._optimize()
assert len(ds._plan._stages) == 1, ds._plan._stages
assert ds._plan._stages[0].ray_remote_args == {"scheduling_strategy": "SPREAD"}
def test_optimize_fuse(ray_start_regular_shared): def test_optimize_fuse(ray_start_regular_shared):
context = DatasetContext.get_current() context = DatasetContext.get_current()
def build_pipe(): def build_pipe():
pipe = ray.data.range(3).repeat(2) pipe = ray.data.range(3).window(blocks_per_window=1).repeat(2)
pipe = pipe.map_batches(lambda x: x) pipe = pipe.map_batches(lambda x: x)
pipe = pipe.map_batches(lambda x: x) pipe = pipe.map_batches(lambda x: x)
pipe = pipe.random_shuffle_each_window() pipe = pipe.random_shuffle_each_window()

View file

@ -2,6 +2,7 @@ import pytest
import re import re
import ray import ray
from ray.data.context import DatasetContext
from ray.tests.conftest import * # noqa from ray.tests.conftest import * # noqa
@ -16,6 +17,8 @@ def canonicalize(stats: str) -> str:
def test_dataset_stats_basic(ray_start_regular_shared): def test_dataset_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10) ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x) ds = ds.map_batches(lambda x: x)
ds = ds.map(lambda x: x) ds = ds.map(lambda x: x)
@ -24,14 +27,7 @@ def test_dataset_stats_basic(ray_start_regular_shared):
stats = canonicalize(ds.stats()) stats = canonicalize(ds.stats())
assert ( assert (
stats stats
== """Stage Z read: N/N blocks executed in T == """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total
@ -56,19 +52,14 @@ Dataset iterator time breakdown:
def test_dataset_stats_shuffle(ray_start_regular_shared): def test_dataset_stats_shuffle(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10) ds = ray.data.range(1000, parallelism=10)
ds = ds.random_shuffle().repartition(1, shuffle=True) ds = ds.random_shuffle().repartition(1, shuffle=True)
stats = canonicalize(ds.stats()) stats = canonicalize(ds.stats())
assert ( assert (
stats stats
== """Stage Z read: N/N blocks executed in T == """Stage N read->random_shuffle_map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N random_shuffle_map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total
@ -135,20 +126,15 @@ def test_dataset_stats_from_items(ray_start_regular_shared):
def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path): def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10) ds = ray.data.range(1000, parallelism=10)
ds.write_parquet(str(tmp_path)) ds.write_parquet(str(tmp_path))
ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x) ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x)
stats = canonicalize(ds.stats()) stats = canonicalize(ds.stats())
assert ( assert (
stats stats
== """Stage Z read: N/N blocks executed in T == """Stage N read->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total
@ -159,6 +145,8 @@ Stage N map: N/N blocks executed in T
def test_dataset_pipeline_stats_basic(ray_start_regular_shared): def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10) ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x) ds = ds.map_batches(lambda x: x)
pipe = ds.repeat(5) pipe = ds.repeat(5)
@ -169,14 +157,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
assert ( assert (
stats stats
== """== Pipeline Window N == == """== Pipeline Window N ==
Stage Z read: N/N blocks executed in T Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total
@ -198,8 +179,7 @@ Dataset iterator time breakdown:
* Total time: T * Total time: T
== Pipeline Window N == == Pipeline Window N ==
Stage Z read: [execution cached] Stage N read->map_batches: [execution cached]
Stage N map_batches: [execution cached]
Stage N map: N/N blocks executed in T Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
@ -215,8 +195,7 @@ Dataset iterator time breakdown:
* Total time: T * Total time: T
== Pipeline Window N == == Pipeline Window N ==
Stage Z read: [execution cached] Stage N read->map_batches: [execution cached]
Stage N map_batches: [execution cached]
Stage N map: N/N blocks executed in T Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
@ -241,6 +220,8 @@ Dataset iterator time breakdown:
def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared): def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10) ds = ray.data.range(1000, parallelism=10)
pipe = ds.repeat(2) pipe = ds.repeat(2)
@ -255,7 +236,7 @@ def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
assert ( assert (
canonicalize(stats[0]) canonicalize(stats[0])
== """== Pipeline Window Z == == """== Pipeline Window Z ==
Stage Z read: N/N blocks executed in T Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total
@ -270,7 +251,7 @@ Dataset iterator time breakdown:
* Total time: T * Total time: T
== Pipeline Window N == == Pipeline Window N ==
Stage Z read: N/N blocks executed in T Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total * Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total * Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total * Output num rows: N min, N max, N mean, N total

View file

@ -47,6 +47,8 @@ parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--mock-train-step-time", type=float, default=1.0) parser.add_argument("--mock-train-step-time", type=float, default=1.0)
parser.add_argument("--num-files", type=int, default=30) parser.add_argument("--num-files", type=int, default=30)
parser.add_argument("--num-windows", type=int, default=1) parser.add_argument("--num-windows", type=int, default=1)
parser.add_argument("--manual-windows", type=bool, default=False)
parser.add_argument("--parallelism", type=int, default=400)
SIZE_50_G = 30 # 49.17GB SIZE_50_G = 30 # 49.17GB
SIZE_100_G = 62 # 101.62GB SIZE_100_G = 62 # 101.62GB
@ -218,8 +220,15 @@ def create_torch_iterator(split, batch_size, rank=None):
return torch_iterator return torch_iterator
def create_dataset(files, num_workers=4, epochs=50, num_windows=1): def create_dataset(
if num_windows > 1: files,
num_workers=4,
epochs=50,
num_windows=1,
manual_windowing=False,
parallelism=400,
):
if num_windows > 1 and manual_windowing:
num_rows = ray.data.read_parquet( num_rows = ray.data.read_parquet(
files files
).count() # This should only read Parquet metadata. ).count() # This should only read Parquet metadata.
@ -247,7 +256,10 @@ def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
pipe = pipe.random_shuffle_each_window() pipe = pipe.random_shuffle_each_window()
pipe_shards = pipe.split_at_indices(split_indices) pipe_shards = pipe.split_at_indices(split_indices)
else: else:
ds = ray.data.read_parquet(files) ds = ray.data.read_parquet(files, parallelism=parallelism)
if num_windows > 1:
window_size = max(ds.num_blocks() // num_windows, 1)
ds = ds.window(blocks_per_window=window_size)
pipe = ds.repeat(epochs) pipe = ds.repeat(epochs)
pipe = pipe.random_shuffle_each_window() pipe = pipe.random_shuffle_each_window()
pipe_shards = pipe.split(num_workers, equal=True) pipe_shards = pipe.split(num_workers, equal=True)
@ -285,6 +297,8 @@ if __name__ == "__main__":
num_workers=args.num_workers, num_workers=args.num_workers,
epochs=args.epochs, epochs=args.epochs,
num_windows=args.num_windows, num_windows=args.num_windows,
manual_windowing=args.manual_windows,
parallelism=args.parallelism,
) )
if args.debug: if args.debug: