From 2c152158337975082006947804743d2742402dd6 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 23 Sep 2021 00:12:29 -0700 Subject: [PATCH] Implement zip() function for dataset (#18833) --- python/ray/data/block.py | 4 +++ python/ray/data/dataset.py | 51 +++++++++++++++++++++++++++ python/ray/data/impl/arrow_block.py | 24 +++++++++++++ python/ray/data/impl/simple_block.py | 10 ++++++ python/ray/data/tests/test_dataset.py | 28 +++++++++++++++ 5 files changed, 117 insertions(+) diff --git a/python/ray/data/block.py b/python/ray/data/block.py index a9b570115..35b99780c 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -109,6 +109,10 @@ class BlockAccessor(Generic[T]): schema=self.schema(), input_files=input_files) + def zip(self, other: "Block[T]") -> "Block[T]": + """Zip this block with another block of the same type and size.""" + raise NotImplementedError + @staticmethod def builder() -> "BlockBuilder[T]": """Create a builder for this block type.""" diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3b75d303e..11d0a13c9 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -655,6 +655,57 @@ class Dataset(Generic[T]): """ return Dataset(sort_impl(self._blocks, key, descending)) + def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]": + """Zip this dataset with the elements of another. + + The datasets must have identical num rows, block types, and block sizes + (e.g., one was produced from a ``.map()`` of another). For Arrow + blocks, the schema will be concatenated, and any duplicate column + names disambiguated with _1, _2, etc. suffixes. + + Time complexity: O(dataset size / parallelism) + + Args: + other: The dataset to zip with on the right hand side. + + Examples: + >>> ds = ray.data.range(5) + >>> ds.zip(ds).take() + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + + Returns: + A Dataset with (k, v) pairs (or concatenated Arrow schema) where k + comes from the first dataset and v comes from the second. + """ + + blocks1 = self.get_blocks() + blocks2 = other.get_blocks() + + if len(blocks1) != len(blocks2): + # TODO(ekl) consider supporting if num_rows are equal. + raise ValueError( + "Cannot zip dataset of different num blocks: {} vs {}".format( + len(blocks1), len(blocks2))) + + def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata): + b1 = BlockAccessor.for_block(block1) + result = b1.zip(block2) + br = BlockAccessor.for_block(result) + return result, br.get_metadata(input_files=[]) + + do_zip_fn = cached_remote_fn(do_zip, num_returns=2) + + blocks = [] + metadata = [] + for b1, b2 in zip(blocks1, blocks2): + res, meta = do_zip_fn.remote(b1, b2) + blocks.append(res) + metadata.append(meta) + + # TODO(ekl) it might be nice to have a progress bar here. + metadata = ray.get(metadata) + return Dataset(BlockList(blocks, metadata)) + def limit(self, limit: int) -> "Dataset[T]": """Limit the dataset to the first number of records specified. diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index a01d8f499..41c5875bb 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -200,6 +200,30 @@ class ArrowBlockAccessor(BlockAccessor): def size_bytes(self) -> int: return self._table.nbytes + def zip(self, other: "Block[T]") -> "Block[T]": + acc = BlockAccessor.for_block(other) + if not isinstance(acc, ArrowBlockAccessor): + raise ValueError("Cannot zip {} with block of type {}".format( + type(self), type(other))) + if acc.num_rows() != self.num_rows(): + raise ValueError( + "Cannot zip self (length {}) with block of length {}".format( + self.num_rows(), acc.num_rows())) + r = self.to_arrow() + s = acc.to_arrow() + for col_name in s.column_names: + col = s.column(col_name) + # Ensure the column names are unique after zip. + if col_name in r.column_names: + i = 1 + new_name = col_name + while new_name in r.column_names: + new_name = "{}_{}".format(col_name, i) + i += 1 + col_name = new_name + r = r.append_column(col_name, col) + return r + @staticmethod def builder() -> ArrowBlockBuilder[T]: return ArrowBlockBuilder() diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index 8d0131a0c..ba20d1334 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -74,6 +74,16 @@ class SimpleBlockAccessor(BlockAccessor): else: return None + def zip(self, other: "Block[T]") -> "Block[T]": + if not isinstance(other, list): + raise ValueError("Cannot zip {} with block of type {}".format( + type(self), type(other))) + if len(other) != len(self._items): + raise ValueError( + "Cannot zip self (length {}) with block of length {}".format( + len(self), len(other))) + return list(zip(self._items, other)) + @staticmethod def builder() -> SimpleBlockBuilder[T]: return SimpleBlockBuilder() diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index c01641dd8..7562e2c5a 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -153,6 +153,34 @@ def test_basic(ray_start_regular_shared, pipelined): assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4] +def test_zip(ray_start_regular_shared): + ds1 = ray.data.range(5) + ds2 = ray.data.range(5).map(lambda x: x + 1) + ds = ds1.zip(ds2) + assert ds.schema() == tuple + assert ds.take() == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)] + with pytest.raises(ValueError): + ds.zip(ray.data.range(3)) + + +def test_zip_arrow(ray_start_regular_shared): + ds1 = ray.data.range_arrow(5).map(lambda r: {"id": r["value"]}) + ds2 = ray.data.range_arrow(5).map( + lambda r: {"a": r["value"] + 1, "b": r["value"] + 2}) + ds = ds1.zip(ds2) + assert "{id: int64, a: int64, b: int64}" in str(ds) + assert ds.count() == 5 + result = [r.as_pydict() for r in ds.take()] + assert result[0] == {"id": 0, "a": 1, "b": 2} + + # Test duplicate column names. + ds = ds1.zip(ds1).zip(ds1) + assert ds.count() == 5 + assert "{id: int64, id_1: int64, id_2: int64}" in str(ds) + result = [r.as_pydict() for r in ds.take()] + assert result[0] == {"id": 0, "id_1": 0, "id_2": 0} + + def test_batch_tensors(ray_start_regular_shared): import torch ds = ray.data.from_items([torch.tensor([0, 0]) for _ in range(40)])