[data] Preserve epoch by default when using rewindow() (#19359)

This commit is contained in:
Eric Liang 2021-10-14 09:17:36 -07:00 committed by GitHub
parent 4edb3c4746
commit 13d4ad6100
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 16 deletions

View file

@ -78,7 +78,9 @@ It's common in ML training to want to divide data ingest into epochs, or repetit
.. code-block:: python
pipe = ray.data.range(5).repeat(3).random_shuffle_each_window()
pipe = ray.data.from_items([0, 1, 2, 3, 4]) \
.repeat(3) \
.random_shuffle_each_window()
for i, epoch in enumerate(pipe.iter_epochs()):
print("Epoch {}", i)
for row in epoch.iter_rows():
@ -113,7 +115,10 @@ While most Dataset operations are per-row (e.g., map, filter), some operations a
.. code-block:: python
# Example of randomly shuffling each window of a pipeline.
ray.data.range(5).repeat(2).random_shuffle_each_window().show_windows()
ray.data.from_items([0, 1, 2, 3, 4]) \
.repeat(2) \
.random_shuffle_each_window() \
.show_windows()
# ->
# ----- Epoch 0 ------
# === Window 0 ===
@ -135,7 +140,10 @@ You can also apply arbitrary transformations to each window using ``DatasetPipel
.. code-block:: python
# Equivalent transformation using .foreach_window()
ray.data.range(5).repeat(2).foreach_window(lambda w: w.random_shuffle()).show_windows()
ray.data.from_items([0, 1, 2, 3, 4]) \
.repeat(2) \
.foreach_window(lambda w: w.random_shuffle()) \
.show_windows()
# ->
# ----- Epoch 0 ------
# === Window 0 ===
@ -336,12 +344,12 @@ See :ref:`the SGD User Guide <sgd-dataset-pipeline>` for more details.
Changing Pipeline Structure
---------------------------
Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data:
Sometimes, you may want to change the structure of an existing pipeline. For example, after generating a pipeline with ``ds.window(k)``, you may want to repeat that windowed pipeline ``n`` times. This can be done with ``ds.window(k).repeat(n)``. As another example, suppose you have a repeating pipeline generated with ``ds.repeat(n)``. The windowing of that pipeline can be changed with ``ds.repeat(n).rewindow(k)``. Note the subtle difference in the two examples: the former is repeating a windowed pipeline that has a base window size of ``k``, while the latter is re-windowing a pipeline of initial window size of ``ds.num_blocks()``. The latter may produce windows that span multiple copies of the same original data if ``preserve_epoch=False`` is set:
.. code-block:: python
# Window followed by repeat.
ray.data.range(5) \
ray.data.from_items([0, 1, 2, 3, 4]) \
.window(blocks_per_window=2) \
.repeat(2) \
.show_windows()
@ -365,12 +373,12 @@ Sometimes, you may want to change the structure of an existing pipeline. For exa
# === Window 5 ===
# 4
# Repeat followed by window. Note that epoch 1 contains some leftover
# data from the tail end of epoch 0, since re-windowing can merge windows
# across epochs.
ray.data.range(5) \
# Repeat followed by window. Since preserve_epoch=True, at epoch boundaries
# windows may be smaller than the target size. If it was set to False, all
# windows except the last would be the target size.
ray.data.from_items([0, 1, 2, 3, 4]) \
.repeat(2) \
.rewindow(blocks_per_window=2) \
.rewindow(blocks_per_window=2, preserve_epoch=True) \
.show_windows()
# ->
# ------ Epoch 0 ------
@ -380,13 +388,14 @@ Sometimes, you may want to change the structure of an existing pipeline. For exa
# === Window 1 ===
# 2
# 3
# ------ Epoch 1 ------
# === Window 2 ===
# 4
# 0
# ------ Epoch 1 ------
# === Window 3 ===
# 0
# 1
# 2
# === Window 4 ===
# 2
# 3
# === Window 5 ===
# 4

View file

@ -242,7 +242,8 @@ class DatasetPipeline(Generic[T]):
for idx in range(n)
]
def rewindow(self, *, blocks_per_window: int) -> "DatasetPipeline[T]":
def rewindow(self, *, blocks_per_window: int,
preserve_epoch: bool = True) -> "DatasetPipeline[T]":
"""Change the windowing (blocks per dataset) of this pipeline.
Changes the windowing of this pipeline to the specified size. For
@ -254,6 +255,8 @@ class DatasetPipeline(Generic[T]):
Args:
blocks_per_window: The new target blocks per window.
preserve_epoch: Whether to preserve epoch boundaries. If set to
False, then windows can contain data from two adjacent epochs.
"""
class WindowIterator:
@ -267,8 +270,14 @@ class DatasetPipeline(Generic[T]):
if self._buffer is None:
self._buffer = next(self._original_iter)
while self._buffer.num_blocks() < blocks_per_window:
self._buffer = self._buffer.union(
next(self._original_iter))
next_ds = next(self._original_iter)
if (preserve_epoch and self._buffer._get_epoch() !=
next_ds._get_epoch()):
partial_window = self._buffer
self._buffer = next_ds
return lambda: partial_window
else:
self._buffer = self._buffer.union(next_ds)
# Slice off the left-most chunk and return it.
res, self._buffer = self._buffer._divide(blocks_per_window)
assert res.num_blocks() <= blocks_per_window, res

View file

@ -53,6 +53,17 @@ def test_epoch(ray_start_regular_shared):
assert results == [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
# Test preserve_epoch=True.
pipe = ray.data.range(5).repeat(2).rewindow(blocks_per_window=2)
results = [p.take() for p in pipe.iter_epochs()]
assert results == [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
# Test preserve_epoch=False.
pipe = ray.data.range(5).repeat(2).rewindow(
blocks_per_window=2, preserve_epoch=False)
results = [p.take() for p in pipe.iter_epochs()]
assert results == [[0, 1, 2, 3], [4, 0, 1, 2, 3, 4]]
def test_cannot_read_twice(ray_start_regular_shared):
ds = ray.data.range(10)