From 462e389791567faa4461b38e71d9940c682000f3 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Thu, 18 Nov 2021 17:02:54 -0800 Subject: [PATCH] [Datasets] Fix empty Dataset.iter_batches() when trying to prefetch more blocks than exist in the dataset (#20480) Before this PR, `ds.iter_batches()` would yield no batches if `prefetch_blocks > ds.num_blocks()` was given, since the sliding window semantics were to return no windows if `window_size > len(iterable)`. This PR tweaks the sliding window implementation to always return at least one window, even if the one window is smaller than the given window size. --- python/ray/data/dataset.py | 52 +++++++++++++++------------ python/ray/data/tests/test_dataset.py | 34 ++++++++++++++++-- 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8ac2e5bf2..15cb7e2d1 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1832,26 +1832,6 @@ class Dataset(Generic[T]): A list of iterators over record batches. """ - def sliding_window(iterable: Iterable, n: int): - """Creates an iterator consisting of n-width sliding windows over - iterable. The sliding windows are constructed lazily such that an - element on the base iterator (iterable) isn't consumed until the - first sliding window containing that element is reached. - - Args: - iterable: The iterable on which the sliding window will be - created. - n: The width of the sliding window. - - Returns: - An iterator of n-width windows over iterable. - """ - iters = itertools.tee(iter(iterable), n) - for i in range(1, n): - for it in iters[i:]: - next(it, None) - return zip(*iters) - def format_batch(batch: Block, format: str) -> BatchType: if batch_format == "native": return batch @@ -1875,8 +1855,8 @@ class Dataset(Generic[T]): yield format_batch(batcher.next_batch(), batch_format) block_window = [] # Handle empty sliding window gracefully. - for block_window in sliding_window(self._blocks.iter_blocks(), - prefetch_blocks + 1): + for block_window in _sliding_window(self._blocks.iter_blocks(), + prefetch_blocks + 1): block_window = list(block_window) ray.wait(block_window, num_returns=1, fetch_local=True) yield from batch_block(block_window[0]) @@ -2511,6 +2491,34 @@ def _block_to_arrow(block: Block): return block.to_arrow() +def _sliding_window(iterable: Iterable, n: int): + """Creates an iterator consisting of n-width sliding windows over + iterable. The sliding windows are constructed lazily such that an + element on the base iterator (iterable) isn't consumed until the + first sliding window containing that element is reached. + + If n > len(iterable), then a single len(iterable) window is + returned. + + Args: + iterable: The iterable on which the sliding window will be + created. + n: The width of the sliding window. + + Returns: + An iterator of n-width windows over iterable. + If n > len(iterable), then a single len(iterable) window is + returned. + """ + it = iter(iterable) + window = collections.deque(itertools.islice(it, n), maxlen=n) + if len(window) > 0: + yield tuple(window) + for elem in it: + window.append(elem) + yield tuple(window) + + def _split_block( block: Block, meta: BlockMetadata, count: int, return_right_half: bool ) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]): diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 4ab62aa0d..694daee07 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -17,7 +17,7 @@ from pytest_lazyfixture import lazy_fixture import ray from ray.tests.conftest import * # noqa -from ray.data.dataset import Dataset +from ray.data.dataset import Dataset, _sliding_window from ray.data.datasource import DummyOutputDatasource from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor @@ -1762,6 +1762,25 @@ def test_read_binary_files_s3(ray_start_regular_shared): assert item == expected +def test_sliding_window(): + arr = list(range(10)) + + # Test all windows over this iterable. + window_sizes = list(range(1, len(arr) + 1)) + for window_size in window_sizes: + windows = list(_sliding_window(arr, window_size)) + assert len(windows) == len(arr) - window_size + 1 + assert all(len(window) == window_size for window in windows) + assert all( + list(window) == arr[i:i + window_size] + for i, window in enumerate(windows)) + + # Test window size larger than iterable length. + windows = list(_sliding_window(arr, 15)) + assert len(windows) == 1 + assert list(windows[0]) == arr + + def test_iter_batches_basic(ray_start_regular_shared): df1 = pd.DataFrame({"one": [1, 2, 3], "two": [2, 3, 4]}) df2 = pd.DataFrame({"one": [4, 5, 6], "two": [5, 6, 7]}) @@ -1830,8 +1849,9 @@ def test_iter_batches_basic(ray_start_regular_shared): batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True)) # Prefetch. - for batch, df in zip( - ds.iter_batches(prefetch_blocks=1, batch_format="pandas"), dfs): + batches = list(ds.iter_batches(prefetch_blocks=1, batch_format="pandas")) + assert len(batches) == len(dfs) + for batch, df in zip(batches, dfs): assert isinstance(batch, pd.DataFrame) assert batch.equals(df) @@ -1845,6 +1865,14 @@ def test_iter_batches_basic(ray_start_regular_shared): assert pd.concat( batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True)) + # Prefetch more than number of blocks. + batches = list( + ds.iter_batches(prefetch_blocks=len(dfs), batch_format="pandas")) + assert len(batches) == len(dfs) + for batch, df in zip(batches, dfs): + assert isinstance(batch, pd.DataFrame) + assert batch.equals(df) + def test_iter_batches_grid(ray_start_regular_shared): # Tests slicing, batch combining, and partial batch dropping logic over