mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
Implement zip() function for dataset (#18833)
This commit is contained in:
parent
a96dbd885b
commit
2c15215833
5 changed files with 117 additions and 0 deletions
|
@ -109,6 +109,10 @@ class BlockAccessor(Generic[T]):
|
||||||
schema=self.schema(),
|
schema=self.schema(),
|
||||||
input_files=input_files)
|
input_files=input_files)
|
||||||
|
|
||||||
|
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||||
|
"""Zip this block with another block of the same type and size."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def builder() -> "BlockBuilder[T]":
|
def builder() -> "BlockBuilder[T]":
|
||||||
"""Create a builder for this block type."""
|
"""Create a builder for this block type."""
|
||||||
|
|
|
@ -655,6 +655,57 @@ class Dataset(Generic[T]):
|
||||||
"""
|
"""
|
||||||
return Dataset(sort_impl(self._blocks, key, descending))
|
return Dataset(sort_impl(self._blocks, key, descending))
|
||||||
|
|
||||||
|
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
|
||||||
|
"""Zip this dataset with the elements of another.
|
||||||
|
|
||||||
|
The datasets must have identical num rows, block types, and block sizes
|
||||||
|
(e.g., one was produced from a ``.map()`` of another). For Arrow
|
||||||
|
blocks, the schema will be concatenated, and any duplicate column
|
||||||
|
names disambiguated with _1, _2, etc. suffixes.
|
||||||
|
|
||||||
|
Time complexity: O(dataset size / parallelism)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The dataset to zip with on the right hand side.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> ds = ray.data.range(5)
|
||||||
|
>>> ds.zip(ds).take()
|
||||||
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Dataset with (k, v) pairs (or concatenated Arrow schema) where k
|
||||||
|
comes from the first dataset and v comes from the second.
|
||||||
|
"""
|
||||||
|
|
||||||
|
blocks1 = self.get_blocks()
|
||||||
|
blocks2 = other.get_blocks()
|
||||||
|
|
||||||
|
if len(blocks1) != len(blocks2):
|
||||||
|
# TODO(ekl) consider supporting if num_rows are equal.
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot zip dataset of different num blocks: {} vs {}".format(
|
||||||
|
len(blocks1), len(blocks2)))
|
||||||
|
|
||||||
|
def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
|
||||||
|
b1 = BlockAccessor.for_block(block1)
|
||||||
|
result = b1.zip(block2)
|
||||||
|
br = BlockAccessor.for_block(result)
|
||||||
|
return result, br.get_metadata(input_files=[])
|
||||||
|
|
||||||
|
do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
metadata = []
|
||||||
|
for b1, b2 in zip(blocks1, blocks2):
|
||||||
|
res, meta = do_zip_fn.remote(b1, b2)
|
||||||
|
blocks.append(res)
|
||||||
|
metadata.append(meta)
|
||||||
|
|
||||||
|
# TODO(ekl) it might be nice to have a progress bar here.
|
||||||
|
metadata = ray.get(metadata)
|
||||||
|
return Dataset(BlockList(blocks, metadata))
|
||||||
|
|
||||||
def limit(self, limit: int) -> "Dataset[T]":
|
def limit(self, limit: int) -> "Dataset[T]":
|
||||||
"""Limit the dataset to the first number of records specified.
|
"""Limit the dataset to the first number of records specified.
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,30 @@ class ArrowBlockAccessor(BlockAccessor):
|
||||||
def size_bytes(self) -> int:
|
def size_bytes(self) -> int:
|
||||||
return self._table.nbytes
|
return self._table.nbytes
|
||||||
|
|
||||||
|
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||||
|
acc = BlockAccessor.for_block(other)
|
||||||
|
if not isinstance(acc, ArrowBlockAccessor):
|
||||||
|
raise ValueError("Cannot zip {} with block of type {}".format(
|
||||||
|
type(self), type(other)))
|
||||||
|
if acc.num_rows() != self.num_rows():
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot zip self (length {}) with block of length {}".format(
|
||||||
|
self.num_rows(), acc.num_rows()))
|
||||||
|
r = self.to_arrow()
|
||||||
|
s = acc.to_arrow()
|
||||||
|
for col_name in s.column_names:
|
||||||
|
col = s.column(col_name)
|
||||||
|
# Ensure the column names are unique after zip.
|
||||||
|
if col_name in r.column_names:
|
||||||
|
i = 1
|
||||||
|
new_name = col_name
|
||||||
|
while new_name in r.column_names:
|
||||||
|
new_name = "{}_{}".format(col_name, i)
|
||||||
|
i += 1
|
||||||
|
col_name = new_name
|
||||||
|
r = r.append_column(col_name, col)
|
||||||
|
return r
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def builder() -> ArrowBlockBuilder[T]:
|
def builder() -> ArrowBlockBuilder[T]:
|
||||||
return ArrowBlockBuilder()
|
return ArrowBlockBuilder()
|
||||||
|
|
|
@ -74,6 +74,16 @@ class SimpleBlockAccessor(BlockAccessor):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def zip(self, other: "Block[T]") -> "Block[T]":
|
||||||
|
if not isinstance(other, list):
|
||||||
|
raise ValueError("Cannot zip {} with block of type {}".format(
|
||||||
|
type(self), type(other)))
|
||||||
|
if len(other) != len(self._items):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot zip self (length {}) with block of length {}".format(
|
||||||
|
len(self), len(other)))
|
||||||
|
return list(zip(self._items, other))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def builder() -> SimpleBlockBuilder[T]:
|
def builder() -> SimpleBlockBuilder[T]:
|
||||||
return SimpleBlockBuilder()
|
return SimpleBlockBuilder()
|
||||||
|
|
|
@ -153,6 +153,34 @@ def test_basic(ray_start_regular_shared, pipelined):
|
||||||
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
|
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_zip(ray_start_regular_shared):
|
||||||
|
ds1 = ray.data.range(5)
|
||||||
|
ds2 = ray.data.range(5).map(lambda x: x + 1)
|
||||||
|
ds = ds1.zip(ds2)
|
||||||
|
assert ds.schema() == tuple
|
||||||
|
assert ds.take() == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ds.zip(ray.data.range(3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_zip_arrow(ray_start_regular_shared):
|
||||||
|
ds1 = ray.data.range_arrow(5).map(lambda r: {"id": r["value"]})
|
||||||
|
ds2 = ray.data.range_arrow(5).map(
|
||||||
|
lambda r: {"a": r["value"] + 1, "b": r["value"] + 2})
|
||||||
|
ds = ds1.zip(ds2)
|
||||||
|
assert "{id: int64, a: int64, b: int64}" in str(ds)
|
||||||
|
assert ds.count() == 5
|
||||||
|
result = [r.as_pydict() for r in ds.take()]
|
||||||
|
assert result[0] == {"id": 0, "a": 1, "b": 2}
|
||||||
|
|
||||||
|
# Test duplicate column names.
|
||||||
|
ds = ds1.zip(ds1).zip(ds1)
|
||||||
|
assert ds.count() == 5
|
||||||
|
assert "{id: int64, id_1: int64, id_2: int64}" in str(ds)
|
||||||
|
result = [r.as_pydict() for r in ds.take()]
|
||||||
|
assert result[0] == {"id": 0, "id_1": 0, "id_2": 0}
|
||||||
|
|
||||||
|
|
||||||
def test_batch_tensors(ray_start_regular_shared):
|
def test_batch_tensors(ray_start_regular_shared):
|
||||||
import torch
|
import torch
|
||||||
ds = ray.data.from_items([torch.tensor([0, 0]) for _ in range(40)])
|
ds = ray.data.from_items([torch.tensor([0, 0]) for _ in range(40)])
|
||||||
|
|
Loading…
Add table
Reference in a new issue