From 58e35a21b4e57c2e248f54746c725d7eb64866ac Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 20 Aug 2021 15:35:22 -0700 Subject: [PATCH] Add split_at_indices() (#17990) --- python/ray/data/dataset.py | 143 ++++++++++++++++++++------ python/ray/data/tests/test_dataset.py | 33 ++++++ 2 files changed, 142 insertions(+), 34 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f5952449c..2aac99055 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -375,6 +375,8 @@ class Dataset(Generic[T]): Time complexity: O(1) + See also: ``Dataset.split_at_indices`` + Args: n: Number of child datasets to return. equal: Whether to guarantee each split has an equal @@ -522,6 +524,49 @@ class Dataset(Generic[T]): for actor in locality_hints ]) + def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]: + """Split the dataset at the given indices (like np.split). + + Examples: + >>> d1, d2, d3 = ray.data.range(10).split_at_indices([2, 5]) + >>> d1.take() + [0, 1] + >>> d2.take() + [2, 3, 4] + >>> d3.take() + [5, 6, 7, 8, 9] + + Time complexity: O(num splits) + + See also: ``Dataset.split`` + + Args: + indices: List of sorted integers which indicate where the dataset + will be split. If an index exceeds the length of the dataset, + an empty dataset will be returned. + + Returns: + The dataset splits. + """ + + if len(indices) < 1: + raise ValueError("indices must be at least of length 1") + if sorted(indices) != indices: + raise ValueError("indices must be sorted") + if indices[0] < 0: + raise ValueError("indices must be positive") + + rest = self + splits = [] + prev = 0 + for i in indices: + first, rest = rest._split(i - prev, return_right_half=True) + prev = i + splits.append(first) + splits.append(rest) + + return splits + def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]": """Combine this dataset with others of the same type. @@ -617,32 +662,8 @@ class Dataset(Generic[T]): The truncated dataset. """ - get_num_rows = cached_remote_fn(_get_num_rows) - truncate = cached_remote_fn(_truncate, num_returns=2) - - count = 0 - out_blocks = [] - out_metadata = [] - for b, m in zip(self._blocks, self._blocks.get_metadata()): - if m.num_rows is None: - num_rows = ray.get(get_num_rows.remote(b)) - else: - num_rows = m.num_rows - if count + num_rows < limit: - out_blocks.append(b) - out_metadata.append(m) - elif count + num_rows == limit: - out_blocks.append(b) - out_metadata.append(m) - break - else: - new_block, new_metadata = truncate.remote(b, m, limit - count) - out_blocks.append(new_block) - out_metadata.append(ray.get(new_metadata)) - break - count += num_rows - - return Dataset(BlockList(out_blocks, out_metadata)) + left, _ = self._split(limit, return_right_half=False) + return left def take(self, limit: int = 20) -> List[T]: """Take up to the given number of records from the dataset. @@ -1460,6 +1481,48 @@ class Dataset(Generic[T]): """ return list(self._blocks) + def _split(self, index: int, + return_right_half: bool) -> ("Dataset[T]", "Dataset[T]"): + get_num_rows = cached_remote_fn(_get_num_rows) + split_block = cached_remote_fn(_split_block, num_returns=4) + + count = 0 + left_blocks = [] + left_metadata = [] + right_blocks = [] + right_metadata = [] + for b, m in zip(self._blocks, self._blocks.get_metadata()): + if m.num_rows is None: + num_rows = ray.get(get_num_rows.remote(b)) + else: + num_rows = m.num_rows + if count >= index: + if not return_right_half: + break + right_blocks.append(b) + right_metadata.append(m) + elif count + num_rows < index: + left_blocks.append(b) + left_metadata.append(m) + elif count + num_rows == index: + left_blocks.append(b) + left_metadata.append(m) + else: + b0, m0, b1, m1 = split_block.remote(b, m, index - count, + return_right_half) + left_blocks.append(b0) + left_metadata.append(ray.get(m0)) + right_blocks.append(b1) + right_metadata.append(ray.get(m1)) + count += num_rows + + left = Dataset(BlockList(left_blocks, left_metadata)) + if return_right_half: + right = Dataset(BlockList(right_blocks, right_metadata)) + else: + right = None + return left, right + def __repr__(self) -> str: schema = self.schema() if schema is None: @@ -1533,15 +1596,27 @@ def _check_is_arrow(block: Block) -> bool: return isinstance(block, pyarrow.Table) -def _truncate(block: Block, meta: BlockMetadata, - count: int) -> (Block, BlockMetadata): +def _split_block( + block: Block, meta: BlockMetadata, count: int, return_right_half: bool +) -> (Block, BlockMetadata, Optional[Block], Optional[BlockMetadata]): block = BlockAccessor.for_block(block) logger.debug("Truncating last block to size: {}".format(count)) - new_block = block.slice(0, count, copy=True) - accessor = BlockAccessor.for_block(new_block) - new_meta = BlockMetadata( - num_rows=accessor.num_rows(), - size_bytes=accessor.size_bytes(), + b0 = block.slice(0, count, copy=True) + a0 = BlockAccessor.for_block(b0) + m0 = BlockMetadata( + num_rows=a0.num_rows(), + size_bytes=a0.size_bytes(), schema=meta.schema, input_files=meta.input_files) - return new_block, new_meta + if return_right_half: + b1 = block.slice(count, block.num_rows(), copy=True) + a1 = BlockAccessor.for_block(b1) + m1 = BlockMetadata( + num_rows=a1.num_rows(), + size_bytes=a1.size_bytes(), + schema=meta.schema, + input_files=meta.input_files) + else: + b1 = None + m1 = None + return b0, m0, b1, m1 diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 86255a1ac..14c3dd454 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -856,6 +856,39 @@ def test_union(ray_start_regular_shared): assert ds2.count() == 210 +def test_split_at_indices(ray_start_regular_shared): + ds = ray.data.range(10, parallelism=3) + + with pytest.raises(ValueError): + ds.split_at_indices([]) + + with pytest.raises(ValueError): + ds.split_at_indices([-1]) + + with pytest.raises(ValueError): + ds.split_at_indices([3, 1]) + + splits = ds.split_at_indices([5]) + r = [s.take() for s in splits] + assert r == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + + splits = ds.split_at_indices([2, 5]) + r = [s.take() for s in splits] + assert r == [[0, 1], [2, 3, 4], [5, 6, 7, 8, 9]] + + splits = ds.split_at_indices([2, 5, 5, 100]) + r = [s.take() for s in splits] + assert r == [[0, 1], [2, 3, 4], [], [5, 6, 7, 8, 9], []] + + splits = ds.split_at_indices([100]) + r = [s.take() for s in splits] + assert r == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], []] + + splits = ds.split_at_indices([0]) + r = [s.take() for s in splits] + assert r == [[], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + + def test_split(ray_start_regular_shared): ds = ray.data.range(20, parallelism=10) assert ds.num_blocks() == 10