Revert "Revert "Revert "[Datasets] [Tensor Story - 1/2] Automatically provide tensor views to UDFs and infer tensor blocks for pure-tensor datasets."" (#25031)" (#25057)

Reverts #25031

It looks to be still somewhat flaky.
This commit is contained in:
mwtian 2022-05-25 19:43:22 -07:00 committed by GitHub
parent b2d41fc427
commit fb2933a78f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 166 additions and 516 deletions

View file

@ -15,57 +15,22 @@ Automatic conversion between the Pandas and Arrow extension types/arrays keeps t
Single-column tensor datasets Single-column tensor datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The most basic case is when a dataset only has a single column, which is of tensor The most basic case is when a dataset only has a single column, which is of tensor type. This kind of dataset can be created with ``.range_tensor()``, and can be read from and written to ``.npy`` files. Here are some examples:
type. This kind of dataset can be:
* created with :func:`range_tensor() <ray.data.range_tensor>`
or :func:`from_numpy() <ray.data.from_numpy>`,
* transformed with NumPy UDFs via
:meth:`ds.map_batches() <ray.data.Dataset.map_batches>`,
* consumed with :meth:`ds.iter_rows() <ray.data.Dataset.iter_rows>` and
:meth:`ds.iter_batches() <ray.data.Dataset.iter_batches>`, and
* can be read from and written to ``.npy`` files.
Here is an end-to-end example:
.. code-block:: python .. code-block:: python
# Create a synthetic pure-tensor Dataset. # Create a Dataset of tensor-typed values.
ds = ray.data.range_tensor(10, shape=(3, 5)) ds = ray.data.range_tensor(10000, shape=(3, 5))
# -> Dataset(num_blocks=10, num_rows=10, # -> Dataset(num_blocks=200, num_rows=10000,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>}) # schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})
# Create a pure-tensor Dataset from an existing NumPy ndarray. # Save to storage.
arr = np.arange(10 * 3 * 5).reshape((10, 3, 5)) ds.write_numpy("/tmp/tensor_out", column="value")
ds = ray.data.from_numpy(arr)
# -> Dataset(num_blocks=1, num_rows=10,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>})
# Transform the tensors. Datasets will automatically unpack the single-column Arrow # Read from storage.
# table into a NumPy ndarray, provide that ndarray to your UDF, and then repack it
# into a single-column Arrow table; this will be a zero-copy conversion in both
# cases.
ds = ds.map_batches(lambda arr: arr / arr.max())
# -> Dataset(num_blocks=1, num_rows=10,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=double>})
# Consume the tensor. This will yield the underlying (3, 5) ndarrays.
for arr in ds.iter_rows():
assert isinstance(arr, np.ndarray)
assert arr.shape == (3, 5)
# Consume the tensor in batches.
for arr in ds.iter_batches(batch_size=2):
assert isinstance(arr, np.ndarray)
assert arr.shape == (2, 3, 5)
# Save to storage. This will write out the blocks of the tensor column as NPY files.
ds.write_numpy("/tmp/tensor_out")
# Read back from storage.
ray.data.read_numpy("/tmp/tensor_out") ray.data.read_numpy("/tmp/tensor_out")
# -> Dataset(num_blocks=1, num_rows=?, # -> Dataset(num_blocks=200, num_rows=?,
# schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=double>}) # schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})
Reading existing serialized tensor columns Reading existing serialized tensor columns
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -3,7 +3,6 @@ import time
from typing import ( from typing import (
TypeVar, TypeVar,
List, List,
Dict,
Generic, Generic,
Iterator, Iterator,
Tuple, Tuple,
@ -83,10 +82,6 @@ def _validate_key_fn(ds: "Dataset", key: KeyFn) -> None:
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``. # ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
Block = Union[List[T], "pyarrow.Table", "pandas.DataFrame", bytes] Block = Union[List[T], "pyarrow.Table", "pandas.DataFrame", bytes]
# User-facing data batch type. This is the data type for data that is supplied to and
# returned from batch UDFs.
DataBatch = Union[Block, np.ndarray]
# A list of block references pending computation by a single task. For example, # A list of block references pending computation by a single task. For example,
# this may be the output of a task reading a file. # this may be the output of a task reading a file.
BlockPartition = List[Tuple[ObjectRef[Block], "BlockMetadata"]] BlockPartition = List[Tuple[ObjectRef[Block], "BlockMetadata"]]
@ -215,13 +210,11 @@ class BlockAccessor(Generic[T]):
"""Convert this block into a Pandas dataframe.""" """Convert this block into a Pandas dataframe."""
raise NotImplementedError raise NotImplementedError
def to_numpy( def to_numpy(self, column: str = None) -> np.ndarray:
self, columns: Optional[Union[str, List[str]]] = None """Convert this block (or column of block) into a NumPy ndarray.
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Convert this block (or columns of block) into a NumPy ndarray.
Args: Args:
columns: Name of columns to convert, or None if converting all columns. column: Name of column to convert, or None.
""" """
raise NotImplementedError raise NotImplementedError
@ -233,10 +226,6 @@ class BlockAccessor(Generic[T]):
"""Return the base block that this accessor wraps.""" """Return the base block that this accessor wraps."""
raise NotImplementedError raise NotImplementedError
def to_native(self) -> Block:
"""Return the native data format for this accessor."""
return self.to_block()
def size_bytes(self) -> int: def size_bytes(self) -> int:
"""Return the approximate size in bytes of this block.""" """Return the approximate size in bytes of this block."""
raise NotImplementedError raise NotImplementedError
@ -266,15 +255,6 @@ class BlockAccessor(Generic[T]):
"""Create a builder for this block type.""" """Create a builder for this block type."""
raise NotImplementedError raise NotImplementedError
@staticmethod
def batch_to_block(batch: DataBatch) -> Block:
"""Create a block from user-facing data formats."""
if isinstance(batch, np.ndarray):
from ray.data.impl.arrow_block import ArrowBlockAccessor
return ArrowBlockAccessor.numpy_to_block(batch)
return batch
@staticmethod @staticmethod
def for_block(block: Block) -> "BlockAccessor[T]": def for_block(block: Block) -> "BlockAccessor[T]":
"""Create a block accessor for the given block.""" """Create a block accessor for the given block."""

View file

@ -68,7 +68,6 @@ from ray.data.datasource.file_based_datasource import (
from ray.data.row import TableRow from ray.data.row import TableRow
from ray.data.aggregate import AggregateFn, Sum, Max, Min, Mean, Std from ray.data.aggregate import AggregateFn, Sum, Max, Min, Mean, Std
from ray.data.random_access_dataset import RandomAccessDataset from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.impl.table_block import VALUE_COL_NAME
from ray.data.impl.remote_fn import cached_remote_fn from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.block_batching import batch_blocks, BatchType from ray.data.impl.block_batching import batch_blocks, BatchType
from ray.data.impl.plan import ExecutionPlan, OneToOneStage, AllToAllStage from ray.data.impl.plan import ExecutionPlan, OneToOneStage, AllToAllStage
@ -235,8 +234,8 @@ class Dataset(Generic[T]):
def transform(block: Block) -> Iterable[Block]: def transform(block: Block) -> Iterable[Block]:
DatasetContext._set_current(context) DatasetContext._set_current(context)
output_buffer = BlockOutputBuffer(None, context.target_max_block_size)
block = BlockAccessor.for_block(block) block = BlockAccessor.for_block(block)
output_buffer = BlockOutputBuffer(None, context.target_max_block_size)
for row in block.iter_rows(): for row in block.iter_rows():
output_buffer.add(fn(row)) output_buffer.add(fn(row))
if output_buffer.has_next(): if output_buffer.has_next():
@ -261,9 +260,6 @@ class Dataset(Generic[T]):
) -> "Dataset[Any]": ) -> "Dataset[Any]":
"""Apply the given function to batches of records of this dataset. """Apply the given function to batches of records of this dataset.
The format of the data batch provided to ``fn`` can be controlled via the
``batch_format`` argument, and the output of the UDF can be any batch type.
This is a blocking operation. This is a blocking operation.
Examples: Examples:
@ -310,9 +306,10 @@ class Dataset(Generic[T]):
blocks as batches. Defaults to a system-chosen batch size. blocks as batches. Defaults to a system-chosen batch size.
compute: The compute strategy, either "tasks" (default) to use Ray compute: The compute strategy, either "tasks" (default) to use Ray
tasks, or ActorPoolStrategy(min, max) to use an autoscaling actor pool. tasks, or ActorPoolStrategy(min, max) to use an autoscaling actor pool.
batch_format: Specify "native" to use the native block format (promotes batch_format: Specify "native" to use the native block format
tables to Pandas and tensors to NumPy), "pandas" to select (promotes Arrow to pandas), "pandas" to select
``pandas.DataFrame``, or "pyarrow" to select `pyarrow.Table``. ``pandas.DataFrame`` as the batch format,
or "pyarrow" to select ``pyarrow.Table``.
ray_remote_args: Additional resource requirements to request from ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks). ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" """
@ -341,7 +338,9 @@ class Dataset(Generic[T]):
# bug where we include the entire base view on serialization. # bug where we include the entire base view on serialization.
view = block.slice(start, end, copy=batch_size is not None) view = block.slice(start, end, copy=batch_size is not None)
if batch_format == "native": if batch_format == "native":
view = BlockAccessor.for_block(view).to_native() # Always promote Arrow blocks to pandas for consistency.
if isinstance(view, pa.Table) or isinstance(view, bytes):
view = BlockAccessor.for_block(view).to_pandas()
elif batch_format == "pandas": elif batch_format == "pandas":
view = BlockAccessor.for_block(view).to_pandas() view = BlockAccessor.for_block(view).to_pandas()
elif batch_format == "pyarrow": elif batch_format == "pyarrow":
@ -356,7 +355,6 @@ class Dataset(Generic[T]):
if not ( if not (
isinstance(applied, list) isinstance(applied, list)
or isinstance(applied, pa.Table) or isinstance(applied, pa.Table)
or isinstance(applied, np.ndarray)
or isinstance(applied, pd.core.frame.DataFrame) or isinstance(applied, pd.core.frame.DataFrame)
): ):
raise ValueError( raise ValueError(
@ -366,7 +364,7 @@ class Dataset(Generic[T]):
"The return type must be either list, " "The return type must be either list, "
"pandas.DataFrame, or pyarrow.Table" "pandas.DataFrame, or pyarrow.Table"
) )
output_buffer.add_batch(applied) output_buffer.add_block(applied)
if output_buffer.has_next(): if output_buffer.has_next():
yield output_buffer.next() yield output_buffer.next()
@ -703,8 +701,6 @@ class Dataset(Generic[T]):
) )
if isinstance(batch, pd.DataFrame): if isinstance(batch, pd.DataFrame):
return batch.sample(frac=fraction) return batch.sample(frac=fraction)
if isinstance(batch, np.ndarray):
return np.array([row for row in batch if random.random() <= fraction])
raise ValueError(f"Unsupported batch type: {type(batch)}") raise ValueError(f"Unsupported batch type: {type(batch)}")
return self.map_batches(process_batch) return self.map_batches(process_batch)
@ -2075,7 +2071,7 @@ class Dataset(Generic[T]):
self, self,
path: str, path: str,
*, *,
column: str = VALUE_COL_NAME, column: str = "value",
filesystem: Optional["pyarrow.fs.FileSystem"] = None, filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True, try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None, arrow_open_stream_args: Optional[Dict[str, Any]] = None,
@ -2103,8 +2099,7 @@ class Dataset(Generic[T]):
path: The path to the destination root directory, where npy path: The path to the destination root directory, where npy
files will be written to. files will be written to.
column: The name of the table column that contains the tensor to column: The name of the table column that contains the tensor to
be written. The default is ``"__value__"``, the column name that be written. This defaults to "value".
Datasets uses for storing tensors in single-column tables.
filesystem: The filesystem implementation to write to. filesystem: The filesystem implementation to write to.
try_create_dir: Try to create all directories in destination path try_create_dir: Try to create all directories in destination path
if True. Does nothing if all directories already exist. if True. Does nothing if all directories already exist.
@ -2251,10 +2246,10 @@ class Dataset(Generic[T]):
current block during the scan. current block during the scan.
batch_size: Record batch size, or None to let the system pick. batch_size: Record batch size, or None to let the system pick.
batch_format: The format in which to return each batch. batch_format: The format in which to return each batch.
Specify "native" to use the native block format (promoting Specify "native" to use the current block format (promoting
tables to Pandas and tensors to NumPy), "pandas" to select Arrow to pandas automatically), "pandas" to
``pandas.DataFrame``, or "pyarrow" to select ``pyarrow.Table``. Default select ``pandas.DataFrame`` or "pyarrow" to select
is "native". ``pyarrow.Table``. Default is "native".
drop_last: Whether to drop the last batch if it's incomplete. drop_last: Whether to drop the last batch if it's incomplete.
Returns: Returns:
@ -2776,9 +2771,8 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
Time complexity: O(dataset size / parallelism) Time complexity: O(dataset size / parallelism)
Args: Args:
column: The name of the column to convert to numpy, or None to specify the column: The name of the column to convert to numpy, or None to
entire row. If not specified for Arrow or Pandas blocks, each returned specify the entire row. Required for Arrow tables.
future will represent a dict of column ndarrays.
Returns: Returns:
A list of remote NumPy ndarrays created from this dataset. A list of remote NumPy ndarrays created from this dataset.

View file

@ -193,11 +193,14 @@ class RangeDatasource(Datasource[Union[ArrowRow, int]]):
elif block_format == "tensor": elif block_format == "tensor":
import pyarrow as pa import pyarrow as pa
tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( tensor = TensorArray(
np.arange(start, start + count), np.ones(tensor_shape, dtype=np.int64)
tuple(range(1, 1 + len(tensor_shape))), * np.expand_dims(
np.arange(start, start + count),
tuple(range(1, 1 + len(tensor_shape))),
)
) )
return BlockAccessor.batch_to_block(tensor) return pa.Table.from_pydict({"value": tensor})
else: else:
return list(builtins.range(start, start + count)) return list(builtins.range(start, start + count))
@ -211,12 +214,16 @@ class RangeDatasource(Datasource[Union[ArrowRow, int]]):
schema = pa.Table.from_pydict({"value": [0]}).schema schema = pa.Table.from_pydict({"value": [0]}).schema
elif block_format == "tensor": elif block_format == "tensor":
_check_pyarrow_version() _check_pyarrow_version()
from ray.data.extensions import TensorArray
import pyarrow as pa import pyarrow as pa
tensor = np.ones(tensor_shape, dtype=np.int64) * np.expand_dims( tensor = TensorArray(
np.arange(0, 10), tuple(range(1, 1 + len(tensor_shape))) np.ones(tensor_shape, dtype=np.int64)
* np.expand_dims(
np.arange(0, 10), tuple(range(1, 1 + len(tensor_shape)))
)
) )
schema = BlockAccessor.batch_to_block(tensor).schema schema = pa.Table.from_pydict({"value": tensor}).schema
elif block_format == "list": elif block_format == "list":
schema = int schema = int
else: else:

View file

@ -26,13 +26,18 @@ class NumpyDatasource(FileBasedDatasource):
""" """
def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args): def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
from ray.data.extensions import TensorArray
import pyarrow as pa
# TODO(ekl) Ideally numpy can read directly from the file, but it # TODO(ekl) Ideally numpy can read directly from the file, but it
# seems like it requires the file to be seekable. # seems like it requires the file to be seekable.
buf = BytesIO() buf = BytesIO()
data = f.readall() data = f.readall()
buf.write(data) buf.write(data)
buf.seek(0) buf.seek(0)
return BlockAccessor.batch_to_block(np.load(buf, allow_pickle=True)) return pa.Table.from_pydict(
{"value": TensorArray(np.load(buf, allow_pickle=True))}
)
def _write_block( def _write_block(
self, self,

View file

@ -6,7 +6,6 @@ from typing import (
Dict, Dict,
List, List,
Tuple, Tuple,
Union,
Iterator, Iterator,
Any, Any,
TypeVar, TypeVar,
@ -31,11 +30,7 @@ from ray.data.block import (
KeyType, KeyType,
) )
from ray.data.row import TableRow from ray.data.row import TableRow
from ray.data.impl.table_block import ( from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
TableBlockAccessor,
TableBlockBuilder,
VALUE_COL_NAME,
)
from ray.data.aggregate import AggregateFn from ray.data.aggregate import AggregateFn
if TYPE_CHECKING: if TYPE_CHECKING:
@ -78,13 +73,6 @@ class ArrowBlockBuilder(TableBlockBuilder[T]):
super().__init__(pyarrow.Table) super().__init__(pyarrow.Table)
def _table_from_pydict(self, columns: Dict[str, List[Any]]) -> Block: def _table_from_pydict(self, columns: Dict[str, List[Any]]) -> Block:
for col_name, col in columns.items():
if col_name == VALUE_COL_NAME or isinstance(
next(iter(col), None), np.ndarray
):
from ray.data.extensions.tensor_extension import ArrowTensorArray
columns[col_name] = ArrowTensorArray.from_numpy(col)
return pyarrow.Table.from_pydict(columns) return pyarrow.Table.from_pydict(columns)
def _concat_tables(self, tables: List[Block]) -> Block: def _concat_tables(self, tables: List[Block]) -> Block:
@ -96,35 +84,19 @@ class ArrowBlockBuilder(TableBlockBuilder[T]):
class ArrowBlockAccessor(TableBlockAccessor): class ArrowBlockAccessor(TableBlockAccessor):
ROW_TYPE = ArrowRow
def __init__(self, table: "pyarrow.Table"): def __init__(self, table: "pyarrow.Table"):
if pyarrow is None: if pyarrow is None:
raise ImportError("Run `pip install pyarrow` for Arrow support") raise ImportError("Run `pip install pyarrow` for Arrow support")
super().__init__(table) super().__init__(table)
def column_names(self) -> List[str]: def _create_table_row(self, row: "pyarrow.Table") -> ArrowRow:
return self._table.column_names return ArrowRow(row)
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor": def from_bytes(cls, data: bytes):
reader = pyarrow.ipc.open_stream(data) reader = pyarrow.ipc.open_stream(data)
return cls(reader.read_all()) return cls(reader.read_all())
@staticmethod
def numpy_to_block(batch: np.ndarray) -> "pyarrow.Table":
import pyarrow as pa
from ray.data.extensions.tensor_extension import ArrowTensorArray
return pa.Table.from_pydict(
{VALUE_COL_NAME: ArrowTensorArray.from_numpy(batch)}
)
@staticmethod
def _build_tensor_row(row: ArrowRow) -> np.ndarray:
# Getting an item in a tensor column automatically does a NumPy conversion.
return row[VALUE_COL_NAME][0]
def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table": def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":
view = self._table.slice(start, end - start) view = self._table.slice(start, end - start)
if copy: if copy:
@ -133,7 +105,7 @@ class ArrowBlockAccessor(TableBlockAccessor):
def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table": def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
random = np.random.RandomState(random_seed) random = np.random.RandomState(random_seed)
return self.take(random.permutation(self.num_rows())) return self._table.take(random.permutation(self.num_rows()))
def schema(self) -> "pyarrow.lib.Schema": def schema(self) -> "pyarrow.lib.Schema":
return self._table.schema return self._table.schema
@ -141,34 +113,26 @@ class ArrowBlockAccessor(TableBlockAccessor):
def to_pandas(self) -> "pandas.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
return self._table.to_pandas() return self._table.to_pandas()
def to_numpy( def to_numpy(self, column: str = None) -> np.ndarray:
self, columns: Optional[Union[str, List[str]]] = None if column is None:
) -> Union[np.ndarray, Dict[str, np.ndarray]]: raise ValueError(
if columns is None: "`column` must be specified when calling .to_numpy() "
columns = self._table.column_names "on Arrow blocks."
if not isinstance(columns, list): )
columns = [columns] if column not in self._table.column_names:
for column in columns: raise ValueError(
if column not in self._table.column_names: f"Cannot find column {column}, available columns: "
raise ValueError( f"{self._table.column_names}"
f"Cannot find column {column}, available columns: " )
f"{self._table.column_names}" array = self._table[column]
) if array.num_chunks > 1:
arrays = [] # TODO(ekl) combine fails since we can't concat
for column in columns: # ArrowTensorType?
array = self._table[column] array = array.combine_chunks()
if array.num_chunks == 0:
array = pyarrow.array([], type=array.type)
elif _is_column_extension_type(array):
array = _concatenate_extension_column(array)
else:
array = array.combine_chunks()
arrays.append(array.to_numpy(zero_copy_only=False))
if len(arrays) == 1:
arrays = arrays[0]
else: else:
arrays = dict(zip(columns, arrays)) assert array.num_chunks == 1, array
return arrays array = array.chunk(0)
return array.to_numpy(zero_copy_only=False)
def to_arrow(self) -> "pyarrow.Table": def to_arrow(self) -> "pyarrow.Table":
return self._table return self._table
@ -205,45 +169,9 @@ class ArrowBlockAccessor(TableBlockAccessor):
def _empty_table() -> "pyarrow.Table": def _empty_table() -> "pyarrow.Table":
return ArrowBlockBuilder._empty_table() return ArrowBlockBuilder._empty_table()
@staticmethod
def take_table(
table: "pyarrow.Table",
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
) -> "pyarrow.Table":
"""Select rows from the table.
This method is an alternative to pyarrow.Table.take(), which breaks for
extension arrays. This is exposed as a static method for easier use on
intermediate tables, not underlying an ArrowBlockAccessor.
"""
if any(_is_column_extension_type(col) for col in table.columns):
new_cols = []
for col in table.columns:
if _is_column_extension_type(col):
# .take() will concatenate internally, which currently breaks for
# extension arrays.
col = _concatenate_extension_column(col)
new_cols.append(col.take(indices))
table = pyarrow.Table.from_arrays(new_cols, schema=table.schema)
else:
table = table.take(indices)
return table
def take(
self,
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
) -> "pyarrow.Table":
"""Select rows from the underlying table.
This method is an alternative to pyarrow.Table.take(), which breaks for
extension arrays.
"""
return self.take_table(self._table, indices)
def _sample(self, n_samples: int, key: "SortKeyT") -> "pyarrow.Table": def _sample(self, n_samples: int, key: "SortKeyT") -> "pyarrow.Table":
indices = random.sample(range(self._table.num_rows), n_samples) indices = random.sample(range(self._table.num_rows), n_samples)
table = self._table.select([k[0] for k in key]) return self._table.select([k[0] for k in key]).take(indices)
return self.take_table(table, indices)
def count(self, on: KeyFn) -> Optional[U]: def count(self, on: KeyFn) -> Optional[U]:
"""Count the number of non-null values in the provided column.""" """Count the number of non-null values in the provided column."""
@ -340,7 +268,7 @@ class ArrowBlockAccessor(TableBlockAccessor):
import pyarrow.compute as pac import pyarrow.compute as pac
indices = pac.sort_indices(self._table, sort_keys=key) indices = pac.sort_indices(self._table, sort_keys=key)
table = self.take(indices) table = self._table.take(indices)
if len(boundaries) == 0: if len(boundaries) == 0:
return [table] return [table]
@ -465,7 +393,7 @@ class ArrowBlockAccessor(TableBlockAccessor):
else: else:
ret = pyarrow.concat_tables(blocks, promote=True) ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key) indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ArrowBlockAccessor.take_table(ret, indices) ret = ret.take(indices)
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build()) return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())
@staticmethod @staticmethod
@ -561,33 +489,6 @@ class ArrowBlockAccessor(TableBlockAccessor):
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build()) return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())
def _is_column_extension_type(ca: "pyarrow.ChunkedArray") -> bool:
"""Whether the provided Arrow Table column is an extension array, using an Arrow
extension type.
"""
return isinstance(ca.type, pyarrow.ExtensionType)
def _concatenate_extension_column(ca: "pyarrow.ChunkedArray") -> "pyarrow.Array":
"""Concatenate chunks of an extension column into a contiguous array.
This concatenation is required for creating copies and for .take() to work on
extension arrays.
See https://issues.apache.org/jira/browse/ARROW-16503.
"""
if not _is_column_extension_type(ca):
raise ValueError("Chunked array isn't an extension array: {ca}")
if ca.num_chunks == 0:
# No-op for no-chunk chunked arrays, since there's nothing to concatenate.
return ca
chunk = ca.chunk(0)
return type(chunk).from_storage(
chunk.type, pyarrow.concat_arrays([c.storage for c in ca.chunks])
)
def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table": def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
"""Copy the provided Arrow table.""" """Copy the provided Arrow table."""
import pyarrow as pa import pyarrow as pa
@ -597,10 +498,14 @@ def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
cols = table.columns cols = table.columns
new_cols = [] new_cols = []
for col in cols: for col in cols:
if _is_column_extension_type(col): if col.num_chunks > 0 and isinstance(col.chunk(0), pa.ExtensionArray):
# Extension arrays don't support concatenation. # If an extension array, we copy the underlying storage arrays.
arr = _concatenate_extension_column(col) chunk = col.chunk(0)
arr = type(chunk).from_storage(
chunk.type, pa.concat_arrays([c.storage for c in col.chunks])
)
else: else:
# Otherwise, we copy the top-level chunk arrays.
arr = col.combine_chunks() arr = col.combine_chunks()
new_cols.append(arr) new_cols.append(arr)
return pa.Table.from_arrays(new_cols, schema=table.schema) return pa.Table.from_arrays(new_cols, schema=table.schema)

View file

@ -93,20 +93,26 @@ def batch_blocks(
def _format_batch(batch: Block, batch_format: str) -> BatchType: def _format_batch(batch: Block, batch_format: str) -> BatchType:
import pyarrow as pa
if batch_format == "native": if batch_format == "native":
batch = BlockAccessor.for_block(batch).to_native() # Always promote Arrow blocks to pandas for consistency, since
# we lazily convert pandas->Arrow internally for efficiency.
if isinstance(batch, pa.Table) or isinstance(batch, bytes):
batch = BlockAccessor.for_block(batch)
batch = batch.to_pandas()
return batch
elif batch_format == "pandas": elif batch_format == "pandas":
batch = BlockAccessor.for_block(batch).to_pandas() batch = BlockAccessor.for_block(batch)
return batch.to_pandas()
elif batch_format == "pyarrow": elif batch_format == "pyarrow":
batch = BlockAccessor.for_block(batch).to_arrow() batch = BlockAccessor.for_block(batch)
elif batch_format == "numpy": return batch.to_arrow()
batch = BlockAccessor.for_block(batch).to_numpy()
else: else:
raise ValueError( raise ValueError(
f"The given batch format: {batch_format} " f"The given batch format: {batch_format} "
f"is invalid. Supported batch type: {BatchType}" f"is invalid. Supported batch type: {BatchType}"
) )
return batch
def _sliding_window(iterable: Iterable, n: int): def _sliding_window(iterable: Iterable, n: int):

View file

@ -1,8 +1,6 @@
from typing import Any from typing import Any
import numpy as np from ray.data.block import Block, T, BlockAccessor
from ray.data.block import Block, DataBatch, T, BlockAccessor
from ray.data.impl.block_builder import BlockBuilder from ray.data.impl.block_builder import BlockBuilder
from ray.data.impl.simple_block import SimpleBlockBuilder from ray.data.impl.simple_block import SimpleBlockBuilder
from ray.data.impl.arrow_block import ArrowRow, ArrowBlockBuilder from ray.data.impl.arrow_block import ArrowRow, ArrowBlockBuilder
@ -15,6 +13,7 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
self._empty_block = None self._empty_block = None
def add(self, item: Any) -> None: def add(self, item: Any) -> None:
if self._builder is None: if self._builder is None:
# TODO (kfstorm): Maybe we can use Pandas block format for dict. # TODO (kfstorm): Maybe we can use Pandas block format for dict.
if isinstance(item, dict) or isinstance(item, ArrowRow): if isinstance(item, dict) or isinstance(item, ArrowRow):
@ -27,24 +26,13 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
self._builder = ArrowBlockBuilder() self._builder = ArrowBlockBuilder()
except (TypeError, pyarrow.lib.ArrowInvalid): except (TypeError, pyarrow.lib.ArrowInvalid):
self._builder = SimpleBlockBuilder() self._builder = SimpleBlockBuilder()
elif isinstance(item, np.ndarray):
self._builder = ArrowBlockBuilder()
elif isinstance(item, PandasRow): elif isinstance(item, PandasRow):
self._builder = PandasBlockBuilder() self._builder = PandasBlockBuilder()
else: else:
self._builder = SimpleBlockBuilder() self._builder = SimpleBlockBuilder()
self._builder.add(item) self._builder.add(item)
def add_batch(self, batch: DataBatch): def add_block(self, block: Block) -> None:
"""Add a user-facing data batch to the builder.
This data batch will be converted to an internal block and then added to the
underlying builder.
"""
block = BlockAccessor.batch_to_block(batch)
return self.add_block(block)
def add_block(self, block: Block):
accessor = BlockAccessor.for_block(block) accessor = BlockAccessor.for_block(block)
if accessor.num_rows() == 0: if accessor.num_rows() == 0:
# Don't infer types of empty lists. Store the block and use it if no # Don't infer types of empty lists. Store the block and use it if no

View file

@ -1,6 +1,6 @@
from typing import Callable, Any, Optional from typing import Callable, Any, Optional
from ray.data.block import Block, DataBatch, BlockAccessor from ray.data.block import Block, BlockAccessor
from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder
@ -44,11 +44,6 @@ class BlockOutputBuffer(object):
assert not self._finalized assert not self._finalized
self._buffer.add(item) self._buffer.add(item)
def add_batch(self, batch: DataBatch) -> None:
"""Add a data batch to this output buffer."""
assert not self._finalized
self._buffer.add_batch(batch)
def add_block(self, block: Block) -> None: def add_block(self, block: Block) -> None:
"""Add a data block to this output buffer.""" """Add a data block to this output buffer."""
assert not self._finalized assert not self._finalized

View file

@ -3,7 +3,6 @@ from typing import (
Dict, Dict,
List, List,
Tuple, Tuple,
Union,
Iterator, Iterator,
Any, Any,
TypeVar, TypeVar,
@ -16,11 +15,7 @@ import numpy as np
from ray.data.block import BlockAccessor, BlockMetadata, KeyFn, U from ray.data.block import BlockAccessor, BlockMetadata, KeyFn, U
from ray.data.row import TableRow from ray.data.row import TableRow
from ray.data.impl.table_block import ( from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
TableBlockAccessor,
TableBlockBuilder,
VALUE_COL_NAME,
)
from ray.data.impl.arrow_block import ArrowBlockAccessor from ray.data.impl.arrow_block import ArrowBlockAccessor
from ray.data.aggregate import AggregateFn from ray.data.aggregate import AggregateFn
@ -76,13 +71,6 @@ class PandasBlockBuilder(TableBlockBuilder[T]):
def _table_from_pydict(self, columns: Dict[str, List[Any]]) -> "pandas.DataFrame": def _table_from_pydict(self, columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
pandas = lazy_import_pandas() pandas = lazy_import_pandas()
for key, value in columns.items():
if key == VALUE_COL_NAME or isinstance(next(iter(value), None), np.ndarray):
from ray.data.extensions.tensor_extension import TensorArray
if len(value) == 1:
value = value[0]
columns[key] = TensorArray(value)
return pandas.DataFrame(columns) return pandas.DataFrame(columns)
def _concat_tables(self, tables: List["pandas.DataFrame"]) -> "pandas.DataFrame": def _concat_tables(self, tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
@ -101,19 +89,11 @@ PandasBlockSchema = collections.namedtuple("PandasBlockSchema", ["names", "types
class PandasBlockAccessor(TableBlockAccessor): class PandasBlockAccessor(TableBlockAccessor):
ROW_TYPE = PandasRow
def __init__(self, table: "pandas.DataFrame"): def __init__(self, table: "pandas.DataFrame"):
super().__init__(table) super().__init__(table)
def column_names(self) -> List[str]: def _create_table_row(self, row: "pandas.DataFrame") -> PandasRow:
return self._table.columns.tolist() return PandasRow(row)
@staticmethod
def _build_tensor_row(row: PandasRow) -> np.ndarray:
# Getting an item in a Pandas tensor column returns a TensorArrayElement, which
# we have to convert to an ndarray.
return row[VALUE_COL_NAME].iloc[0].to_numpy()
def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame": def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
view = self._table[start:end] view = self._table[start:end]
@ -142,27 +122,19 @@ class PandasBlockAccessor(TableBlockAccessor):
def to_pandas(self) -> "pandas.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
return self._table return self._table
def to_numpy( def to_numpy(self, column: str = None) -> np.ndarray:
self, columns: Optional[Union[str, List[str]]] = None if not column:
) -> Union[np.ndarray, Dict[str, np.ndarray]]: raise ValueError(
if columns is None: "`column` must be specified when calling .to_numpy() "
columns = self._table.columns.tolist() "on Pandas blocks."
if not isinstance(columns, list): )
columns = [columns] if column not in self._table.columns:
for column in columns: raise ValueError(
if column not in self._table.columns: "Cannot find column {}, available columns: {}".format(
raise ValueError( column, self._table.columns.tolist()
f"Cannot find column {column}, available columns: "
f"{self._table.columns.tolist()}"
) )
arrays = [] )
for column in columns: return self._table[column].to_numpy()
arrays.append(self._table[column].to_numpy())
if len(arrays) == 1:
arrays = arrays[0]
else:
arrays = dict(zip(columns, arrays))
return arrays
def to_arrow(self) -> "pyarrow.Table": def to_arrow(self) -> "pyarrow.Table":
import pyarrow import pyarrow

View file

@ -1,7 +1,7 @@
import random import random
import sys import sys
import heapq import heapq
from typing import Union, Callable, Iterator, List, Tuple, Any, Optional, TYPE_CHECKING from typing import Callable, Iterator, List, Tuple, Any, Optional, TYPE_CHECKING
import numpy as np import numpy as np
@ -84,9 +84,9 @@ class SimpleBlockAccessor(BlockAccessor):
return pandas.DataFrame({"value": self._items}) return pandas.DataFrame({"value": self._items})
def to_numpy(self, columns: Optional[Union[str, List[str]]] = None) -> np.ndarray: def to_numpy(self, column: str = None) -> np.ndarray:
if columns: if column:
raise ValueError("`columns` arg is not supported for list block.") raise ValueError("`column` arg not supported for list block")
return np.array(self._items) return np.array(self._items)
def to_arrow(self) -> "pyarrow.Table": def to_arrow(self) -> "pyarrow.Table":

View file

@ -1,7 +1,6 @@
import collections import collections
from typing import Dict, Iterator, List, Union, Any, TypeVar, TYPE_CHECKING
import numpy as np from typing import Dict, Iterator, List, Union, Any, TypeVar, TYPE_CHECKING
from ray.data.block import Block, BlockAccessor from ray.data.block import Block, BlockAccessor
from ray.data.row import TableRow from ray.data.row import TableRow
@ -11,11 +10,6 @@ from ray.data.impl.size_estimator import SizeEstimator
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.data.impl.sort import SortKeyT from ray.data.impl.sort import SortKeyT
# The internal column name used for pure-tensor datasets, represented as
# single-tensor-column tables.
VALUE_COL_NAME = "__value__"
T = TypeVar("T") T = TypeVar("T")
# The max size of Python tuples to buffer before compacting them into a # The max size of Python tuples to buffer before compacting them into a
@ -36,11 +30,9 @@ class TableBlockBuilder(BlockBuilder[T]):
self._num_compactions = 0 self._num_compactions = 0
self._block_type = block_type self._block_type = block_type
def add(self, item: Union[dict, TableRow, np.ndarray]) -> None: def add(self, item: Union[dict, TableRow]) -> None:
if isinstance(item, TableRow): if isinstance(item, TableRow):
item = item.as_pydict() item = item.as_pydict()
elif isinstance(item, np.ndarray):
item = {VALUE_COL_NAME: item}
if not isinstance(item, dict): if not isinstance(item, dict):
raise ValueError( raise ValueError(
"Returned elements of an TableBlock must be of type `dict`, " "Returned elements of an TableBlock must be of type `dict`, "
@ -108,42 +100,16 @@ class TableBlockBuilder(BlockBuilder[T]):
class TableBlockAccessor(BlockAccessor): class TableBlockAccessor(BlockAccessor):
ROW_TYPE: TableRow = TableRow
def __init__(self, table: Any): def __init__(self, table: Any):
self._table = table self._table = table
def _get_row(self, index: int, copy: bool = False) -> Union[TableRow, np.ndarray]: def _create_table_row(self, row: Any) -> TableRow:
row = self.slice(index, index + 1, copy=copy)
if self.is_tensor_wrapper():
row = self._build_tensor_row(row)
else:
row = self.ROW_TYPE(row)
return row
@staticmethod
def _build_tensor_row(row: TableRow) -> np.ndarray:
raise NotImplementedError
def to_native(self) -> Block:
if self.is_tensor_wrapper():
native = self.to_numpy()
else:
# Always promote Arrow blocks to pandas for consistency, since
# we lazily convert pandas->Arrow internally for efficiency.
native = self.to_pandas()
return native
def column_names(self) -> List[str]:
raise NotImplementedError raise NotImplementedError
def to_block(self) -> Block: def to_block(self) -> Block:
return self._table return self._table
def is_tensor_wrapper(self) -> bool: def iter_rows(self) -> Iterator[TableRow]:
return self.column_names() == [VALUE_COL_NAME]
def iter_rows(self) -> Iterator[Union[TableRow, np.ndarray]]:
outer = self outer = self
class Iter: class Iter:
@ -156,7 +122,10 @@ class TableBlockAccessor(BlockAccessor):
def __next__(self): def __next__(self):
self._cur += 1 self._cur += 1
if self._cur < outer.num_rows(): if self._cur < outer.num_rows():
return outer._get_row(self._cur) row = outer._create_table_row(
outer.slice(self._cur, self._cur + 1, copy=False)
)
return row
raise StopIteration raise StopIteration
return Iter() return Iter()

View file

@ -225,7 +225,9 @@ class _RandomAccessWorker:
col = block[self.key_field] col = block[self.key_field]
indices = np.searchsorted(col, keys) indices = np.searchsorted(col, keys)
acc = BlockAccessor.for_block(block) acc = BlockAccessor.for_block(block)
result = [acc._get_row(i, copy=True) for i in indices] result = [
acc._create_table_row(acc.slice(i, i + 1, copy=True)) for i in indices
]
# assert result == [self._get(i, k) for i, k in zip(block_indices, keys)] # assert result == [self._get(i, k) for i, k in zip(block_indices, keys)]
else: else:
result = [self._get(i, k) for i, k in zip(block_indices, keys)] result = [self._get(i, k) for i, k in zip(block_indices, keys)]
@ -254,7 +256,7 @@ class _RandomAccessWorker:
if i is None: if i is None:
return None return None
acc = BlockAccessor.for_block(block) acc = BlockAccessor.for_block(block)
return acc._get_row(i, copy=True) return acc._create_table_row(acc.slice(i, i + 1, copy=True))
def _binary_search_find(column, x): def _binary_search_find(column, x):

View file

@ -174,10 +174,10 @@ def range_tensor(
>>> import ray >>> import ray
>>> ds = ray.data.range_tensor(1000, shape=(3, 10)) # doctest: +SKIP >>> ds = ray.data.range_tensor(1000, shape=(3, 10)) # doctest: +SKIP
>>> ds.map_batches( # doctest: +SKIP >>> ds.map_batches( # doctest: +SKIP
... lambda arr: arr * 2).show() ... lambda arr: arr * 2, batch_format="pandas").show()
This is similar to range_table(), but uses the ArrowTensorArray extension This is similar to range_table(), but uses the ArrowTensorArray extension
type. The dataset elements take the form {VALUE_COL_NAME: array(N, shape=shape)}. type. The dataset elements take the form {"value": array(N, shape=shape)}.
Args: Args:
n: The upper bound of the range of integer records. n: The upper bound of the range of integer records.
@ -1020,11 +1020,16 @@ def _df_to_block(df: "pandas.DataFrame") -> Block[ArrowRow]:
def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]: def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]:
stats = BlockExecStats.builder() stats = BlockExecStats.builder()
block = BlockAccessor.batch_to_block(ndarray) import pyarrow as pa
metadata = BlockAccessor.for_block(block).get_metadata( from ray.data.extensions import TensorArray
input_files=None, exec_stats=stats.build()
table = pa.Table.from_pydict({"value": TensorArray(ndarray)})
return (
table,
BlockAccessor.for_block(table).get_metadata(
input_files=None, exec_stats=stats.build()
),
) )
return block, metadata
def _get_metadata(table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata: def _get_metadata(table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata:

View file

@ -443,138 +443,21 @@ def test_range_table(ray_start_regular_shared):
assert ds.take() == [{"value": i} for i in range(10)] assert ds.take() == [{"value": i} for i in range(10)]
def test_tensors_basic(ray_start_regular_shared): def test_tensors(ray_start_regular_shared):
# Create directly. # Create directly.
tensor_shape = (3, 5) ds = ray.data.range_tensor(5, shape=(3, 5))
ds = ray.data.range_tensor(6, shape=tensor_shape)
assert str(ds) == ( assert str(ds) == (
"Dataset(num_blocks=6, num_rows=6, " "Dataset(num_blocks=5, num_rows=5, "
"schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>})" "schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})"
) )
# Test row iterator yields tensors.
for tensor in ds.iter_rows():
assert isinstance(tensor, np.ndarray)
assert tensor.shape == tensor_shape
# Test batch iterator yields tensors.
for tensor in ds.iter_batches(batch_size=2):
assert isinstance(tensor, np.ndarray)
assert tensor.shape == (2,) + tensor_shape
# Native format.
def np_mapper(arr):
assert isinstance(arr, np.ndarray)
return arr + 1
res = ray.data.range_tensor(2, shape=(2, 2)).map(np_mapper).take()
np.testing.assert_equal(res, [np.ones((2, 2)), 2 * np.ones((2, 2))])
# Pandas conversion. # Pandas conversion.
def pd_mapper(df): res = (
assert isinstance(df, pd.DataFrame) ray.data.range_tensor(10)
return df + 2 .map_batches(lambda t: t + 2, batch_format="pandas")
.take(2)
res = ray.data.range_tensor(2).map_batches(pd_mapper, batch_format="pandas").take()
np.testing.assert_equal(res, [np.array([2]), np.array([3])])
def test_tensors_shuffle(ray_start_regular_shared):
# Test Arrow table representation.
tensor_shape = (3, 5)
ds = ray.data.range_tensor(6, shape=tensor_shape)
shuffled_ds = ds.random_shuffle()
shuffled = shuffled_ds.take()
base = ds.take()
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
shuffled,
base,
)
np.testing.assert_equal(
sorted(shuffled, key=lambda arr: arr.min()),
sorted(base, key=lambda arr: arr.min()),
)
# Test Pandas table representation.
tensor_shape = (3, 5)
ds = ray.data.range_tensor(6, shape=tensor_shape)
ds = ds.map_batches(lambda df: df, batch_format="pandas")
shuffled_ds = ds.random_shuffle()
shuffled = shuffled_ds.take()
base = ds.take()
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
shuffled,
base,
)
np.testing.assert_equal(
sorted(shuffled, key=lambda arr: arr.min()),
sorted(base, key=lambda arr: arr.min()),
)
def test_tensors_sort(ray_start_regular_shared):
# Test Arrow table representation.
t = pa.table({"a": TensorArray(np.arange(32).reshape((2, 4, 4))), "b": [1, 2]})
ds = ray.data.from_arrow(t)
sorted_ds = ds.sort(key="b", descending=True)
sorted_arrs = [row["a"] for row in sorted_ds.take()]
base = [row["a"] for row in ds.take()]
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
sorted_arrs,
base,
)
np.testing.assert_equal(
sorted_arrs,
sorted(base, key=lambda arr: -arr.min()),
)
# Test Pandas table representation.
df = pd.DataFrame({"a": TensorArray(np.arange(32).reshape((2, 4, 4))), "b": [1, 2]})
ds = ray.data.from_pandas(df)
sorted_ds = ds.sort(key="b", descending=True)
sorted_arrs = [np.asarray(row["a"]) for row in sorted_ds.take()]
base = [np.asarray(row["a"]) for row in ds.take()]
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
sorted_arrs,
base,
)
np.testing.assert_equal(
sorted_arrs,
sorted(base, key=lambda arr: -arr.min()),
)
def test_tensors_inferred_from_map(ray_start_regular_shared):
# Test map.
ds = ray.data.range(10).map(lambda _: np.ones((4, 4)))
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=10, "
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
)
# Test map_batches.
ds = ray.data.range(16, parallelism=4).map_batches(
lambda _: np.ones((3, 4, 4)), batch_size=2
)
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=24, "
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
)
# Test flat_map.
ds = ray.data.range(10).flat_map(lambda _: [np.ones((4, 4)), np.ones((4, 4))])
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=20, "
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
) )
assert str(res) == "[{'value': array([2])}, {'value': array([3])}]"
def test_tensor_array_ops(ray_start_regular_shared): def test_tensor_array_ops(ray_start_regular_shared):

View file

@ -121,7 +121,7 @@ def test_from_numpy(ray_start_regular_shared, from_ref):
ds = ray.data.from_numpy_refs([ray.put(arr) for arr in arrs]) ds = ray.data.from_numpy_refs([ray.put(arr) for arr in arrs])
else: else:
ds = ray.data.from_numpy(arrs) ds = ray.data.from_numpy(arrs)
values = np.stack(ds.take(8)) values = np.stack([x["value"] for x in ds.take(8)])
np.testing.assert_array_equal(values, np.concatenate((arr1, arr2))) np.testing.assert_array_equal(values, np.concatenate((arr1, arr2)))
# Test from single NumPy ndarray. # Test from single NumPy ndarray.
@ -129,7 +129,7 @@ def test_from_numpy(ray_start_regular_shared, from_ref):
ds = ray.data.from_numpy_refs(ray.put(arr1)) ds = ray.data.from_numpy_refs(ray.put(arr1))
else: else:
ds = ray.data.from_numpy(arr1) ds = ray.data.from_numpy(arr1)
values = np.stack(ds.take(4)) values = np.stack([x["value"] for x in ds.take(4)])
np.testing.assert_array_equal(values, arr1) np.testing.assert_array_equal(values, arr1)
@ -197,28 +197,14 @@ def test_to_numpy_refs(ray_start_regular_shared):
# Tensor Dataset # Tensor Dataset
ds = ray.data.range_tensor(10, parallelism=2) ds = ray.data.range_tensor(10, parallelism=2)
arr = np.concatenate(ray.get(ds.to_numpy_refs())) arr = np.concatenate(ray.get(ds.to_numpy_refs(column="value")))
np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1)) np.testing.assert_equal(arr, np.expand_dims(np.arange(0, 10), 1))
# Table Dataset # Table Dataset
ds = ray.data.range_table(10) ds = ray.data.range_table(10)
arr = np.concatenate(ray.get(ds.to_numpy_refs())) arr = np.concatenate(ray.get(ds.to_numpy_refs(column="value")))
np.testing.assert_equal(arr, np.arange(0, 10)) np.testing.assert_equal(arr, np.arange(0, 10))
# Test multi-column Arrow dataset.
ds = ray.data.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
arrs = ray.get(ds.to_numpy_refs())
np.testing.assert_equal(
arrs, [{"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}]
)
# Test multi-column Pandas dataset.
ds = ray.data.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
arrs = ray.get(ds.to_numpy_refs())
np.testing.assert_equal(
arrs, [{"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}]
)
def test_to_arrow_refs(ray_start_regular_shared): def test_to_arrow_refs(ray_start_regular_shared):
n = 5 n = 5
@ -1012,9 +998,9 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path):
ds = ray.data.read_numpy(data_path, filesystem=fs) ds = ray.data.read_numpy(data_path, filesystem=fs)
assert str(ds) == ( assert str(ds) == (
"Dataset(num_blocks=2, num_rows=None, " "Dataset(num_blocks=2, num_rows=None, "
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})" "schema={value: <ArrowTensorType: shape=(1,), dtype=int64>})"
) )
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])]) assert str(ds.take(2)) == "[{'value': array([0])}, {'value': array([1])}]"
def test_numpy_read(ray_start_regular_shared, tmp_path): def test_numpy_read(ray_start_regular_shared, tmp_path):
@ -1024,9 +1010,9 @@ def test_numpy_read(ray_start_regular_shared, tmp_path):
ds = ray.data.read_numpy(path) ds = ray.data.read_numpy(path)
assert str(ds) == ( assert str(ds) == (
"Dataset(num_blocks=1, num_rows=10, " "Dataset(num_blocks=1, num_rows=10, "
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})" "schema={value: <ArrowTensorType: shape=(1,), dtype=int64>})"
) )
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])]) assert str(ds.take(2)) == "[{'value': array([0])}, {'value': array([1])}]"
def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path): def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path):
@ -1037,9 +1023,9 @@ def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path):
ds = ray.data.read_numpy(path, meta_provider=FastFileMetadataProvider()) ds = ray.data.read_numpy(path, meta_provider=FastFileMetadataProvider())
assert str(ds) == ( assert str(ds) == (
"Dataset(num_blocks=1, num_rows=10, " "Dataset(num_blocks=1, num_rows=10, "
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})" "schema={value: <ArrowTensorType: shape=(1,), dtype=int64>})"
) )
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])]) assert str(ds.take(2)) == "[{'value': array([0])}, {'value': array([1])}]"
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
ray.data.read_binary_files( ray.data.read_binary_files(
@ -1091,10 +1077,10 @@ def test_numpy_read_partitioned_with_filter(
ds = ray.data.read_numpy(base_dir, partition_filter=partition_path_filter) ds = ray.data.read_numpy(base_dir, partition_filter=partition_path_filter)
vals = [[1, 0], [1, 1], [1, 2], [3, 3], [3, 4], [3, 5]] vals = [[1, 0], [1, 1], [1, 2], [3, 3], [3, 4], [3, 5]]
val_str = "".join(f"array({v}, dtype=int8), " for v in vals)[:-2] val_str = "".join([f"{{'value': array({v}, dtype=int8)}}, " for v in vals])[:-2]
assert_base_partitioned_ds( assert_base_partitioned_ds(
ds, ds,
schema="{__value__: <ArrowTensorType: shape=(2,), dtype=int8>}", schema="{value: <ArrowTensorType: shape=(2,), dtype=int8>}",
sorted_values=f"[[{val_str}]]", sorted_values=f"[[{val_str}]]",
ds_take_transform_fn=lambda taken: [taken], ds_take_transform_fn=lambda taken: [taken],
sorted_values_transform_fn=lambda sorted_values: str(sorted_values), sorted_values_transform_fn=lambda sorted_values: str(sorted_values),
@ -1133,7 +1119,7 @@ def test_numpy_write(ray_start_regular_shared, fs, data_path, endpoint_url):
assert len(arr2) == 5 assert len(arr2) == 5
assert arr1.sum() == 10 assert arr1.sum() == 10
assert arr2.sum() == 35 assert arr2.sum() == 35
np.testing.assert_equal(ds.take(1), [np.array([0])]) assert str(ds.take(1)) == "[{'value': array([0])}]"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1172,7 +1158,7 @@ def test_numpy_write_block_path_provider(
assert len(arr2) == 5 assert len(arr2) == 5
assert arr1.sum() == 10 assert arr1.sum() == 10
assert arr2.sum() == 35 assert arr2.sum() == 35
np.testing.assert_equal(ds.take(1), [np.array([0])]) assert str(ds.take(1)) == "[{'value': array([0])}]"
def test_read_text(ray_start_regular_shared, tmp_path): def test_read_text(ray_start_regular_shared, tmp_path):

View file

@ -3,7 +3,6 @@
import time import time
import numpy as np import numpy as np
from typing import Optional from typing import Optional
import sys
import ray import ray
from ray import train from ray import train
@ -60,8 +59,6 @@ class DummyTrainer(DataParallelTrainer):
"""Make a debug train loop that runs for the given amount of runtime.""" """Make a debug train loop that runs for the given amount of runtime."""
def train_loop_per_worker(): def train_loop_per_worker():
import pandas as pd
rank = train.world_rank() rank = train.world_rank()
data_shard = train.get_dataset_shard("train") data_shard = train.get_dataset_shard("train")
start = time.perf_counter() start = time.perf_counter()
@ -78,16 +75,7 @@ class DummyTrainer(DataParallelTrainer):
batch_delay = time.perf_counter() - batch_start batch_delay = time.perf_counter() - batch_start
batch_delays.append(batch_delay) batch_delays.append(batch_delay)
num_batches += 1 num_batches += 1
if isinstance(batch, pd.DataFrame): num_bytes += int(batch.memory_usage(index=True, deep=True).sum())
num_bytes += int(
batch.memory_usage(index=True, deep=True).sum()
)
elif isinstance(batch, np.ndarray):
num_bytes += batch.nbytes
else:
# NOTE: This isn't recursive and will just return the size of
# the object pointers if list of non-primitive types.
num_bytes += sys.getsizeof(batch)
train.report( train.report(
bytes_read=num_bytes, bytes_read=num_bytes,
num_batches=num_batches, num_batches=num_batches,