Implement zip() function for dataset (#18833)

This commit is contained in:
Eric Liang 2021-09-23 00:12:29 -07:00 committed by GitHub
parent a96dbd885b
commit 2c15215833
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 0 deletions

View file

@ -109,6 +109,10 @@ class BlockAccessor(Generic[T]):
schema=self.schema(),
input_files=input_files)
def zip(self, other: "Block[T]") -> "Block[T]":
"""Zip this block with another block of the same type and size."""
raise NotImplementedError
@staticmethod
def builder() -> "BlockBuilder[T]":
"""Create a builder for this block type."""

View file

@ -655,6 +655,57 @@ class Dataset(Generic[T]):
"""
return Dataset(sort_impl(self._blocks, key, descending))
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
"""Zip this dataset with the elements of another.
The datasets must have identical num rows, block types, and block sizes
(e.g., one was produced from a ``.map()`` of another). For Arrow
blocks, the schema will be concatenated, and any duplicate column
names disambiguated with _1, _2, etc. suffixes.
Time complexity: O(dataset size / parallelism)
Args:
other: The dataset to zip with on the right hand side.
Examples:
>>> ds = ray.data.range(5)
>>> ds.zip(ds).take()
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
Returns:
A Dataset with (k, v) pairs (or concatenated Arrow schema) where k
comes from the first dataset and v comes from the second.
"""
blocks1 = self.get_blocks()
blocks2 = other.get_blocks()
if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)))
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[])
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
return Dataset(BlockList(blocks, metadata))
def limit(self, limit: int) -> "Dataset[T]":
"""Limit the dataset to the first number of records specified.

View file

@ -200,6 +200,30 @@ class ArrowBlockAccessor(BlockAccessor):
def size_bytes(self) -> int:
return self._table.nbytes
def zip(self, other: "Block[T]") -> "Block[T]":
acc = BlockAccessor.for_block(other)
if not isinstance(acc, ArrowBlockAccessor):
raise ValueError("Cannot zip {} with block of type {}".format(
type(self), type(other)))
if acc.num_rows() != self.num_rows():
raise ValueError(
"Cannot zip self (length {}) with block of length {}".format(
self.num_rows(), acc.num_rows()))
r = self.to_arrow()
s = acc.to_arrow()
for col_name in s.column_names:
col = s.column(col_name)
# Ensure the column names are unique after zip.
if col_name in r.column_names:
i = 1
new_name = col_name
while new_name in r.column_names:
new_name = "{}_{}".format(col_name, i)
i += 1
col_name = new_name
r = r.append_column(col_name, col)
return r
@staticmethod
def builder() -> ArrowBlockBuilder[T]:
return ArrowBlockBuilder()

View file

@ -74,6 +74,16 @@ class SimpleBlockAccessor(BlockAccessor):
else:
return None
def zip(self, other: "Block[T]") -> "Block[T]":
if not isinstance(other, list):
raise ValueError("Cannot zip {} with block of type {}".format(
type(self), type(other)))
if len(other) != len(self._items):
raise ValueError(
"Cannot zip self (length {}) with block of length {}".format(
len(self), len(other)))
return list(zip(self._items, other))
@staticmethod
def builder() -> SimpleBlockBuilder[T]:
return SimpleBlockBuilder()

View file

@ -153,6 +153,34 @@ def test_basic(ray_start_regular_shared, pipelined):
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
def test_zip(ray_start_regular_shared):
ds1 = ray.data.range(5)
ds2 = ray.data.range(5).map(lambda x: x + 1)
ds = ds1.zip(ds2)
assert ds.schema() == tuple
assert ds.take() == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
with pytest.raises(ValueError):
ds.zip(ray.data.range(3))
def test_zip_arrow(ray_start_regular_shared):
ds1 = ray.data.range_arrow(5).map(lambda r: {"id": r["value"]})
ds2 = ray.data.range_arrow(5).map(
lambda r: {"a": r["value"] + 1, "b": r["value"] + 2})
ds = ds1.zip(ds2)
assert "{id: int64, a: int64, b: int64}" in str(ds)
assert ds.count() == 5
result = [r.as_pydict() for r in ds.take()]
assert result[0] == {"id": 0, "a": 1, "b": 2}
# Test duplicate column names.
ds = ds1.zip(ds1).zip(ds1)
assert ds.count() == 5
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds)
result = [r.as_pydict() for r in ds.take()]
assert result[0] == {"id": 0, "id_1": 0, "id_2": 0}
def test_batch_tensors(ray_start_regular_shared):
import torch
ds = ray.data.from_items([torch.tensor([0, 0]) for _ in range(40)])