mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Implement zip() function for dataset (#18833)
This commit is contained in:
parent
a96dbd885b
commit
2c15215833
5 changed files with 117 additions and 0 deletions
|
@ -109,6 +109,10 @@ class BlockAccessor(Generic[T]):
|
|||
schema=self.schema(),
|
||||
input_files=input_files)
|
||||
|
||||
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||
"""Zip this block with another block of the same type and size."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def builder() -> "BlockBuilder[T]":
|
||||
"""Create a builder for this block type."""
|
||||
|
|
|
@ -655,6 +655,57 @@ class Dataset(Generic[T]):
|
|||
"""
|
||||
return Dataset(sort_impl(self._blocks, key, descending))
|
||||
|
||||
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
|
||||
"""Zip this dataset with the elements of another.
|
||||
|
||||
The datasets must have identical num rows, block types, and block sizes
|
||||
(e.g., one was produced from a ``.map()`` of another). For Arrow
|
||||
blocks, the schema will be concatenated, and any duplicate column
|
||||
names disambiguated with _1, _2, etc. suffixes.
|
||||
|
||||
Time complexity: O(dataset size / parallelism)
|
||||
|
||||
Args:
|
||||
other: The dataset to zip with on the right hand side.
|
||||
|
||||
Examples:
|
||||
>>> ds = ray.data.range(5)
|
||||
>>> ds.zip(ds).take()
|
||||
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
||||
|
||||
Returns:
|
||||
A Dataset with (k, v) pairs (or concatenated Arrow schema) where k
|
||||
comes from the first dataset and v comes from the second.
|
||||
"""
|
||||
|
||||
blocks1 = self.get_blocks()
|
||||
blocks2 = other.get_blocks()
|
||||
|
||||
if len(blocks1) != len(blocks2):
|
||||
# TODO(ekl) consider supporting if num_rows are equal.
|
||||
raise ValueError(
|
||||
"Cannot zip dataset of different num blocks: {} vs {}".format(
|
||||
len(blocks1), len(blocks2)))
|
||||
|
||||
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
|
||||
b1 = BlockAccessor.for_block(block1)
|
||||
result = b1.zip(block2)
|
||||
br = BlockAccessor.for_block(result)
|
||||
return result, br.get_metadata(input_files=[])
|
||||
|
||||
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
|
||||
|
||||
blocks = []
|
||||
metadata = []
|
||||
for b1, b2 in zip(blocks1, blocks2):
|
||||
res, meta = do_zip_fn.remote(b1, b2)
|
||||
blocks.append(res)
|
||||
metadata.append(meta)
|
||||
|
||||
# TODO(ekl) it might be nice to have a progress bar here.
|
||||
metadata = ray.get(metadata)
|
||||
return Dataset(BlockList(blocks, metadata))
|
||||
|
||||
def limit(self, limit: int) -> "Dataset[T]":
|
||||
"""Limit the dataset to the first number of records specified.
|
||||
|
||||
|
|
|
@ -200,6 +200,30 @@ class ArrowBlockAccessor(BlockAccessor):
|
|||
def size_bytes(self) -> int:
|
||||
return self._table.nbytes
|
||||
|
||||
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||
acc = BlockAccessor.for_block(other)
|
||||
if not isinstance(acc, ArrowBlockAccessor):
|
||||
raise ValueError("Cannot zip {} with block of type {}".format(
|
||||
type(self), type(other)))
|
||||
if acc.num_rows() != self.num_rows():
|
||||
raise ValueError(
|
||||
"Cannot zip self (length {}) with block of length {}".format(
|
||||
self.num_rows(), acc.num_rows()))
|
||||
r = self.to_arrow()
|
||||
s = acc.to_arrow()
|
||||
for col_name in s.column_names:
|
||||
col = s.column(col_name)
|
||||
# Ensure the column names are unique after zip.
|
||||
if col_name in r.column_names:
|
||||
i = 1
|
||||
new_name = col_name
|
||||
while new_name in r.column_names:
|
||||
new_name = "{}_{}".format(col_name, i)
|
||||
i += 1
|
||||
col_name = new_name
|
||||
r = r.append_column(col_name, col)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def builder() -> ArrowBlockBuilder[T]:
|
||||
return ArrowBlockBuilder()
|
||||
|
|
|
@ -74,6 +74,16 @@ class SimpleBlockAccessor(BlockAccessor):
|
|||
else:
|
||||
return None
|
||||
|
||||
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||
if not isinstance(other, list):
|
||||
raise ValueError("Cannot zip {} with block of type {}".format(
|
||||
type(self), type(other)))
|
||||
if len(other) != len(self._items):
|
||||
raise ValueError(
|
||||
"Cannot zip self (length {}) with block of length {}".format(
|
||||
len(self), len(other)))
|
||||
return list(zip(self._items, other))
|
||||
|
||||
@staticmethod
|
||||
def builder() -> SimpleBlockBuilder[T]:
|
||||
return SimpleBlockBuilder()
|
||||
|
|
|
@ -153,6 +153,34 @@ def test_basic(ray_start_regular_shared, pipelined):
|
|||
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_zip(ray_start_regular_shared):
|
||||
ds1 = ray.data.range(5)
|
||||
ds2 = ray.data.range(5).map(lambda x: x + 1)
|
||||
ds = ds1.zip(ds2)
|
||||
assert ds.schema() == tuple
|
||||
assert ds.take() == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
|
||||
with pytest.raises(ValueError):
|
||||
ds.zip(ray.data.range(3))
|
||||
|
||||
|
||||
def test_zip_arrow(ray_start_regular_shared):
|
||||
ds1 = ray.data.range_arrow(5).map(lambda r: {"id": r["value"]})
|
||||
ds2 = ray.data.range_arrow(5).map(
|
||||
lambda r: {"a": r["value"] + 1, "b": r["value"] + 2})
|
||||
ds = ds1.zip(ds2)
|
||||
assert "{id: int64, a: int64, b: int64}" in str(ds)
|
||||
assert ds.count() == 5
|
||||
result = [r.as_pydict() for r in ds.take()]
|
||||
assert result[0] == {"id": 0, "a": 1, "b": 2}
|
||||
|
||||
# Test duplicate column names.
|
||||
ds = ds1.zip(ds1).zip(ds1)
|
||||
assert ds.count() == 5
|
||||
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds)
|
||||
result = [r.as_pydict() for r in ds.take()]
|
||||
assert result[0] == {"id": 0, "id_1": 0, "id_2": 0}
|
||||
|
||||
|
||||
def test_batch_tensors(ray_start_regular_shared):
|
||||
import torch
|
||||
ds = ray.data.from_items([torch.tensor([0, 0]) for _ in range(40)])
|
||||
|
|
Loading…
Add table
Reference in a new issue