mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Add split_at_indices() (#17990)
This commit is contained in:
parent
05502da271
commit
58e35a21b4
2 changed files with 142 additions and 34 deletions
|
@ -375,6 +375,8 @@ class Dataset(Generic[T]):
|
|||
|
||||
Time complexity: O(1)
|
||||
|
||||
See also: ``Dataset.split_at_indices``
|
||||
|
||||
Args:
|
||||
n: Number of child datasets to return.
|
||||
equal: Whether to guarantee each split has an equal
|
||||
|
@ -522,6 +524,49 @@ class Dataset(Generic[T]):
|
|||
for actor in locality_hints
|
||||
])
|
||||
|
||||
def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
|
||||
"""Split the dataset at the given indices (like np.split).
|
||||
|
||||
Examples:
|
||||
>>> d1, d2, d3 = ray.data.range(10).split_at_indices([2, 5])
|
||||
>>> d1.take()
|
||||
[0, 1]
|
||||
>>> d2.take()
|
||||
[2, 3, 4]
|
||||
>>> d3.take()
|
||||
[5, 6, 7, 8, 9]
|
||||
|
||||
Time complexity: O(num splits)
|
||||
|
||||
See also: ``Dataset.split``
|
||||
|
||||
Args:
|
||||
indices: List of sorted integers which indicate where the dataset
|
||||
will be split. If an index exceeds the length of the dataset,
|
||||
an empty dataset will be returned.
|
||||
|
||||
Returns:
|
||||
The dataset splits.
|
||||
"""
|
||||
|
||||
if len(indices) < 1:
|
||||
raise ValueError("indices must be at least of length 1")
|
||||
if sorted(indices) != indices:
|
||||
raise ValueError("indices must be sorted")
|
||||
if indices[0] < 0:
|
||||
raise ValueError("indices must be positive")
|
||||
|
||||
rest = self
|
||||
splits = []
|
||||
prev = 0
|
||||
for i in indices:
|
||||
first, rest = rest._split(i - prev, return_right_half=True)
|
||||
prev = i
|
||||
splits.append(first)
|
||||
splits.append(rest)
|
||||
|
||||
return splits
|
||||
|
||||
def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
|
||||
"""Combine this dataset with others of the same type.
|
||||
|
||||
|
@ -617,32 +662,8 @@ class Dataset(Generic[T]):
|
|||
The truncated dataset.
|
||||
"""
|
||||
|
||||
get_num_rows = cached_remote_fn(_get_num_rows)
|
||||
truncate = cached_remote_fn(_truncate, num_returns=2)
|
||||
|
||||
count = 0
|
||||
out_blocks = []
|
||||
out_metadata = []
|
||||
for b, m in zip(self._blocks, self._blocks.get_metadata()):
|
||||
if m.num_rows is None:
|
||||
num_rows = ray.get(get_num_rows.remote(b))
|
||||
else:
|
||||
num_rows = m.num_rows
|
||||
if count + num_rows < limit:
|
||||
out_blocks.append(b)
|
||||
out_metadata.append(m)
|
||||
elif count + num_rows == limit:
|
||||
out_blocks.append(b)
|
||||
out_metadata.append(m)
|
||||
break
|
||||
else:
|
||||
new_block, new_metadata = truncate.remote(b, m, limit - count)
|
||||
out_blocks.append(new_block)
|
||||
out_metadata.append(ray.get(new_metadata))
|
||||
break
|
||||
count += num_rows
|
||||
|
||||
return Dataset(BlockList(out_blocks, out_metadata))
|
||||
left, _ = self._split(limit, return_right_half=False)
|
||||
return left
|
||||
|
||||
def take(self, limit: int = 20) -> List[T]:
|
||||
"""Take up to the given number of records from the dataset.
|
||||
|
@ -1460,6 +1481,48 @@ class Dataset(Generic[T]):
|
|||
"""
|
||||
return list(self._blocks)
|
||||
|
||||
def _split(self, index: int,
|
||||
return_right_half: bool) -> ("Dataset[T]", "Dataset[T]"):
|
||||
get_num_rows = cached_remote_fn(_get_num_rows)
|
||||
split_block = cached_remote_fn(_split_block, num_returns=4)
|
||||
|
||||
count = 0
|
||||
left_blocks = []
|
||||
left_metadata = []
|
||||
right_blocks = []
|
||||
right_metadata = []
|
||||
for b, m in zip(self._blocks, self._blocks.get_metadata()):
|
||||
if m.num_rows is None:
|
||||
num_rows = ray.get(get_num_rows.remote(b))
|
||||
else:
|
||||
num_rows = m.num_rows
|
||||
if count >= index:
|
||||
if not return_right_half:
|
||||
break
|
||||
right_blocks.append(b)
|
||||
right_metadata.append(m)
|
||||
elif count + num_rows < index:
|
||||
left_blocks.append(b)
|
||||
left_metadata.append(m)
|
||||
elif count + num_rows == index:
|
||||
left_blocks.append(b)
|
||||
left_metadata.append(m)
|
||||
else:
|
||||
b0, m0, b1, m1 = split_block.remote(b, m, index - count,
|
||||
return_right_half)
|
||||
left_blocks.append(b0)
|
||||
left_metadata.append(ray.get(m0))
|
||||
right_blocks.append(b1)
|
||||
right_metadata.append(ray.get(m1))
|
||||
count += num_rows
|
||||
|
||||
left = Dataset(BlockList(left_blocks, left_metadata))
|
||||
if return_right_half:
|
||||
right = Dataset(BlockList(right_blocks, right_metadata))
|
||||
else:
|
||||
right = None
|
||||
return left, right
|
||||
|
||||
def __repr__(self) -> str:
|
||||
schema = self.schema()
|
||||
if schema is None:
|
||||
|
@ -1533,15 +1596,27 @@ def _check_is_arrow(block: Block) -> bool:
|
|||
return isinstance(block, pyarrow.Table)
|
||||
|
||||
|
||||
def _truncate(block: Block, meta: BlockMetadata,
|
||||
count: int) -> (Block, BlockMetadata):
|
||||
def _split_block(
|
||||
block: Block, meta: BlockMetadata, count: int, return_right_half: bool
|
||||
) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]):
|
||||
block = BlockAccessor.for_block(block)
|
||||
logger.debug("Truncating last block to size: {}".format(count))
|
||||
new_block = block.slice(0, count, copy=True)
|
||||
accessor = BlockAccessor.for_block(new_block)
|
||||
new_meta = BlockMetadata(
|
||||
num_rows=accessor.num_rows(),
|
||||
size_bytes=accessor.size_bytes(),
|
||||
b0 = block.slice(0, count, copy=True)
|
||||
a0 = BlockAccessor.for_block(b0)
|
||||
m0 = BlockMetadata(
|
||||
num_rows=a0.num_rows(),
|
||||
size_bytes=a0.size_bytes(),
|
||||
schema=meta.schema,
|
||||
input_files=meta.input_files)
|
||||
return new_block, new_meta
|
||||
if return_right_half:
|
||||
b1 = block.slice(count, block.num_rows(), copy=True)
|
||||
a1 = BlockAccessor.for_block(b1)
|
||||
m1 = BlockMetadata(
|
||||
num_rows=a1.num_rows(),
|
||||
size_bytes=a1.size_bytes(),
|
||||
schema=meta.schema,
|
||||
input_files=meta.input_files)
|
||||
else:
|
||||
b1 = None
|
||||
m1 = None
|
||||
return b0, m0, b1, m1
|
||||
|
|
|
@ -856,6 +856,39 @@ def test_union(ray_start_regular_shared):
|
|||
assert ds2.count() == 210
|
||||
|
||||
|
||||
def test_split_at_indices(ray_start_regular_shared):
|
||||
ds = ray.data.range(10, parallelism=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ds.split_at_indices([])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ds.split_at_indices([-1])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ds.split_at_indices([3, 1])
|
||||
|
||||
splits = ds.split_at_indices([5])
|
||||
r = [s.take() for s in splits]
|
||||
assert r == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
|
||||
|
||||
splits = ds.split_at_indices([2, 5])
|
||||
r = [s.take() for s in splits]
|
||||
assert r == [[0, 1], [2, 3, 4], [5, 6, 7, 8, 9]]
|
||||
|
||||
splits = ds.split_at_indices([2, 5, 5, 100])
|
||||
r = [s.take() for s in splits]
|
||||
assert r == [[0, 1], [2, 3, 4], [], [5, 6, 7, 8, 9], []]
|
||||
|
||||
splits = ds.split_at_indices([100])
|
||||
r = [s.take() for s in splits]
|
||||
assert r == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], []]
|
||||
|
||||
splits = ds.split_at_indices([0])
|
||||
r = [s.take() for s in splits]
|
||||
assert r == [[], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
|
||||
|
||||
|
||||
def test_split(ray_start_regular_shared):
|
||||
ds = ray.data.range(20, parallelism=10)
|
||||
assert ds.num_blocks() == 10
|
||||
|
|
Loading…
Add table
Reference in a new issue