[data] Add support for repeating and re-windowing a DatasetPipeline (#19091)

This commit is contained in:
Eric Liang 2021-10-06 20:13:43 -07:00 committed by GitHub
parent 1ed5f622c2
commit 86cbe3e833
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 364 additions and 57 deletions

View file

@ -6,17 +6,13 @@ Overview
Datasets execute their transformations synchronously in blocking calls. However, it can be useful to overlap dataset computations with output. This can be done with a `DatasetPipeline <package-ref.html#datasetpipeline-api>`__.
A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets. Conceptually it is similar to a `Spark DStream <https://spark.apache.org/docs/latest/streaming-programming-guide.html#discretized-streams-dstreams>`__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.).
A DatasetPipeline is an unified iterator over a (potentially infinite) sequence of Ray Datasets, each of which represents a *window* over the original data. Conceptually it is similar to a `Spark DStream <https://spark.apache.org/docs/latest/streaming-programming-guide.html#discretized-streams-dstreams>`__, but manages execution over a bounded amount of source data instead of an unbounded stream. Ray computes each dataset window on-demand and stitches their output together into a single logical data iterator. DatasetPipeline implements most of the same transformation and output methods as Datasets (e.g., map, filter, split, iter_rows, to_torch, etc.).
Creating a DatasetPipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~
A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.window``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example:
.. tip::
The "window size" of a pipeline is defined as the number of blocks per Dataset in the pipeline.
.. code-block:: python
import ray
@ -36,14 +32,14 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu
# -> Dataset(num_blocks=200, num_rows=1000000, schema=<class 'int'>)
pipe = base.window(blocks_per_window=10)
print(pipe)
# -> DatasetPipeline(length=20, num_stages=1)
# -> DatasetPipeline(num_windows=20, num_stages=1)
# Applying transforms to pipelines adds more pipeline stages.
pipe = pipe.map(func1)
pipe = pipe.map(func2)
pipe = pipe.map(func3)
print(pipe)
# -> DatasetPipeline(length=20, num_stages=4)
# -> DatasetPipeline(num_windows=20, num_stages=4)
# Output can be pulled from the pipeline concurrently with its execution.
num_rows = 0
@ -73,6 +69,48 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato
splits = ray.data.range(1000, parallelism=200).split(20)
pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits])
Per-Window Transformations
~~~~~~~~~~~~~~~~~~~~~~~~~~
While most Dataset operations are per-row (e.g., map, filter), some operations apply to the Dataset as a whole (e.g., sort, shuffle). When applied to a pipeline, holistic transforms like shuffle are applied separately to each window in the pipeline:
.. code-block:: python
# Example of randomly shuffling each window of a pipeline.
ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows()
# ->
# === Window 0 ===
# 4
# 3
# 1
# 0
# 2
# === Window 1 ===
# 2
# 1
# 4
# 0
# 3
You can also apply arbitrary transformations to each window using ``DatasetPipeline.foreach_window()``:
.. code-block:: python
# Equivalent transformation using .foreach_window()
ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows()
# ->
# === Window 0 ===
# 1
# 0
# 4
# 2
# 3
# === Window 1 ===
# 4
# 2
# 0
# 3
# 1
Example: Pipelined Batch Inference
----------------------------------
@ -158,7 +196,7 @@ Transformations made prior to the Dataset prior to the call to ``.repeat()`` are
pipe: DatasetPipeline = ray.data \
.read_datasource(...) \
.repeat() \
.random_shuffle()
.random_shuffle_each_window()
@ray.remote(num_gpus=1)
def train_func(pipe: DatasetPipeline):
@ -187,7 +225,7 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel
pipe: DatasetPipeline = ray.data \
.read_parquet("s3://bucket/dir") \
.repeat() \
.random_shuffle()
.random_shuffle_each_window()
@ray.remote(num_gpus=1)
class TrainingWorker:
@ -204,3 +242,55 @@ Similar to how you can ``.split()`` a Dataset, you can also split a DatasetPipel
**Pipeline**:
.. image:: dataset-repeat-2.svg
Changing Pipeline Structure
---------------------------
Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data:
.. code-block:: python
# Window followed by repeat.
ray.data.range(5) \
.window(blocks_per_window=2) \
.repeat(2) \
.show_windows()
# ->
# === Window 0 ===
# 0
# 1
# === Window 1 ===
# 2
# 3
# === Window 2 ===
# 4
# === Window 3 ===
# 0
# 1
# === Window 4 ===
# 2
# 3
# === Window 5 ===
# 4
# Repeat followed by window.
ray.data.range(5) \
.repeat(2) \
.rewindow(blocks_per_window=2) \
.show_windows()
# ->
# === Window 0 ===
# 0
# 1
# === Window 1 ===
# 2
# 3
# === Window 2 ===
# 4
# 0
# === Window 3 ===
# 1
# 2
# === Window 4 ===
# 3
# 4

View file

@ -277,8 +277,8 @@ Papers
:caption: Ray Data
data/dataset.rst
data/dataset-tensor-support.rst
data/dataset-pipeline.rst
data/dataset-tensor-support.rst
data/package-ref.rst
data/dask-on-ray.rst
data/mars-on-ray.rst
@ -366,7 +366,7 @@ Papers
.. toctree::
:hidden:
:maxdepth: -1
:caption: Contributing
:caption: Contributor Guide
getting-involved.rst
development.rst

View file

@ -704,13 +704,16 @@ class Dataset(Generic[T]):
return splits
def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
def union(self, *other: List["Dataset[T]"],
preserve_order: bool = False) -> "Dataset[T]":
"""Combine this dataset with others of the same type.
Args:
other: List of datasets to combine with this one. The datasets
must have the same schema as this dataset, otherwise the
behavior is undefined.
preserve_order: Whether to preserve the order of the data blocks.
This may trigger eager loading of data from disk.
Returns:
A new dataset holding the union of their data.
@ -725,6 +728,11 @@ class Dataset(Generic[T]):
for ds in datasets:
bl = ds._blocks
if isinstance(bl, LazyBlockList):
if preserve_order:
# Force evaluation of blocks, which preserves order since
# then we don't need to move evaluated blocks to the front
# of LazyBlockList.
list(bl)
for block, meta in zip(bl._blocks, bl._metadata):
blocks.append(block)
metadata.append(meta)
@ -1642,7 +1650,7 @@ class Dataset(Generic[T]):
def __iter__(self):
return Iterator(self._ds)
return DatasetPipeline(Iterable(self), length=times)
return DatasetPipeline(Iterable(self), length=times or float("inf"))
def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]":
raise DeprecationWarning("Use .window(blocks_per_window=n) instead of "
@ -1678,11 +1686,11 @@ class Dataset(Generic[T]):
>>> # Create an inference pipeline.
>>> ds = ray.data.read_binary_files(dir)
>>> pipe = ds.window(blocks_per_window=10).map(infer)
DatasetPipeline(num_stages=2, length=40)
DatasetPipeline(num_windows=40, num_stages=2)
>>> # The higher the stage parallelism, the shorter the pipeline.
>>> pipe = ds.window(blocks_per_window=20).map(infer)
DatasetPipeline(num_stages=2, length=20)
DatasetPipeline(num_windows=20, num_stages=2)
>>> # Outputs can be incrementally read from the pipeline.
>>> for item in pipe.iter_rows():
@ -1777,6 +1785,10 @@ class Dataset(Generic[T]):
right = None
return left, right
def _divide(self, block_idx: int) -> ("Dataset[T]", "Dataset[T]"):
left, right = self._blocks.divide(block_idx)
return Dataset(left), Dataset(right)
def __repr__(self) -> str:
schema = self.schema()
if schema is None:
@ -1792,8 +1804,6 @@ class Dataset(Generic[T]):
schema_str = ", ".join(schema_str)
schema_str = "{" + schema_str + "}"
count = self._meta_count()
if count is None:
count = "?"
return "Dataset(num_blocks={}, num_rows={}, schema={})".format(
len(self._blocks), count, schema_str)

View file

@ -1,7 +1,7 @@
import functools
import time
from typing import Any, Callable, List, Iterator, Iterable, Generic, Union, \
TYPE_CHECKING
Optional, TYPE_CHECKING
import ray
from ray.data.dataset import Dataset, T, U, BatchType
@ -13,13 +13,15 @@ from ray.util.annotations import PublicAPI, DeveloperAPI
if TYPE_CHECKING:
import pyarrow
# Operations that can be naively applied per dataset in the pipeline.
# Operations that can be naively applied per dataset row in the pipeline.
PER_DATASET_OPS = [
"map", "map_batches", "flat_map", "filter", "repartition",
"random_shuffle", "sort", "write_json", "write_csv", "write_parquet",
"write_datasource"
"map", "map_batches", "flat_map", "filter", "write_json", "write_csv",
"write_parquet", "write_datasource"
]
# Operations that apply to each dataset holistically in the pipeline.
HOLISTIC_PER_DATASET_OPS = ["repartition", "random_shuffle", "sort"]
# Similar to above but we should force evaluation immediately.
PER_DATASET_OUTPUT_OPS = [
"write_json", "write_csv", "write_parquet", "write_datasource"
@ -240,32 +242,123 @@ class DatasetPipeline(Generic[T]):
for idx in range(n)
]
def window(self, *, blocks_per_window: int) -> "DatasetPipeline[T]":
def rewindow(self, *, blocks_per_window: int) -> "DatasetPipeline[T]":
"""Change the windowing (blocks per dataset) of this pipeline.
Changes the windowing of this pipeline to the specified size. For
example, if the current pipeline has two blocks per dataset, and
`.window(blocks_per_window=4)` is requested, adjacent datasets will
`.rewindow(blocks_per_window=4)` is requested, adjacent datasets will
be merged until each dataset is 4 blocks. If
`.window(blocks_per_window=1)` was requested the datasets will
be split into smaller windows.
`.rewindow(blocks_per_window)` was requested the datasets will be
split into smaller windows.
Args:
blocks_per_window: The new target blocks per window.
"""
raise NotImplementedError
class WindowIterator:
def __init__(self, original_iter):
self._original_iter = original_iter
self._buffer: Optional[Dataset[T]] = None
def __next__(self) -> Dataset[T]:
try:
# Merge windows until we meet the requested window size.
if self._buffer is None:
self._buffer = next(self._original_iter)
while self._buffer.num_blocks() < blocks_per_window:
self._buffer = self._buffer.union(
next(self._original_iter), preserve_order=True)
# Slice off the left-most chunk and return it.
res, self._buffer = self._buffer._divide(blocks_per_window)
assert res.num_blocks() <= blocks_per_window, res
return lambda: res
except StopIteration:
# Return the left-over data as a single window.
if self._buffer and self._buffer.num_blocks() > 0:
res = self._buffer
assert res.num_blocks() <= blocks_per_window, res
self._buffer = None
return lambda: res
else:
raise
class WindowIterable:
def __init__(self, original_iter):
self._original_iter = original_iter
def __iter__(self):
return WindowIterator(self._original_iter)
return DatasetPipeline(
WindowIterable(self.iter_datasets()), length=None)
def repeat(self, times: int = None) -> "DatasetPipeline[T]":
"""Repeat this pipeline a given number or times, or indefinitely.
This operation is only allowed for pipelines of a finite length. An
error will be raised for pipelines of infinite or unknown length.
error will be raised for pipelines of infinite length.
Transformations prior to the call to ``repeat()`` are evaluated once.
Transformations done on the repeated pipeline are evaluated on each
loop of the pipeline over the base pipeline.
Args:
times: The number of times to loop over this pipeline, or None
to repeat indefinitely.
"""
raise NotImplementedError
if self._length == float("inf"):
raise ValueError("Cannot repeat a pipeline of infinite length.")
class RepeatIterator:
def __init__(self, original_iter):
self._original_iter = original_iter
# Holds results to repeat.
self._results = []
# Incrementing cursor over results.
self._i = 0
# This is calculated later.
self._max_i = None
def __next__(self) -> Dataset[T]:
# Still going through the original pipeline.
if self._original_iter:
try:
res = next(self._original_iter)
self._results.append(res)
return lambda: res
except StopIteration:
self._original_iter = None
# Calculate the cursor limit.
if times:
self._max_i = len(self._results) * (times - 1)
else:
self._max_i = float("inf")
# Going through a repeat of the pipeline.
if self._i < self._max_i:
res = self._results[self._i % len(self._results)]
self._i += 1
return lambda: res
else:
raise StopIteration
class RepeatIterable:
def __init__(self, original_iter):
self._original_iter = original_iter
def __iter__(self):
return RepeatIterator(self._original_iter)
if not times:
length = float("inf")
elif times and self._length:
length = times * self._length
else:
length = None
return DatasetPipeline(
RepeatIterable(self.iter_datasets()), length=length)
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
"""Return the schema of the dataset pipeline.
@ -314,6 +407,19 @@ class DatasetPipeline(Generic[T]):
total += elem
return total
def show_windows(self, limit_per_dataset: int = 10) -> None:
"""Print up to the given number of records from each window/dataset.
This is helpful as a debugging tool for understanding the structure of
dataset pipelines.
Args:
limit_per_dataset: Rows to print per window/dataset.
"""
for i, ds in enumerate(self.iter_datasets()):
print("=== Window {} ===".format(i))
ds.show(limit_per_dataset)
@DeveloperAPI
def iter_datasets(self) -> Iterator[Dataset[T]]:
"""Iterate over the output datasets of this pipeline.
@ -327,9 +433,9 @@ class DatasetPipeline(Generic[T]):
return PipelineExecutor(self)
@DeveloperAPI
def foreach_dataset(self, fn: Callable[[Dataset[T]], Dataset[U]]
) -> "DatasetPipeline[U]":
"""Apply a transform to each dataset in this pipeline.
def foreach_window(self, fn: Callable[[Dataset[T]], Dataset[U]]
) -> "DatasetPipeline[U]":
"""Apply a transform to each dataset/window in this pipeline.
Args:
fn: The function to transform each dataset with.
@ -346,6 +452,10 @@ class DatasetPipeline(Generic[T]):
self._progress_bars,
_executed=self._executed)
def foreach_dataset(self, *a, **kw) -> None:
raise DeprecationWarning(
"`foreach_dataset` has been renamed to `foreach_window`.")
@staticmethod
def from_iterable(iterable: Iterable[Callable[[], Dataset[T]]],
) -> "DatasetPipeline[T]":
@ -362,7 +472,7 @@ class DatasetPipeline(Generic[T]):
return DatasetPipeline(iterable, length=length)
def __repr__(self) -> str:
return "DatasetPipeline(length={}, num_stages={})".format(
return "DatasetPipeline(num_windows={}, num_stages={})".format(
self._length, 1 + len(self._stages))
def __str__(self) -> str:
@ -382,7 +492,7 @@ for method in PER_DATASET_OPS:
@functools.wraps(delegate)
def impl(self, *args, **kwargs):
return self.foreach_dataset(
return self.foreach_window(
lambda ds: getattr(ds, method)(*args, **kwargs))
if impl.__annotations__.get("return"):
@ -393,6 +503,33 @@ for method in PER_DATASET_OPS:
setattr(DatasetPipeline, method, make_impl(method))
for method in HOLISTIC_PER_DATASET_OPS:
def make_impl(method):
delegate = getattr(Dataset, method)
@functools.wraps(delegate)
def impl(self, *args, **kwargs):
return self.foreach_window(
lambda ds: getattr(ds, method)(*args, **kwargs))
if impl.__annotations__.get("return"):
impl.__annotations__["return"] = impl.__annotations__[
"return"].replace("Dataset", "DatasetPipeline")
return impl
def deprecation_warning(method: str):
def impl(*a, **kw):
raise DeprecationWarning(
"`{}` has been renamed to `{}_each_window`.".format(
method, method))
return impl
setattr(DatasetPipeline, method, deprecation_warning(method))
setattr(DatasetPipeline, method + "_each_window", make_impl(method))
for method in PER_DATASET_OUTPUT_OPS:
def make_impl(method):

View file

@ -42,6 +42,13 @@ class BlockList(Iterable[ObjectRef[Block]]):
output.append(BlockList(b.tolist(), m.tolist()))
return output
def divide(self, block_idx: int) -> ("BlockList", "BlockList"):
self._check_if_cleared()
return (BlockList(self._blocks[:block_idx],
self._metadata[:block_idx]),
BlockList(self._blocks[block_idx:],
self._metadata[block_idx:]))
def __len__(self):
self._check_if_cleared()
return len(self._blocks)

View file

@ -28,6 +28,7 @@ class LazyBlockList(BlockList[T]):
self._calls = None
def split(self, split_size: int) -> List["LazyBlockList"]:
# TODO(ekl) isn't this not copying already computed blocks?
self._check_if_cleared()
num_splits = math.ceil(len(self._calls) / split_size)
calls = np.array_split(self._calls, num_splits)
@ -37,6 +38,18 @@ class LazyBlockList(BlockList[T]):
output.append(LazyBlockList(c.tolist(), m.tolist()))
return output
def divide(self, block_idx: int) -> ("BlockList", "BlockList"):
self._check_if_cleared()
left = self.copy()
right = self.copy()
left._calls = left._calls[:block_idx]
left._blocks = left._blocks[:block_idx]
left._metadata = left._metadata[:block_idx]
right._calls = right._calls[block_idx:]
right._blocks = right._blocks[block_idx:]
right._metadata = right._metadata[block_idx:]
return left, right
def __len__(self):
self._check_if_cleared()
return len(self._calls)

View file

@ -27,12 +27,15 @@ class PipelineExecutor:
self._iter = iter(self._pipeline._base_iterable)
self._stages[0] = pipeline_stage.remote(next(self._iter))
if self._pipeline._length and self._pipeline._length != float("inf"):
length = self._pipeline._length
else:
length = 1
if self._pipeline._progress_bars:
self._bars = [
ProgressBar(
"Stage {}".format(i),
self._pipeline._length or 1,
position=i) for i in range(len(self._stages))
ProgressBar("Stage {}".format(i), length, position=i)
for i in range(len(self._stages))
]
else:
self._bars = None

View file

@ -150,7 +150,7 @@ def test_transform_failure(shutdown_only):
def mapper(x):
time.sleep(x)
assert False
raise ValueError("oops")
return x
with pytest.raises(ray.exceptions.RayTaskError):
@ -723,7 +723,7 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path):
ds.write_numpy(data_path, filesystem=fs)
ds = ray.data.read_numpy(data_path, filesystem=fs)
assert str(ds) == (
"Dataset(num_blocks=2, num_rows=?, "
"Dataset(num_blocks=2, num_rows=None, "
"schema={value: <ArrowTensorType: shape=(1,), dtype=int64>})")
assert str(ds.take(2)) == \
"[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]"
@ -736,7 +736,7 @@ def test_numpy_read(ray_start_regular_shared, tmp_path):
os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1))
ds = ray.data.read_numpy(path)
assert str(ds) == (
"Dataset(num_blocks=1, num_rows=?, "
"Dataset(num_blocks=1, num_rows=None, "
"schema={value: <ArrowTensorType: shape=(1,), dtype=int64>})")
assert str(ds.take(2)) == \
"[ArrowRow({'value': array([0])}), ArrowRow({'value': array([1])})]"
@ -2522,7 +2522,9 @@ def test_random_shuffle(shutdown_only, pipelined):
def range(n, parallelism=200):
ds = ray.data.range(n, parallelism=parallelism)
if pipelined:
return ds.repeat(2)
pipe = ds.repeat(2)
pipe.random_shuffle = pipe.random_shuffle_each_window
return pipe
else:
return ds
@ -2692,7 +2694,7 @@ def test_dataset_retry_exceptions(ray_start_regular, local_path):
def _read_file(self, f: "pa.NativeFile", path: str, **reader_args):
count = self.counter.increment.remote()
if ray.get(count) == 1:
raise ValueError()
raise ValueError("oops")
else:
return CSVDatasource._read_file(self, f, path, **reader_args)
@ -2700,7 +2702,7 @@ def test_dataset_retry_exceptions(ray_start_regular, local_path):
**writer_args):
count = self.counter.increment.remote()
if ray.get(count) == 1:
raise ValueError()
raise ValueError("oops")
else:
CSVDatasource._write_block(self, f, block, **writer_args)
@ -2720,7 +2722,7 @@ def test_dataset_retry_exceptions(ray_start_regular, local_path):
def flaky_mapper(x):
count = counter.increment.remote()
if ray.get(count) == 1:
raise ValueError()
raise ValueError("oops")
else:
return ray.get(count)

View file

@ -53,24 +53,69 @@ def test_basic_pipeline(ray_start_regular_shared):
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert pipe.count() == 10
pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x)
assert str(pipe) == "DatasetPipeline(length=10, num_stages=3)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=3)"
assert pipe.take() == list(range(10))
pipe = ds.window(blocks_per_window=999)
assert str(pipe) == "DatasetPipeline(length=1, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=1, num_stages=1)"
assert pipe.count() == 10
pipe = ds.repeat(10)
assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert pipe.count() == 100
pipe = ds.repeat(10)
assert pipe.sum() == 450
def test_window(ray_start_regular_shared):
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets())
assert len(datasets) == 4
assert datasets[0].take() == [0, 1, 2]
assert datasets[1].take() == [3, 4, 5]
assert datasets[2].take() == [6, 7, 8]
assert datasets[3].take() == [9]
ds = ray.data.range(10)
pipe = ds.window(blocks_per_window=5)
assert str(pipe) == "DatasetPipeline(num_windows=2, num_stages=1)"
pipe = pipe.rewindow(blocks_per_window=3)
assert str(pipe) == "DatasetPipeline(num_windows=None, num_stages=1)"
datasets = list(pipe.iter_datasets())
assert len(datasets) == 4
assert datasets[0].take() == [0, 1, 2]
assert datasets[1].take() == [3, 4, 5]
assert datasets[2].take() == [6, 7, 8]
assert datasets[3].take() == [9]
def test_repeat(ray_start_regular_shared):
ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1)
assert str(pipe) == "DatasetPipeline(num_windows=5, num_stages=1)"
pipe = pipe.repeat(2)
assert str(pipe) == "DatasetPipeline(num_windows=10, num_stages=1)"
assert pipe.take() == (list(range(5)) + list(range(5)))
ds = ray.data.range(5)
pipe = ds.window(blocks_per_window=1)
pipe = pipe.repeat()
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
assert len(pipe.take(99)) == 99
pipe = ray.data.range(5).repeat()
with pytest.raises(ValueError):
pipe.repeat()
def test_from_iterable(ray_start_regular_shared):
pipe = DatasetPipeline.from_iterable(
[lambda: ray.data.range(3), lambda: ray.data.range(2)])
@ -80,7 +125,7 @@ def test_from_iterable(ray_start_regular_shared):
def test_repeat_forever(ray_start_regular_shared):
ds = ray.data.range(10)
pipe = ds.repeat()
assert str(pipe) == "DatasetPipeline(length=None, num_stages=1)"
assert str(pipe) == "DatasetPipeline(num_windows=inf, num_stages=1)"
for i, v in enumerate(pipe.iter_rows()):
assert v == i % 10, (v, i, i % 10)
if i > 1000:
@ -89,11 +134,11 @@ def test_repeat_forever(ray_start_regular_shared):
def test_repartition(ray_start_regular_shared):
pipe = ray.data.range(10).repeat(10)
assert pipe.repartition(1).sum() == 450
assert pipe.repartition_each_window(1).sum() == 450
pipe = ray.data.range(10).repeat(10)
assert pipe.repartition(10).sum() == 450
assert pipe.repartition_each_window(10).sum() == 450
pipe = ray.data.range(10).repeat(10)
assert pipe.repartition(100).sum() == 450
assert pipe.repartition_each_window(100).sum() == 450
def test_iter_batches(ray_start_regular_shared):
@ -113,9 +158,9 @@ def test_iter_datasets(ray_start_regular_shared):
assert len(ds) == 2
def test_foreach_dataset(ray_start_regular_shared):
def test_foreach_window(ray_start_regular_shared):
pipe = ray.data.range(5).window(blocks_per_window=2)
pipe = pipe.foreach_dataset(lambda ds: ds.map(lambda x: x * 2))
pipe = pipe.foreach_window(lambda ds: ds.map(lambda x: x * 2))
assert pipe.take() == [0, 2, 4, 6, 8]

View file

@ -244,12 +244,12 @@ def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
i * num_rows // num_windows // num_workers
for i in range(1, num_workers)
]
pipe = pipe.random_shuffle(_spread_resource_prefix="node:")
pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
pipe_shards = pipe.split_at_indices(split_indices)
else:
ds = ray.data.read_parquet(files, _spread_resource_prefix="node:")
pipe = ds.repeat(epochs)
pipe = pipe.random_shuffle(_spread_resource_prefix="node:")
pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
pipe_shards = pipe.split(num_workers, equal=True)
return pipe_shards