Add split_at_indices() (#17990)

This commit is contained in:
Eric Liang 2021-08-20 15:35:22 -07:00 committed by GitHub
parent 05502da271
commit 58e35a21b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 34 deletions

View file

@ -375,6 +375,8 @@ class Dataset(Generic[T]):
Time complexity: O(1)
See also: ``Dataset.split_at_indices``
Args:
n: Number of child datasets to return.
equal: Whether to guarantee each split has an equal
@ -522,6 +524,49 @@ class Dataset(Generic[T]):
for actor in locality_hints
])
def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
"""Split the dataset at the given indices (like np.split).
Examples:
>>> d1, d2, d3 = ray.data.range(10).split_at_indices([2, 5])
>>> d1.take()
[0, 1]
>>> d2.take()
[2, 3, 4]
>>> d3.take()
[5, 6, 7, 8, 9]
Time complexity: O(num splits)
See also: ``Dataset.split``
Args:
indices: List of sorted integers which indicate where the dataset
will be split. If an index exceeds the length of the dataset,
an empty dataset will be returned.
Returns:
The dataset splits.
"""
if len(indices) < 1:
raise ValueError("indices must be at least of length 1")
if sorted(indices) != indices:
raise ValueError("indices must be sorted")
if indices[0] < 0:
raise ValueError("indices must be positive")
rest = self
splits = []
prev = 0
for i in indices:
first, rest = rest._split(i - prev, return_right_half=True)
prev = i
splits.append(first)
splits.append(rest)
return splits
def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
"""Combine this dataset with others of the same type.
@ -617,32 +662,8 @@ class Dataset(Generic[T]):
The truncated dataset.
"""
get_num_rows = cached_remote_fn(_get_num_rows)
truncate = cached_remote_fn(_truncate, num_returns=2)
count = 0
out_blocks = []
out_metadata = []
for b, m in zip(self._blocks, self._blocks.get_metadata()):
if m.num_rows is None:
num_rows = ray.get(get_num_rows.remote(b))
else:
num_rows = m.num_rows
if count + num_rows < limit:
out_blocks.append(b)
out_metadata.append(m)
elif count + num_rows == limit:
out_blocks.append(b)
out_metadata.append(m)
break
else:
new_block, new_metadata = truncate.remote(b, m, limit - count)
out_blocks.append(new_block)
out_metadata.append(ray.get(new_metadata))
break
count += num_rows
return Dataset(BlockList(out_blocks, out_metadata))
left, _ = self._split(limit, return_right_half=False)
return left
def take(self, limit: int = 20) -> List[T]:
"""Take up to the given number of records from the dataset.
@ -1460,6 +1481,48 @@ class Dataset(Generic[T]):
"""
return list(self._blocks)
def _split(self, index: int,
return_right_half: bool) -> ("Dataset[T]", "Dataset[T]"):
get_num_rows = cached_remote_fn(_get_num_rows)
split_block = cached_remote_fn(_split_block, num_returns=4)
count = 0
left_blocks = []
left_metadata = []
right_blocks = []
right_metadata = []
for b, m in zip(self._blocks, self._blocks.get_metadata()):
if m.num_rows is None:
num_rows = ray.get(get_num_rows.remote(b))
else:
num_rows = m.num_rows
if count >= index:
if not return_right_half:
break
right_blocks.append(b)
right_metadata.append(m)
elif count + num_rows < index:
left_blocks.append(b)
left_metadata.append(m)
elif count + num_rows == index:
left_blocks.append(b)
left_metadata.append(m)
else:
b0, m0, b1, m1 = split_block.remote(b, m, index - count,
return_right_half)
left_blocks.append(b0)
left_metadata.append(ray.get(m0))
right_blocks.append(b1)
right_metadata.append(ray.get(m1))
count += num_rows
left = Dataset(BlockList(left_blocks, left_metadata))
if return_right_half:
right = Dataset(BlockList(right_blocks, right_metadata))
else:
right = None
return left, right
def __repr__(self) -> str:
schema = self.schema()
if schema is None:
@ -1533,15 +1596,27 @@ def _check_is_arrow(block: Block) -> bool:
return isinstance(block, pyarrow.Table)
def _truncate(block: Block, meta: BlockMetadata,
count: int) -> (Block, BlockMetadata):
def _split_block(
block: Block, meta: BlockMetadata, count: int, return_right_half: bool
) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]):
block = BlockAccessor.for_block(block)
logger.debug("Truncating last block to size: {}".format(count))
new_block = block.slice(0, count, copy=True)
accessor = BlockAccessor.for_block(new_block)
new_meta = BlockMetadata(
num_rows=accessor.num_rows(),
size_bytes=accessor.size_bytes(),
b0 = block.slice(0, count, copy=True)
a0 = BlockAccessor.for_block(b0)
m0 = BlockMetadata(
num_rows=a0.num_rows(),
size_bytes=a0.size_bytes(),
schema=meta.schema,
input_files=meta.input_files)
return new_block, new_meta
if return_right_half:
b1 = block.slice(count, block.num_rows(), copy=True)
a1 = BlockAccessor.for_block(b1)
m1 = BlockMetadata(
num_rows=a1.num_rows(),
size_bytes=a1.size_bytes(),
schema=meta.schema,
input_files=meta.input_files)
else:
b1 = None
m1 = None
return b0, m0, b1, m1

View file

@ -856,6 +856,39 @@ def test_union(ray_start_regular_shared):
assert ds2.count() == 210
def test_split_at_indices(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=3)
with pytest.raises(ValueError):
ds.split_at_indices([])
with pytest.raises(ValueError):
ds.split_at_indices([-1])
with pytest.raises(ValueError):
ds.split_at_indices([3, 1])
splits = ds.split_at_indices([5])
r = [s.take() for s in splits]
assert r == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
splits = ds.split_at_indices([2, 5])
r = [s.take() for s in splits]
assert r == [[0, 1], [2, 3, 4], [5, 6, 7, 8, 9]]
splits = ds.split_at_indices([2, 5, 5, 100])
r = [s.take() for s in splits]
assert r == [[0, 1], [2, 3, 4], [], [5, 6, 7, 8, 9], []]
splits = ds.split_at_indices([100])
r = [s.take() for s in splits]
assert r == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], []]
splits = ds.split_at_indices([0])
r = [s.take() for s in splits]
assert r == [[], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
def test_split(ray_start_regular_shared):
ds = ray.data.range(20, parallelism=10)
assert ds.num_blocks() == 10