Enable stage fusion by default for dataset pipelines (#22476)

This PR enables stage fusion for dataset pipelines. This also requires:
1. Removing the num_cpus=0.5 default for the read stage, to enable fusion of the read stage.
2. Removing spread_resource_prefix (not supported for now).
This commit is contained in:
Eric Liang 2022-02-23 17:34:05 -08:00 committed by GitHub
parent a62a9c38fb
commit e15a419028
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 164 additions and 102 deletions

View file

@ -19,8 +19,7 @@ DEFAULT_BLOCK_SPLITTING_ENABLED = False
DEFAULT_ENABLE_PANDAS_BLOCK = True
# Whether to enable stage-fusion optimizations for dataset pipelines.
# TODO(ekl): enable this by default when ready.
DEFAULT_OPTIMIZE_FUSE_STAGES = False
DEFAULT_OPTIMIZE_FUSE_STAGES = True
# Whether to furthermore fuse read stages. When this is enabled, data will also be
# re-read from the base dataset in each repetition of a DatasetPipeline.
@ -54,12 +53,8 @@ class DatasetContext:
self.target_max_block_size = target_max_block_size
self.enable_pandas_block = enable_pandas_block
self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = (
optimize_fuse_stages and optimize_fuse_read_stages
)
self.optimize_fuse_shuffle_stages = (
optimize_fuse_stages and optimize_fuse_shuffle_stages
)
self.optimize_fuse_read_stages = optimize_fuse_read_stages
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
@staticmethod
def get_current() -> "DatasetContext":

View file

@ -463,13 +463,21 @@ class Dataset(Generic[T]):
if shuffle:
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
def do_shuffle(
block_list, clear_input_blocks: bool, block_udf, remote_args
):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
else:
blocks = block_list
return simple_shuffle(blocks, block_udf, num_blocks)
return simple_shuffle(
blocks,
block_udf,
num_blocks,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
plan = self._plan.with_stage(
AllToAllStage(
@ -479,7 +487,7 @@ class Dataset(Generic[T]):
else:
def do_fast_repartition(block_list, clear_input_blocks: bool, _):
def do_fast_repartition(block_list, clear_input_blocks: bool, *_):
if clear_input_blocks:
blocks = block_list.copy()
block_list.clear()
@ -524,7 +532,7 @@ class Dataset(Generic[T]):
The shuffled dataset.
"""
def do_shuffle(block_list, clear_input_blocks: bool, block_udf):
def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args):
num_blocks = block_list.executed_num_blocks() # Blocking.
if num_blocks == 0:
return block_list, {}
@ -540,6 +548,8 @@ class Dataset(Generic[T]):
random_shuffle=True,
random_seed=seed,
_spread_resource_prefix=_spread_resource_prefix,
map_ray_remote_args=remote_args,
reduce_ray_remote_args=remote_args,
)
return new_blocks, stage_info
@ -1380,7 +1390,7 @@ class Dataset(Generic[T]):
A new, sorted dataset.
"""
def do_sort(block_list, clear_input_blocks: bool, block_udf):
def do_sort(block_list, clear_input_blocks: bool, *_):
# Handle empty dataset.
if block_list.initial_num_blocks() == 0:
return block_list, {}
@ -1424,7 +1434,7 @@ class Dataset(Generic[T]):
comes from the first dataset and v comes from the second.
"""
def do_zip_all(block_list, clear_input_blocks: bool, block_udf):
def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()
@ -2681,11 +2691,24 @@ Dict[str, List[str]]]): The names of the columns
This can be used to read all blocks into memory. By default, Datasets
doesn't read blocks from the datasource until the first transform.
Returns:
A Dataset with all blocks fully materialized in memory.
"""
blocks = self.get_internal_block_refs()
bar = ProgressBar("Force reads", len(blocks))
bar.block_until_complete(blocks)
return self
ds = Dataset(
ExecutionPlan(
BlockList(blocks, self._plan.execute().get_metadata()),
self._plan.stats(),
dataset_uuid=self._get_uuid(),
),
self._epoch,
lazy=False,
)
ds._set_uuid(self._get_uuid())
return ds
@DeveloperAPI
def stats(self) -> str:

View file

@ -394,10 +394,6 @@ class DatasetPipeline(Generic[T]):
This operation is only allowed for pipelines of a finite length. An
error will be raised for pipelines of infinite length.
Transformations prior to the call to ``repeat()`` are evaluated once.
Transformations done on the repeated pipeline are evaluated on each
loop of the pipeline over the base pipeline.
Note that every repeat of the pipeline is considered an "epoch" for
the purposes of ``iter_epochs()``. If there are multiple repeat calls,
the latest repeat takes precedence for the purpose of defining epochs.
@ -424,10 +420,15 @@ class DatasetPipeline(Generic[T]):
# Still going through the original pipeline.
if self._original_iter:
try:
res = next(self._original_iter)
res._set_epoch(0)
self._results.append(res)
return lambda: res
make_ds = next(self._original_iter)
self._results.append(make_ds)
def gen():
res = make_ds()
res._set_epoch(0)
return res
return gen
except StopIteration:
self._original_iter = None
# Calculate the cursor limit.
@ -437,10 +438,16 @@ class DatasetPipeline(Generic[T]):
self._max_i = float("inf")
# Going through a repeat of the pipeline.
if self._i < self._max_i:
res = self._results[self._i % len(self._results)]
res._set_epoch(1 + self._i // len(self._results))
make_ds = self._results[self._i % len(self._results)]
epoch = 1 + self._i // len(self._results)
def gen():
res = make_ds()
res._set_epoch(epoch)
return res
self._i += 1
return lambda: res
return gen
else:
raise StopIteration
@ -458,7 +465,11 @@ class DatasetPipeline(Generic[T]):
else:
length = None
return DatasetPipeline(RepeatIterable(self.iter_datasets()), length=length)
return DatasetPipeline(
RepeatIterable(iter(self._base_iterable)),
stages=self._stages.copy(),
length=length,
)
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
"""Return the schema of the dataset pipeline.

View file

@ -56,7 +56,7 @@ class GroupedDataset(Generic[T]):
If groupby key is ``None`` then the key part of return is omitted.
"""
def do_agg(blocks, clear_input_blocks: bool, block_udf):
def do_agg(blocks, clear_input_blocks: bool, *_):
# TODO: implement clear_input_blocks
stage_info = {}
if len(aggs) == 0:

View file

@ -12,6 +12,9 @@ from ray.data.impl.compute import get_compute
from ray.data.impl.stats import DatasetStats
from ray.data.impl.lazy_block_list import LazyBlockList
# Scheduling strategy can be inherited from prev stage if not specified.
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]
class ExecutionPlan:
"""A lazy execution plan for a Dataset."""
@ -115,10 +118,6 @@ class ExecutionPlan:
Returns:
The blocks of the output dataset.
"""
# TODO: add optimizations:
# 1. task fusion of OneToOne
# 2. task fusion of OneToOne to AlltoAll
# 3. clear input blocks
if self._out_blocks is None:
self._optimize()
blocks = self._in_blocks
@ -186,6 +185,7 @@ class ExecutionPlan:
[GetReadTasks -> MapBatches(DoRead -> Fn)].
"""
# Generate the "GetReadTasks" stage blocks.
remote_args = self._in_blocks._read_remote_args
blocks = []
metadata = []
for i, read_task in enumerate(self._in_blocks._read_tasks):
@ -198,8 +198,7 @@ class ExecutionPlan:
for tmp1 in read_task._read_fn():
yield tmp1
# TODO(ekl): add num_cpus properly here and make the read default num_cpus=1.
return block_list, OneToOneStage("read", block_fn, None, {})
return block_list, OneToOneStage("read", block_fn, "tasks", remote_args)
def _fuse_one_to_one_stages(self) -> None:
"""Fuses compatible one-to-one stages."""
@ -254,14 +253,18 @@ class OneToOneStage(Stage):
super().__init__(name, None)
self.block_fn = block_fn
self.compute = compute or "tasks"
self.ray_remote_args = ray_remote_args
self.ray_remote_args = ray_remote_args or {}
def can_fuse(self, prev: Stage):
if not isinstance(prev, OneToOneStage):
return False
if prev.compute != self.compute:
return False
if prev.ray_remote_args != self.ray_remote_args:
for key in INHERITABLE_REMOTE_ARGS:
remote_args = self.ray_remote_args.copy()
if key in prev.ray_remote_args:
remote_args[key] = prev.ray_remote_args[key]
if prev.ray_remote_args != remote_args:
return False
return True
@ -275,7 +278,7 @@ class OneToOneStage(Stage):
for tmp2 in fn2(tmp1):
yield tmp2
return OneToOneStage(name, block_fn, self.compute, self.ray_remote_args)
return OneToOneStage(name, block_fn, prev.compute, prev.ray_remote_args)
def __call__(
self, blocks: BlockList, clear_input_blocks: bool
@ -298,11 +301,13 @@ class AllToAllStage(Stage):
fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]],
supports_block_udf: bool = False,
block_udf=None,
remote_args=None,
):
super().__init__(name, num_blocks)
self.fn = fn
self.supports_block_udf = supports_block_udf
self.block_udf = block_udf
self.ray_remote_args = remote_args or {}
def can_fuse(self, prev: Stage):
context = DatasetContext.get_current()
@ -315,18 +320,22 @@ class AllToAllStage(Stage):
return False
if prev.compute != "tasks":
return False
if prev.ray_remote_args:
if any(k not in INHERITABLE_REMOTE_ARGS for k in prev.ray_remote_args):
return False
return True
def fuse(self, prev: Stage):
assert self.supports_block_udf
name = prev.name + "->" + self.name
return AllToAllStage(name, self.num_blocks, self.fn, True, prev.block_fn)
return AllToAllStage(
name, self.num_blocks, self.fn, True, prev.block_fn, prev.ray_remote_args
)
def __call__(
self, blocks: BlockList, clear_input_blocks: bool
) -> Tuple[BlockList, dict]:
blocks, stage_info = self.fn(blocks, clear_input_blocks, self.block_udf)
blocks, stage_info = self.fn(
blocks, clear_input_blocks, self.block_udf, self.ray_remote_args
)
assert isinstance(blocks, BlockList), blocks
return blocks, stage_info

View file

@ -83,3 +83,9 @@ class ProgressBar:
def __del__(self):
self.close()
def __getstate__(self):
return {}
def __setstate__(self, state):
self._bar = None # Progress bar is disabled on remote nodes.

View file

@ -32,6 +32,7 @@ def simple_shuffle(
if reduce_ray_remote_args is None:
reduce_ray_remote_args = {}
if "scheduling_strategy" not in reduce_ray_remote_args:
reduce_ray_remote_args = reduce_ray_remote_args.copy()
reduce_ray_remote_args["scheduling_strategy"] = "SPREAD"
input_num_blocks = len(input_blocks)
if _spread_resource_prefix is not None:

View file

@ -209,8 +209,10 @@ class DatasetStats:
out = ""
if self.parents:
for p in self.parents:
out += p.summary_string(already_printed)
out += "\n"
parent_sum = p.summary_string(already_printed)
if parent_sum:
out += parent_sum
out += "\n"
first = True
for stage_name, metadata in self.stages.items():
stage_uuid = self.dataset_uuid + stage_name

View file

@ -259,17 +259,16 @@ def read_datasource(
if ray_remote_args is None:
ray_remote_args = {}
# Increase the read parallelism by default to maximize IO throughput. This
# is particularly important when reading from e.g., remote storage.
if "num_cpus" not in ray_remote_args:
# Note that the too many workers warning triggers at 4x subscription,
# so we go at 0.5 to avoid the warning message.
ray_remote_args["num_cpus"] = 0.5
if "scheduling_strategy" not in ray_remote_args:
ray_remote_args["scheduling_strategy"] = "SPREAD"
remote_read = cached_remote_fn(remote_read)
if _spread_resource_prefix is not None:
if context.optimize_fuse_stages:
logger.warning(
"_spread_resource_prefix has no effect when optimize_fuse_stages "
"is enabled. Tasks are spread by default."
)
# Use given spread resource prefix for round-robin resource-based
# scheduling.
nodes = ray.nodes()
@ -294,6 +293,7 @@ def read_datasource(
block_list = LazyBlockList(calls, metadata)
# TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly.
block_list._read_tasks = read_tasks
block_list._read_remote_args = ray_remote_args
# Get the schema from the first block synchronously.
if metadata and metadata[0].schema is None:

View file

@ -6,6 +6,7 @@ import pandas as pd
import numpy as np
import ray
from ray.data.context import DatasetContext
from ray.data.dataset_pipeline import DatasetPipeline
from ray.tests.conftest import * # noqa
@ -89,31 +90,35 @@ def test_cannot_read_twice(ray_start_regular_shared):
def test_basic_pipeline(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.count() == 10
pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=4)"
assert pipe.take() == list(range(10))
pipe = ds.window(blocks_per_window=999)
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=2)"
assert pipe.count() == 10
pipe = ds.repeat(10)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.count() == 100
pipe = ds.repeat(10)
assert pipe.sum() == 450
def test_window(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets())
@ -125,7 +130,7 @@ def test_window(ray_start_regular_shared):
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=5)
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=2)"
pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets())
@ -137,17 +142,19 @@ def test_window(ray_start_regular_shared):
def test_repeat(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=2)"
pipe = pipe.repeat(2)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=2)"
assert pipe.take() == (list(range(5)) + list(range(5)))
ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1)
pipe = pipe.repeat()
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
assert len(pipe.take(99)) == 99
pipe = ray.data.range(5).repeat()
@ -163,9 +170,11 @@ def test_from_iterable(ray_start_regular_shared):
def test_repeat_forever(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(10)
pipe = ds.repeat()
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=2)"
for i, v in enumerate(pipe.iter_rows()):
assert v == i % 10, (v, i, i % 10)
if i > 1000:
@ -212,7 +221,7 @@ def test_schema(ray_start_regular_shared):
def test_split(ray_start_regular_shared):
pipe = ray.data.range(3).map(lambda x: x + 1).repeat(10)
@ray.remote
@ray.remote(num_cpus=0)
def consume(shard, i):
total = 0
for row in shard.iter_rows():
@ -230,7 +239,7 @@ def test_split_at_indices(ray_start_regular_shared):
n = 8
pipe = ray.data.range(n).map(lambda x: x + 1).repeat(2)
@ray.remote
@ray.remote(num_cpus=0)
def consume(shard, i):
total = 0
out = []

View file

@ -17,11 +17,22 @@ def expect_stages(pipe, num_stages_expected, stage_names):
assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages
def test_spread_hint_inherit(ray_start_regular_shared):
ds = ray.data.range(10)._experimental_lazy()
ds = ds.map(lambda x: x + 1)
ds = ds.random_shuffle()
for s in ds._plan._stages:
assert s.ray_remote_args == {}, s.ray_remote_args
ds._plan._optimize()
assert len(ds._plan._stages) == 1, ds._plan._stages
assert ds._plan._stages[0].ray_remote_args == {"scheduling_strategy": "SPREAD"}
def test_optimize_fuse(ray_start_regular_shared):
context = DatasetContext.get_current()
def build_pipe():
pipe = ray.data.range(3).repeat(2)
pipe = ray.data.range(3).window(blocks_per_window=1).repeat(2)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.map_batches(lambda x: x)
pipe = pipe.random_shuffle_each_window()

View file

@ -2,6 +2,7 @@ import pytest
import re
import ray
from ray.data.context import DatasetContext
from ray.tests.conftest import * # noqa
@ -16,6 +17,8 @@ def canonicalize(stats: str) -> str:
def test_dataset_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)
ds = ds.map(lambda x: x)
@ -24,14 +27,7 @@ def test_dataset_stats_basic(ray_start_regular_shared):
stats = canonicalize(ds.stats())
assert (
stats
== """Stage Z read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map_batches: N/N blocks executed in T
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
@ -56,19 +52,14 @@ Dataset iterator time breakdown:
def test_dataset_stats_shuffle(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.random_shuffle().repartition(1, shuffle=True)
stats = canonicalize(ds.stats())
assert (
stats
== """Stage Z read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N random_shuffle_map: N/N blocks executed in T
== """Stage N read->random_shuffle_map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
@ -135,20 +126,15 @@ def test_dataset_stats_from_items(ray_start_regular_shared):
def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds.write_parquet(str(tmp_path))
ds = ray.data.read_parquet(str(tmp_path)).map(lambda x: x)
stats = canonicalize(ds.stats())
assert (
stats
== """Stage Z read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map: N/N blocks executed in T
== """Stage N read->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
@ -159,6 +145,8 @@ Stage N map: N/N blocks executed in T
def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
ds = ds.map_batches(lambda x: x)
pipe = ds.repeat(5)
@ -169,14 +157,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
assert (
stats
== """== Pipeline Window N ==
Stage Z read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N map_batches: N/N blocks executed in T
Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
@ -198,8 +179,7 @@ Dataset iterator time breakdown:
* Total time: T
== Pipeline Window N ==
Stage Z read: [execution cached]
Stage N map_batches: [execution cached]
Stage N read->map_batches: [execution cached]
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
@ -215,8 +195,7 @@ Dataset iterator time breakdown:
* Total time: T
== Pipeline Window N ==
Stage Z read: [execution cached]
Stage N map_batches: [execution cached]
Stage N read->map_batches: [execution cached]
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
@ -241,6 +220,8 @@ Dataset iterator time breakdown:
def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True
ds = ray.data.range(1000, parallelism=10)
pipe = ds.repeat(2)
@ -255,7 +236,7 @@ def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
assert (
canonicalize(stats[0])
== """== Pipeline Window Z ==
Stage Z read: N/N blocks executed in T
Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
@ -270,7 +251,7 @@ Dataset iterator time breakdown:
* Total time: T
== Pipeline Window N ==
Stage Z read: N/N blocks executed in T
Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total

View file

@ -47,6 +47,8 @@ parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--mock-train-step-time", type=float, default=1.0)
parser.add_argument("--num-files", type=int, default=30)
parser.add_argument("--num-windows", type=int, default=1)
parser.add_argument("--manual-windows", type=bool, default=False)
parser.add_argument("--parallelism", type=int, default=400)
SIZE_50_G = 30 # 49.17GB
SIZE_100_G = 62 # 101.62GB
@ -218,8 +220,15 @@ def create_torch_iterator(split, batch_size, rank=None):
return torch_iterator
def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
if num_windows > 1:
def create_dataset(
files,
num_workers=4,
epochs=50,
num_windows=1,
manual_windowing=False,
parallelism=400,
):
if num_windows > 1 and manual_windowing:
num_rows = ray.data.read_parquet(
files
).count() # This should only read Parquet metadata.
@ -247,7 +256,10 @@ def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
pipe = pipe.random_shuffle_each_window()
pipe_shards = pipe.split_at_indices(split_indices)
else:
ds = ray.data.read_parquet(files)
ds = ray.data.read_parquet(files, parallelism=parallelism)
if num_windows > 1:
window_size = max(ds.num_blocks() // num_windows, 1)
ds = ds.window(blocks_per_window=window_size)
pipe = ds.repeat(epochs)
pipe = pipe.random_shuffle_each_window()
pipe_shards = pipe.split(num_workers, equal=True)
@ -285,6 +297,8 @@ if __name__ == "__main__":
num_workers=args.num_workers,
epochs=args.epochs,
num_windows=args.num_windows,
manual_windowing=args.manual_windows,
parallelism=args.parallelism,
)
if args.debug: