mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Enable stage fusion by default for dataset pipelines (#22476)
This PR enables stage fusion for dataset pipelines. This also requires: 1. Removing the num_cpus=0.5 default for the read stage, to enable fusion of the read stage. 2. Removing spread_resource_prefix (not supported for now).
This commit is contained in:
parent
a62a9c38fb
commit
e15a419028
13 changed files with 164 additions and 102 deletions
|
@ -19,8 +19,7 @@ DEFAULT_BLOCK_SPLITTING_ENABLED = False
|
|||
DEFAULT_ENABLE_PANDAS_BLOCK = True
|
||||
|
||||
# Whether to enable stage-fusion optimizations for dataset pipelines.
|
||||
# TODO(ekl): enable this by default when ready.
|
||||
DEFAULT_OPTIMIZE_FUSE_STAGES = False
|
||||
DEFAULT_OPTIMIZE_FUSE_STAGES = True
|
||||
|
||||
# Whether to furthermore fuse read stages. When this is enabled, data will also be
|
||||
# re-read from the base dataset in each repetition of a DatasetPipeline.
|
||||
|
@ -54,12 +53,8 @@ class DatasetContext:
|
|||
self.target_max_block_size = target_max_block_size
|
||||
self.enable_pandas_block = enable_pandas_block
|
||||
self.optimize_fuse_stages = optimize_fuse_stages
|
||||
self.optimize_fuse_read_stages = (
|
||||
optimize_fuse_stages and optimize_fuse_read_stages
|
||||
)
|
||||
self.optimize_fuse_shuffle_stages = (
|
||||
optimize_fuse_stages and optimize_fuse_shuffle_stages
|
||||
)
|
||||
self.optimize_fuse_read_stages = optimize_fuse_read_stages
|
||||
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
|
||||
|
||||
@staticmethod
|
||||
def get_current() -> "DatasetContext":
|
||||
|
|
|
@ -463,13 +463,21 @@ class Dataset(Generic[T]):
|
|||
|
||||
if shuffle:
|
||||
|
||||
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
|
||||
def do_shuffle(
|
||||
block_list, clear_input_blocks: bool, block_udf, remote_args
|
||||
):
|
||||
if clear_input_blocks:
|
||||
blocks = block_list.copy()
|
||||
block_list.clear()
|
||||
else:
|
||||
blocks = block_list
|
||||
return simple_shuffle(blocks, block_udf, num_blocks)
|
||||
return simple_shuffle(
|
||||
blocks,
|
||||
block_udf,
|
||||
num_blocks,
|
||||
map_ray_remote_args=remote_args,
|
||||
reduce_ray_remote_args=remote_args,
|
||||
)
|
||||
|
||||
plan = self._plan.with_stage(
|
||||
AllToAllStage(
|
||||
|
@ -479,7 +487,7 @@ class Dataset(Generic[T]):
|
|||
|
||||
else:
|
||||
|
||||
def do_fast_repartition(block_list, clear_input_blocks: bool, _):
|
||||
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
|
||||
if clear_input_blocks:
|
||||
blocks = block_list.copy()
|
||||
block_list.clear()
|
||||
|
@ -524,7 +532,7 @@ class Dataset(Generic[T]):
|
|||
The shuffled dataset.
|
||||
"""
|
||||
|
||||
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
|
||||
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
|
||||
num_blocks = block_list.executed_num_blocks() # Blocking.
|
||||
if num_blocks == 0:
|
||||
return block_list, {}
|
||||
|
@ -540,6 +548,8 @@ class Dataset(Generic[T]):
|
|||
random_shuffle=True,
|
||||
random_seed=seed,
|
||||
_spread_resource_prefix=_spread_resource_prefix,
|
||||
map_ray_remote_args=remote_args,
|
||||
reduce_ray_remote_args=remote_args,
|
||||
)
|
||||
return new_blocks, stage_info
|
||||
|
||||
|
@ -1380,7 +1390,7 @@ class Dataset(Generic[T]):
|
|||
A new, sorted dataset.
|
||||
"""
|
||||
|
||||
def do_sort(block_list, clear_input_blocks: bool, block_udf):
|
||||
def do_sort(block_list, clear_input_blocks: bool, *_):
|
||||
# Handle empty dataset.
|
||||
if block_list.initial_num_blocks() == 0:
|
||||
return block_list, {}
|
||||
|
@ -1424,7 +1434,7 @@ class Dataset(Generic[T]):
|
|||
comes from the first dataset and v comes from the second.
|
||||
"""
|
||||
|
||||
def do_zip_all(block_list, clear_input_blocks: bool, block_udf):
|
||||
def do_zip_all(block_list, clear_input_blocks: bool, *_):
|
||||
blocks1 = block_list.get_blocks()
|
||||
blocks2 = other.get_internal_block_refs()
|
||||
|
||||
|
@ -2681,11 +2691,24 @@ Dict[str, List[str]]]): The names of the columns
|
|||
|
||||
This can be used to read all blocks into memory. By default, Datasets
|
||||
doesn't read blocks from the datasource until the first transform.
|
||||
|
||||
Returns:
|
||||
A Dataset with all blocks fully materialized in memory.
|
||||
"""
|
||||
blocks = self.get_internal_block_refs()
|
||||
bar = ProgressBar("Force reads", len(blocks))
|
||||
bar.block_until_complete(blocks)
|
||||
return self
|
||||
ds = Dataset(
|
||||
ExecutionPlan(
|
||||
BlockList(blocks, self._plan.execute().get_metadata()),
|
||||
self._plan.stats(),
|
||||
dataset_uuid=self._get_uuid(),
|
||||
),
|
||||
self._epoch,
|
||||
lazy=False,
|
||||
)
|
||||
ds._set_uuid(self._get_uuid())
|
||||
return ds
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self) -> str:
|
||||
|
|
|
@ -394,10 +394,6 @@ class DatasetPipeline(Generic[T]):
|
|||
This operation is only allowed for pipelines of a finite length. An
|
||||
error will be raised for pipelines of infinite length.
|
||||
|
||||
Transformations prior to the call to ``repeat()`` are evaluated once.
|
||||
Transformations done on the repeated pipeline are evaluated on each
|
||||
loop of the pipeline over the base pipeline.
|
||||
|
||||
Note that every repeat of the pipeline is considered an "epoch" for
|
||||
the purposes of ``iter_epochs()``. If there are multiple repeat calls,
|
||||
the latest repeat takes precedence for the purpose of defining epochs.
|
||||
|
@ -424,10 +420,15 @@ class DatasetPipeline(Generic[T]):
|
|||
# Still going through the original pipeline.
|
||||
if self._original_iter:
|
||||
try:
|
||||
res = next(self._original_iter)
|
||||
res._set_epoch(0)
|
||||
self._results.append(res)
|
||||
return lambda: res
|
||||
make_ds = next(self._original_iter)
|
||||
self._results.append(make_ds)
|
||||
|
||||
def gen():
|
||||
res = make_ds()
|
||||
res._set_epoch(0)
|
||||
return res
|
||||
|
||||
return gen
|
||||
except StopIteration:
|
||||
self._original_iter = None
|
||||
# Calculate the cursor limit.
|
||||
|
@ -437,10 +438,16 @@ class DatasetPipeline(Generic[T]):
|
|||
self._max_i = float("inf")
|
||||
# Going through a repeat of the pipeline.
|
||||
if self._i < self._max_i:
|
||||
res = self._results[self._i % len(self._results)]
|
||||
res._set_epoch(1 + self._i // len(self._results))
|
||||
make_ds = self._results[self._i % len(self._results)]
|
||||
epoch = 1 + self._i // len(self._results)
|
||||
|
||||
def gen():
|
||||
res = make_ds()
|
||||
res._set_epoch(epoch)
|
||||
return res
|
||||
|
||||
self._i += 1
|
||||
return lambda: res
|
||||
return gen
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
@ -458,7 +465,11 @@ class DatasetPipeline(Generic[T]):
|
|||
else:
|
||||
length = None
|
||||
|
||||
return DatasetPipeline(RepeatIterable(self.iter_datasets()), length=length)
|
||||
return DatasetPipeline(
|
||||
RepeatIterable(iter(self._base_iterable)),
|
||||
stages=self._stages.copy(),
|
||||
length=length,
|
||||
)
|
||||
|
||||
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
|
||||
"""Return the schema of the dataset pipeline.
|
||||
|
|
|
@ -56,7 +56,7 @@ class GroupedDataset(Generic[T]):
|
|||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
|
||||
def do_agg(blocks, clear_input_blocks: bool, block_udf):
|
||||
def do_agg(blocks, clear_input_blocks: bool, *_):
|
||||
# TODO: implement clear_input_blocks
|
||||
stage_info = {}
|
||||
if len(aggs) == 0:
|
||||
|
|
|
@ -12,6 +12,9 @@ from ray.data.impl.compute import get_compute
|
|||
from ray.data.impl.stats import DatasetStats
|
||||
from ray.data.impl.lazy_block_list import LazyBlockList
|
||||
|
||||
# Scheduling strategy can be inherited from prev stage if not specified.
|
||||
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
|
||||
|
||||
|
||||
class ExecutionPlan:
|
||||
"""A lazy execution plan for a Dataset."""
|
||||
|
@ -115,10 +118,6 @@ class ExecutionPlan:
|
|||
Returns:
|
||||
The blocks of the output dataset.
|
||||
"""
|
||||
# TODO: add optimizations:
|
||||
# 1. task fusion of OneToOne
|
||||
# 2. task fusion of OneToOne to AlltoAll
|
||||
# 3. clear input blocks
|
||||
if self._out_blocks is None:
|
||||
self._optimize()
|
||||
blocks = self._in_blocks
|
||||
|
@ -186,6 +185,7 @@ class ExecutionPlan:
|
|||
[GetReadTasks -> MapBatches(DoRead -> Fn)].
|
||||
"""
|
||||
# Generate the "GetReadTasks" stage blocks.
|
||||
remote_args = self._in_blocks._read_remote_args
|
||||
blocks = []
|
||||
metadata = []
|
||||
for i, read_task in enumerate(self._in_blocks._read_tasks):
|
||||
|
@ -198,8 +198,7 @@ class ExecutionPlan:
|
|||
for tmp1 in read_task._read_fn():
|
||||
yield tmp1
|
||||
|
||||
# TODO(ekl): add num_cpus properly here and make the read default num_cpus=1.
|
||||
return block_list, OneToOneStage("read", block_fn, None, {})
|
||||
return block_list, OneToOneStage("read", block_fn, "tasks", remote_args)
|
||||
|
||||
def _fuse_one_to_one_stages(self) -> None:
|
||||
"""Fuses compatible one-to-one stages."""
|
||||
|
@ -254,14 +253,18 @@ class OneToOneStage(Stage):
|
|||
super().__init__(name, None)
|
||||
self.block_fn = block_fn
|
||||
self.compute = compute or "tasks"
|
||||
self.ray_remote_args = ray_remote_args
|
||||
self.ray_remote_args = ray_remote_args or {}
|
||||
|
||||
def can_fuse(self, prev: Stage):
|
||||
if not isinstance(prev, OneToOneStage):
|
||||
return False
|
||||
if prev.compute != self.compute:
|
||||
return False
|
||||
if prev.ray_remote_args != self.ray_remote_args:
|
||||
for key in INHERITABLE_REMOTE_ARGS:
|
||||
remote_args = self.ray_remote_args.copy()
|
||||
if key in prev.ray_remote_args:
|
||||
remote_args[key] = prev.ray_remote_args[key]
|
||||
if prev.ray_remote_args != remote_args:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -275,7 +278,7 @@ class OneToOneStage(Stage):
|
|||
for tmp2 in fn2(tmp1):
|
||||
yield tmp2
|
||||
|
||||
return OneToOneStage(name, block_fn, self.compute, self.ray_remote_args)
|
||||
return OneToOneStage(name, block_fn, prev.compute, prev.ray_remote_args)
|
||||
|
||||
def __call__(
|
||||
self, blocks: BlockList, clear_input_blocks: bool
|
||||
|
@ -298,11 +301,13 @@ class AllToAllStage(Stage):
|
|||
fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]],
|
||||
supports_block_udf: bool = False,
|
||||
block_udf=None,
|
||||
remote_args=None,
|
||||
):
|
||||
super().__init__(name, num_blocks)
|
||||
self.fn = fn
|
||||
self.supports_block_udf = supports_block_udf
|
||||
self.block_udf = block_udf
|
||||
self.ray_remote_args = remote_args or {}
|
||||
|
||||
def can_fuse(self, prev: Stage):
|
||||
context = DatasetContext.get_current()
|
||||
|
@ -315,18 +320,22 @@ class AllToAllStage(Stage):
|
|||
return False
|
||||
if prev.compute != "tasks":
|
||||
return False
|
||||
if prev.ray_remote_args:
|
||||
if any(k not in INHERITABLE_REMOTE_ARGS for k in prev.ray_remote_args):
|
||||
return False
|
||||
return True
|
||||
|
||||
def fuse(self, prev: Stage):
|
||||
assert self.supports_block_udf
|
||||
name = prev.name + "->" + self.name
|
||||
return AllToAllStage(name, self.num_blocks, self.fn, True, prev.block_fn)
|
||||
return AllToAllStage(
|
||||
name, self.num_blocks, self.fn, True, prev.block_fn, prev.ray_remote_args
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, blocks: BlockList, clear_input_blocks: bool
|
||||
) -> Tuple[BlockList, dict]:
|
||||
blocks, stage_info = self.fn(blocks, clear_input_blocks, self.block_udf)
|
||||
blocks, stage_info = self.fn(
|
||||
blocks, clear_input_blocks, self.block_udf, self.ray_remote_args
|
||||
)
|
||||
assert isinstance(blocks, BlockList), blocks
|
||||
return blocks, stage_info
|
||||
|
|
|
@ -83,3 +83,9 @@ class ProgressBar:
|
|||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def __getstate__(self):
|
||||
return {}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._bar = None # Progress bar is disabled on remote nodes.
|
||||
|
|
|
@ -32,6 +32,7 @@ def simple_shuffle(
|
|||
if reduce_ray_remote_args is None:
|
||||
reduce_ray_remote_args = {}
|
||||
if "scheduling_strategy" not in reduce_ray_remote_args:
|
||||
reduce_ray_remote_args = reduce_ray_remote_args.copy()
|
||||
reduce_ray_remote_args["scheduling_strategy"] = "SPREAD"
|
||||
input_num_blocks = len(input_blocks)
|
||||
if _spread_resource_prefix is not None:
|
||||
|
|
|
@ -209,8 +209,10 @@ class DatasetStats:
|
|||
out = ""
|
||||
if self.parents:
|
||||
for p in self.parents:
|
||||
out += p.summary_string(already_printed)
|
||||
out += "\n"
|
||||
parent_sum = p.summary_string(already_printed)
|
||||
if parent_sum:
|
||||
out += parent_sum
|
||||
out += "\n"
|
||||
first = True
|
||||
for stage_name, metadata in self.stages.items():
|
||||
stage_uuid = self.dataset_uuid + stage_name
|
||||
|
|
|
@ -259,17 +259,16 @@ def read_datasource(
|
|||
|
||||
if ray_remote_args is None:
|
||||
ray_remote_args = {}
|
||||
# Increase the read parallelism by default to maximize IO throughput. This
|
||||
# is particularly important when reading from e.g., remote storage.
|
||||
if "num_cpus" not in ray_remote_args:
|
||||
# Note that the too many workers warning triggers at 4x subscription,
|
||||
# so we go at 0.5 to avoid the warning message.
|
||||
ray_remote_args["num_cpus"] = 0.5
|
||||
if "scheduling_strategy" not in ray_remote_args:
|
||||
ray_remote_args["scheduling_strategy"] = "SPREAD"
|
||||
remote_read = cached_remote_fn(remote_read)
|
||||
|
||||
if _spread_resource_prefix is not None:
|
||||
if context.optimize_fuse_stages:
|
||||
logger.warning(
|
||||
"_spread_resource_prefix has no effect when optimize_fuse_stages "
|
||||
"is enabled. Tasks are spread by default."
|
||||
)
|
||||
# Use given spread resource prefix for round-robin resource-based
|
||||
# scheduling.
|
||||
nodes = ray.nodes()
|
||||
|
@ -294,6 +293,7 @@ def read_datasource(
|
|||
block_list = LazyBlockList(calls, metadata)
|
||||
# TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly.
|
||||
block_list._read_tasks = read_tasks
|
||||
block_list._read_remote_args = ray_remote_args
|
||||
|
||||
# Get the schema from the first block synchronously.
|
||||
if metadata and metadata[0].schema is None:
|
||||
|
|
|
@ -6,6 +6,7 @@ import pandas as pd
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.data.context import DatasetContext
|
||||
from ray.data.dataset_pipeline import DatasetPipeline
|
||||
|
||||
from ray.tests.conftest import * # noqa
|
||||
|
@ -89,31 +90,35 @@ def test_cannot_read_twice(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_basic_pipeline(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(10)
|
||||
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
|
||||
assert pipe.count() == 10
|
||||
|
||||
pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=4)"
|
||||
assert pipe.take() == list(range(10))
|
||||
|
||||
pipe = ds.window(blocks_per_window=999)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=2)"
|
||||
assert pipe.count() == 10
|
||||
|
||||
pipe = ds.repeat(10)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
|
||||
assert pipe.count() == 100
|
||||
pipe = ds.repeat(10)
|
||||
assert pipe.sum() == 450
|
||||
|
||||
|
||||
def test_window(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(10)
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
|
||||
pipe = pipe.rewindow(blocks_per_window=3)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
|
||||
datasets = list(pipe.iter_datasets())
|
||||
|
@ -125,7 +130,7 @@ def test_window(ray_start_regular_shared):
|
|||
|
||||
ds = ray.data.range(10)
|
||||
pipe = ds.window(blocks_per_window=5)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=2)"
|
||||
pipe = pipe.rewindow(blocks_per_window=3)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
|
||||
datasets = list(pipe.iter_datasets())
|
||||
|
@ -137,17 +142,19 @@ def test_window(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_repeat(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(5)
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=2)"
|
||||
pipe = pipe.repeat(2)
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
|
||||
assert pipe.take() == (list(range(5)) + list(range(5)))
|
||||
|
||||
ds = ray.data.range(5)
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
pipe = pipe.repeat()
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
|
||||
assert len(pipe.take(99)) == 99
|
||||
|
||||
pipe = ray.data.range(5).repeat()
|
||||
|
@ -163,9 +170,11 @@ def test_from_iterable(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_repeat_forever(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(10)
|
||||
pipe = ds.repeat()
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
|
||||
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
|
||||
for i, v in enumerate(pipe.iter_rows()):
|
||||
assert v == i % 10, (v, i, i % 10)
|
||||
if i > 1000:
|
||||
|
@ -212,7 +221,7 @@ def test_schema(ray_start_regular_shared):
|
|||
def test_split(ray_start_regular_shared):
|
||||
pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10)
|
||||
|
||||
@ray.remote
|
||||
@ray.remote(num_cpus=0)
|
||||
def consume(shard, i):
|
||||
total = 0
|
||||
for row in shard.iter_rows():
|
||||
|
@ -230,7 +239,7 @@ def test_split_at_indices(ray_start_regular_shared):
|
|||
n = 8
|
||||
pipe = ray.data.range(n).map(lambda x: x + 1).repeat(2)
|
||||
|
||||
@ray.remote
|
||||
@ray.remote(num_cpus=0)
|
||||
def consume(shard, i):
|
||||
total = 0
|
||||
out = []
|
||||
|
|
|
@ -17,11 +17,22 @@ def expect_stages(pipe, num_stages_expected, stage_names):
|
|||
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
|
||||
|
||||
|
||||
def test_spread_hint_inherit(ray_start_regular_shared):
|
||||
ds = ray.data.range(10)._experimental_lazy()
|
||||
ds = ds.map(lambda x: x + 1)
|
||||
ds = ds.random_shuffle()
|
||||
for s in ds._plan._stages:
|
||||
assert s.ray_remote_args == {}, s.ray_remote_args
|
||||
ds._plan._optimize()
|
||||
assert len(ds._plan._stages) == 1, ds._plan._stages
|
||||
assert ds._plan._stages[0].ray_remote_args == {"scheduling_strategy": "SPREAD"}
|
||||
|
||||
|
||||
def test_optimize_fuse(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
|
||||
def build_pipe():
|
||||
pipe = ray.data.range(3).repeat(2)
|
||||
pipe = ray.data.range(3).window(blocks_per_window=1).repeat(2)
|
||||
pipe = pipe.map_batches(lambda x: x)
|
||||
pipe = pipe.map_batches(lambda x: x)
|
||||
pipe = pipe.random_shuffle_each_window()
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
import re
|
||||
|
||||
import ray
|
||||
from ray.data.context import DatasetContext
|
||||
from ray.tests.conftest import * # noqa
|
||||
|
||||
|
||||
|
@ -16,6 +17,8 @@ def canonicalize(stats: str) -> str:
|
|||
|
||||
|
||||
def test_dataset_stats_basic(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(1000, parallelism=10)
|
||||
ds = ds.map_batches(lambda x: x)
|
||||
ds = ds.map(lambda x: x)
|
||||
|
@ -24,14 +27,7 @@ def test_dataset_stats_basic(ray_start_regular_shared):
|
|||
stats = canonicalize(ds.stats())
|
||||
assert (
|
||||
stats
|
||||
== """Stage Z read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
* Output size bytes: N min, N max, N mean, N total
|
||||
* Tasks per node: N min, N max, N mean; N nodes used
|
||||
|
||||
Stage N map_batches: N/N blocks executed in T
|
||||
== """Stage N read->map_batches: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
@ -56,19 +52,14 @@ Dataset iterator time breakdown:
|
|||
|
||||
|
||||
def test_dataset_stats_shuffle(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(1000, parallelism=10)
|
||||
ds = ds.random_shuffle().repartition(1, shuffle=True)
|
||||
stats = canonicalize(ds.stats())
|
||||
assert (
|
||||
stats
|
||||
== """Stage Z read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
* Output size bytes: N min, N max, N mean, N total
|
||||
* Tasks per node: N min, N max, N mean; N nodes used
|
||||
|
||||
Stage N random_shuffle_map: N/N blocks executed in T
|
||||
== """Stage N read->random_shuffle_map: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
@ -135,20 +126,15 @@ def test_dataset_stats_from_items(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(1000, parallelism=10)
|
||||
ds.write_parquet(str(tmp_path))
|
||||
ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x)
|
||||
stats = canonicalize(ds.stats())
|
||||
assert (
|
||||
stats
|
||||
== """Stage Z read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
* Output size bytes: N min, N max, N mean, N total
|
||||
* Tasks per node: N min, N max, N mean; N nodes used
|
||||
|
||||
Stage N map: N/N blocks executed in T
|
||||
== """Stage N read->map: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
@ -159,6 +145,8 @@ Stage N map: N/N blocks executed in T
|
|||
|
||||
|
||||
def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(1000, parallelism=10)
|
||||
ds = ds.map_batches(lambda x: x)
|
||||
pipe = ds.repeat(5)
|
||||
|
@ -169,14 +157,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
|
|||
assert (
|
||||
stats
|
||||
== """== Pipeline Window N ==
|
||||
Stage Z read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
* Output size bytes: N min, N max, N mean, N total
|
||||
* Tasks per node: N min, N max, N mean; N nodes used
|
||||
|
||||
Stage N map_batches: N/N blocks executed in T
|
||||
Stage N read->map_batches: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
@ -198,8 +179,7 @@ Dataset iterator time breakdown:
|
|||
* Total time: T
|
||||
|
||||
== Pipeline Window N ==
|
||||
Stage Z read: [execution cached]
|
||||
Stage N map_batches: [execution cached]
|
||||
Stage N read->map_batches: [execution cached]
|
||||
Stage N map: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
|
@ -215,8 +195,7 @@ Dataset iterator time breakdown:
|
|||
* Total time: T
|
||||
|
||||
== Pipeline Window N ==
|
||||
Stage Z read: [execution cached]
|
||||
Stage N map_batches: [execution cached]
|
||||
Stage N read->map_batches: [execution cached]
|
||||
Stage N map: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
|
@ -241,6 +220,8 @@ Dataset iterator time breakdown:
|
|||
|
||||
|
||||
def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
|
||||
context = DatasetContext.get_current()
|
||||
context.optimize_fuse_stages = True
|
||||
ds = ray.data.range(1000, parallelism=10)
|
||||
pipe = ds.repeat(2)
|
||||
|
||||
|
@ -255,7 +236,7 @@ def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
|
|||
assert (
|
||||
canonicalize(stats[0])
|
||||
== """== Pipeline Window Z ==
|
||||
Stage Z read: N/N blocks executed in T
|
||||
Stage N read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
@ -270,7 +251,7 @@ Dataset iterator time breakdown:
|
|||
* Total time: T
|
||||
|
||||
== Pipeline Window N ==
|
||||
Stage Z read: N/N blocks executed in T
|
||||
Stage N read: N/N blocks executed in T
|
||||
* Remote wall time: T min, T max, T mean, T total
|
||||
* Remote cpu time: T min, T max, T mean, T total
|
||||
* Output num rows: N min, N max, N mean, N total
|
||||
|
|
|
@ -47,6 +47,8 @@ parser.add_argument("--num-workers", type=int, default=16)
|
|||
parser.add_argument("--mock-train-step-time", type=float, default=1.0)
|
||||
parser.add_argument("--num-files", type=int, default=30)
|
||||
parser.add_argument("--num-windows", type=int, default=1)
|
||||
parser.add_argument("--manual-windows", type=bool, default=False)
|
||||
parser.add_argument("--parallelism", type=int, default=400)
|
||||
|
||||
SIZE_50_G = 30 # 49.17GB
|
||||
SIZE_100_G = 62 # 101.62GB
|
||||
|
@ -218,8 +220,15 @@ def create_torch_iterator(split, batch_size, rank=None):
|
|||
return torch_iterator
|
||||
|
||||
|
||||
def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
|
||||
if num_windows > 1:
|
||||
def create_dataset(
|
||||
files,
|
||||
num_workers=4,
|
||||
epochs=50,
|
||||
num_windows=1,
|
||||
manual_windowing=False,
|
||||
parallelism=400,
|
||||
):
|
||||
if num_windows > 1 and manual_windowing:
|
||||
num_rows = ray.data.read_parquet(
|
||||
files
|
||||
).count() # This should only read Parquet metadata.
|
||||
|
@ -247,7 +256,10 @@ def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
|
|||
pipe = pipe.random_shuffle_each_window()
|
||||
pipe_shards = pipe.split_at_indices(split_indices)
|
||||
else:
|
||||
ds = ray.data.read_parquet(files)
|
||||
ds = ray.data.read_parquet(files, parallelism=parallelism)
|
||||
if num_windows > 1:
|
||||
window_size = max(ds.num_blocks() // num_windows, 1)
|
||||
ds = ds.window(blocks_per_window=window_size)
|
||||
pipe = ds.repeat(epochs)
|
||||
pipe = pipe.random_shuffle_each_window()
|
||||
pipe_shards = pipe.split(num_workers, equal=True)
|
||||
|
@ -285,6 +297,8 @@ if __name__ == "__main__":
|
|||
num_workers=args.num_workers,
|
||||
epochs=args.epochs,
|
||||
num_windows=args.num_windows,
|
||||
manual_windowing=args.manual_windows,
|
||||
parallelism=args.parallelism,
|
||||
)
|
||||
|
||||
if args.debug:
|
||||
|
|
Loading…
Add table
Reference in a new issue