mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Datasets] Fix empty Dataset.iter_batches() when trying to prefetch more blocks than exist in the dataset (#20480)
Before this PR, `ds.iter_batches()` would yield no batches if `prefetch_blocks > ds.num_blocks()` was given, since the sliding window semantics were to return no windows if `window_size > len(iterable)`. This PR tweaks the sliding window implementation to always return at least one window, even if the one window is smaller than the given window size.
This commit is contained in:
parent
add2450b92
commit
462e389791
2 changed files with 61 additions and 25 deletions
|
@ -1832,26 +1832,6 @@ class Dataset(Generic[T]):
|
|||
A list of iterators over record batches.
|
||||
"""
|
||||
|
||||
def sliding_window(iterable: Iterable, n: int):
|
||||
"""Creates an iterator consisting of n-width sliding windows over
|
||||
iterable. The sliding windows are constructed lazily such that an
|
||||
element on the base iterator (iterable) isn't consumed until the
|
||||
first sliding window containing that element is reached.
|
||||
|
||||
Args:
|
||||
iterable: The iterable on which the sliding window will be
|
||||
created.
|
||||
n: The width of the sliding window.
|
||||
|
||||
Returns:
|
||||
An iterator of n-width windows over iterable.
|
||||
"""
|
||||
iters = itertools.tee(iter(iterable), n)
|
||||
for i in range(1, n):
|
||||
for it in iters[i:]:
|
||||
next(it, None)
|
||||
return zip(*iters)
|
||||
|
||||
def format_batch(batch: Block, format: str) -> BatchType:
|
||||
if batch_format == "native":
|
||||
return batch
|
||||
|
@ -1875,8 +1855,8 @@ class Dataset(Generic[T]):
|
|||
yield format_batch(batcher.next_batch(), batch_format)
|
||||
|
||||
block_window = [] # Handle empty sliding window gracefully.
|
||||
for block_window in sliding_window(self._blocks.iter_blocks(),
|
||||
prefetch_blocks + 1):
|
||||
for block_window in _sliding_window(self._blocks.iter_blocks(),
|
||||
prefetch_blocks + 1):
|
||||
block_window = list(block_window)
|
||||
ray.wait(block_window, num_returns=1, fetch_local=True)
|
||||
yield from batch_block(block_window[0])
|
||||
|
@ -2511,6 +2491,34 @@ def _block_to_arrow(block: Block):
|
|||
return block.to_arrow()
|
||||
|
||||
|
||||
def _sliding_window(iterable: Iterable, n: int):
|
||||
"""Creates an iterator consisting of n-width sliding windows over
|
||||
iterable. The sliding windows are constructed lazily such that an
|
||||
element on the base iterator (iterable) isn't consumed until the
|
||||
first sliding window containing that element is reached.
|
||||
|
||||
If n > len(iterable), then a single len(iterable) window is
|
||||
returned.
|
||||
|
||||
Args:
|
||||
iterable: The iterable on which the sliding window will be
|
||||
created.
|
||||
n: The width of the sliding window.
|
||||
|
||||
Returns:
|
||||
An iterator of n-width windows over iterable.
|
||||
If n > len(iterable), then a single len(iterable) window is
|
||||
returned.
|
||||
"""
|
||||
it = iter(iterable)
|
||||
window = collections.deque(itertools.islice(it, n), maxlen=n)
|
||||
if len(window) > 0:
|
||||
yield tuple(window)
|
||||
for elem in it:
|
||||
window.append(elem)
|
||||
yield tuple(window)
|
||||
|
||||
|
||||
def _split_block(
|
||||
block: Block, meta: BlockMetadata, count: int, return_right_half: bool
|
||||
) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]):
|
||||
|
|
|
@ -17,7 +17,7 @@ from pytest_lazyfixture import lazy_fixture
|
|||
import ray
|
||||
|
||||
from ray.tests.conftest import * # noqa
|
||||
from ray.data.dataset import Dataset
|
||||
from ray.data.dataset import Dataset, _sliding_window
|
||||
from ray.data.datasource import DummyOutputDatasource
|
||||
from ray.data.datasource.csv_datasource import CSVDatasource
|
||||
from ray.data.block import BlockAccessor
|
||||
|
@ -1762,6 +1762,25 @@ def test_read_binary_files_s3(ray_start_regular_shared):
|
|||
assert item == expected
|
||||
|
||||
|
||||
def test_sliding_window():
|
||||
arr = list(range(10))
|
||||
|
||||
# Test all windows over this iterable.
|
||||
window_sizes = list(range(1, len(arr) + 1))
|
||||
for window_size in window_sizes:
|
||||
windows = list(_sliding_window(arr, window_size))
|
||||
assert len(windows) == len(arr) - window_size + 1
|
||||
assert all(len(window) == window_size for window in windows)
|
||||
assert all(
|
||||
list(window) == arr[i:i + window_size]
|
||||
for i, window in enumerate(windows))
|
||||
|
||||
# Test window size larger than iterable length.
|
||||
windows = list(_sliding_window(arr, 15))
|
||||
assert len(windows) == 1
|
||||
assert list(windows[0]) == arr
|
||||
|
||||
|
||||
def test_iter_batches_basic(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({"one": [1, 2, 3], "two": [2, 3, 4]})
|
||||
df2 = pd.DataFrame({"one": [4, 5, 6], "two": [5, 6, 7]})
|
||||
|
@ -1830,8 +1849,9 @@ def test_iter_batches_basic(ray_start_regular_shared):
|
|||
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))
|
||||
|
||||
# Prefetch.
|
||||
for batch, df in zip(
|
||||
ds.iter_batches(prefetch_blocks=1, batch_format="pandas"), dfs):
|
||||
batches = list(ds.iter_batches(prefetch_blocks=1, batch_format="pandas"))
|
||||
assert len(batches) == len(dfs)
|
||||
for batch, df in zip(batches, dfs):
|
||||
assert isinstance(batch, pd.DataFrame)
|
||||
assert batch.equals(df)
|
||||
|
||||
|
@ -1845,6 +1865,14 @@ def test_iter_batches_basic(ray_start_regular_shared):
|
|||
assert pd.concat(
|
||||
batches, ignore_index=True).equals(pd.concat(dfs, ignore_index=True))
|
||||
|
||||
# Prefetch more than number of blocks.
|
||||
batches = list(
|
||||
ds.iter_batches(prefetch_blocks=len(dfs), batch_format="pandas"))
|
||||
assert len(batches) == len(dfs)
|
||||
for batch, df in zip(batches, dfs):
|
||||
assert isinstance(batch, pd.DataFrame)
|
||||
assert batch.equals(df)
|
||||
|
||||
|
||||
def test_iter_batches_grid(ray_start_regular_shared):
|
||||
# Tests slicing, batch combining, and partial batch dropping logic over
|
||||
|
|
Loading…
Add table
Reference in a new issue