[Datasets] Fix empty Dataset.iter_batches() when trying to prefetch more blocks than exist in the dataset (#20480)

Before this PR, `ds.iter_batches()` would yield no batches if `prefetch_blocks > ds.num_blocks()` was given, since the sliding window semantics were to return no windows if `window_size > len(iterable)`. This PR tweaks the sliding window implementation to always return at least one window, even if the one window is smaller than the given window size.
This commit is contained in:
Clark Zinzow 2021-11-18 17:02:54 -08:00 committed by GitHub
parent add2450b92
commit 462e389791
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 25 deletions

View file

@ -1832,26 +1832,6 @@ class Dataset(Generic[T]):
A list of iterators over record batches.
"""
def sliding_window(iterable: Iterable, n: int):
"""Creates an iterator consisting of n-width sliding windows over
iterable. The sliding windows are constructed lazily such that an
element on the base iterator (iterable) isn't consumed until the
first sliding window containing that element is reached.
Args:
iterable: The iterable on which the sliding window will be
created.
n: The width of the sliding window.
Returns:
An iterator of n-width windows over iterable.
"""
iters = itertools.tee(iter(iterable), n)
for i in range(1, n):
for it in iters[i:]:
next(it, None)
return zip(*iters)
def format_batch(batch: Block, format: str) -> BatchType:
if batch_format == "native":
return batch
@ -1875,8 +1855,8 @@ class Dataset(Generic[T]):
yield format_batch(batcher.next_batch(), batch_format)
block_window = [] # Handle empty sliding window gracefully.
for block_window in sliding_window(self._blocks.iter_blocks(),
prefetch_blocks + 1):
for block_window in _sliding_window(self._blocks.iter_blocks(),
prefetch_blocks + 1):
block_window = list(block_window)
ray.wait(block_window, num_returns=1, fetch_local=True)
yield from batch_block(block_window[0])
@ -2511,6 +2491,34 @@ def _block_to_arrow(block: Block):
return block.to_arrow()
def _sliding_window(iterable: Iterable, n: int):
"""Creates an iterator consisting of n-width sliding windows over
iterable. The sliding windows are constructed lazily such that an
element on the base iterator (iterable) isn't consumed until the
first sliding window containing that element is reached.
If n > len(iterable), then a single len(iterable) window is
returned.
Args:
iterable: The iterable on which the sliding window will be
created.
n: The width of the sliding window.
Returns:
An iterator of n-width windows over iterable.
If n > len(iterable), then a single len(iterable) window is
returned.
"""
it = iter(iterable)
window = collections.deque(itertools.islice(it, n), maxlen=n)
if len(window) > 0:
yield tuple(window)
for elem in it:
window.append(elem)
yield tuple(window)
def _split_block(
block: Block, meta: BlockMetadata, count: int, return_right_half: bool
) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]):

View file

@ -17,7 +17,7 @@ from pytest_lazyfixture import lazy_fixture
import ray
from ray.tests.conftest import * # noqa
from ray.data.dataset import Dataset
from ray.data.dataset import Dataset, _sliding_window
from ray.data.datasource import DummyOutputDatasource
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.data.block import BlockAccessor
@ -1762,6 +1762,25 @@ def test_read_binary_files_s3(ray_start_regular_shared):
assert item == expected
def test_sliding_window():
arr = list(range(10))
# Test all windows over this iterable.
window_sizes = list(range(1, len(arr) + 1))
for window_size in window_sizes:
windows = list(_sliding_window(arr, window_size))
assert len(windows) == len(arr) - window_size + 1
assert all(len(window) == window_size for window in windows)
assert all(
list(window) == arr[i:i + window_size]
for i, window in enumerate(windows))
# Test window size larger than iterable length.
windows = list(_sliding_window(arr, 15))
assert len(windows) == 1
assert list(windows[0]) == arr
def test_iter_batches_basic(ray_start_regular_shared):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": [2, 3, 4]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": [5, 6, 7]})
@ -1830,8 +1849,9 @@ def test_iter_batches_basic(ray_start_regular_shared):
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))
# Prefetch.
for batch, df in zip(
ds.iter_batches(prefetch_blocks=1, batch_format="pandas"), dfs):
batches = list(ds.iter_batches(prefetch_blocks=1, batch_format="pandas"))
assert len(batches) == len(dfs)
for batch, df in zip(batches, dfs):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)
@ -1845,6 +1865,14 @@ def test_iter_batches_basic(ray_start_regular_shared):
assert pd.concat(
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))
# Prefetch more than number of blocks.
batches = list(
ds.iter_batches(prefetch_blocks=len(dfs), batch_format="pandas"))
assert len(batches) == len(dfs)
for batch, df in zip(batches, dfs):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)
def test_iter_batches_grid(ray_start_regular_shared):
# Tests slicing, batch combining, and partial batch dropping logic over