mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Datasets] Adds tensor column support (tensors-in-tables) via Pandas/Arrow extension types/arrays. (#18301)
This commit is contained in:
parent
e427e4a467
commit
b30c41759d
13 changed files with 2190 additions and 67 deletions
282
doc/source/data/dataset-tensor-support.rst
Normal file
282
doc/source/data/dataset-tensor-support.rst
Normal file
|
@ -0,0 +1,282 @@
|
|||
.. _datasets_tensor_support:
|
||||
|
||||
Datasets Tensor Support
|
||||
=======================
|
||||
|
||||
Tensor-typed values
|
||||
-------------------
|
||||
|
||||
Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a Dataset of tensor-typed values.
|
||||
ds = ray.data.range_tensor(10000, shape=(3, 5))
|
||||
# -> Dataset(num_blocks=200, num_rows=10000,
|
||||
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)
|
||||
|
||||
ds.map_batches(lambda t: t + 2).show(2)
|
||||
# -> [[2 2 2 2 2]
|
||||
# [2 2 2 2 2]
|
||||
# [2 2 2 2 2]]
|
||||
# [[3 3 3 3 3]
|
||||
# [3 3 3 3 3]
|
||||
# [3 3 3 3 3]]
|
||||
|
||||
# Save to storage.
|
||||
ds.write_numpy("/tmp/tensor_out")
|
||||
|
||||
# Read from storage.
|
||||
ray.data.read_numpy("/tmp/tensor_out")
|
||||
# -> Dataset(num_blocks=200, num_rows=?,
|
||||
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)
|
||||
|
||||
Tensor datasets are also created whenever an array type is returned from a map function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a dataset of Python integers.
|
||||
ds = ray.data.range(10)
|
||||
# -> Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)
|
||||
|
||||
# It is now converted into a Tensor dataset.
|
||||
ds = ds.map_batches(lambda x: np.array(x))
|
||||
# -> Dataset(num_blocks=10, num_rows=10,
|
||||
# schema=<Tensor: shape=(None,), dtype=int64>)
|
||||
|
||||
Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Create a Dataset from a list of NumPy ndarray objects.
|
||||
arr1 = np.arange(0, 10)
|
||||
arr2 = np.arange(10, 20)
|
||||
ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)])
|
||||
|
||||
Tables with tensor columns
|
||||
--------------------------
|
||||
|
||||
In addition to tensor datasets, Datasets also supports tables with fixed-shape tensor columns, where each element in the column is a tensor (n-dimensional array) with the same shape. As an example, this allows you to use both Pandas and Ray Datasets to read, write, and manipulate a table with a column of e.g. images (2D arrays), with all conversions between Pandas, Arrow, and Parquet, and all application of aggregations/operations to the underlying image ndarrays, being taken care of by Ray Datasets.
|
||||
|
||||
With our Pandas extension type, :class:`TensorDtype <ray.data.extensions.tensor_extension.TensorDtype>`, and extension array, :class:`TensorArray <ray.data.extensions.tensor_extension.TensorArray>`, you can do familiar aggregations and arithmetic, comparison, and logical operations on a DataFrame containing a tensor column and the operations will be applied to the underlying tensors as expected. With our Arrow extension type, :class:`ArrowTensorType <ray.data.extensions.tensor_extension.ArrowTensorType>`, and extension array, :class:`ArrowTensorArray <ray.data.extensions.tensor_extension.ArrowTensorArray>`, you'll be able to import that DataFrame into Ray Datasets and read/write the data from/to the Parquet format.
|
||||
|
||||
Automatic conversion between the Pandas and Arrow extension types/arrays keeps the details under-the-hood, so you only have to worry about casting the column to a tensor column using our Pandas extension type when first ingesting the table into a ``Dataset``, whether from storage or in-memory. All table operations downstream from that cast should work automatically.
|
||||
|
||||
Reading existing serialized tensor columns
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If you already have a Parquet dataset with columns containing serialized tensors, you can have these tensor columns cast to our tensor extension type at read-time by giving a simple schema for the tensor columns. Note that these tensors must have been serialized as their raw NumPy ndarray bytes in C-contiguous order (e.g. serialized via ``ndarray.tobytes()``).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
path = "/tmp/some_path"
|
||||
|
||||
# Create a DataFrame with a list of serialized ndarrays as a column.
|
||||
# Note that we do not cast it to a tensor array, so each element in the
|
||||
# column is an opaque blob of bytes.
|
||||
arr = np.arange(24).reshape((3, 2, 2, 2))
|
||||
df = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": [tensor.tobytes() for tensor in arr]})
|
||||
|
||||
# Write the dataset to Parquet. The tensor column will be written as an
|
||||
# array of opaque byte blobs.
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(path)
|
||||
|
||||
# Read the Parquet files into a new Dataset, with the serialized tensors
|
||||
# automatically cast to our tensor column extension type.
|
||||
ds = ray.data.read_parquet(
|
||||
path, _tensor_column_schema={"two": (np.int, (2, 2, 2))})
|
||||
|
||||
# Internally, this column is represented with our Arrow tensor extension
|
||||
# type.
|
||||
print(ds.schema())
|
||||
# -> one: int64
|
||||
# two: extension<arrow.py_extension_type<ArrowTensorType>>
|
||||
|
||||
If your serialized tensors don't fit the above constraints (e.g. they're stored in Fortran-contiguous order, or they're pickled), you can manually cast this tensor column to our tensor extension type via a read-time user-defined function. This UDF will be pushed down to Ray Datasets' IO layer and executed on each block in parallel, as it's read from storage.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pickle
|
||||
import pyarrow as pa
|
||||
from ray.data.extensions import TensorArray
|
||||
|
||||
# Create a DataFrame with a list of pickled ndarrays as a column.
|
||||
arr = np.arange(24).reshape((3, 2, 2, 2))
|
||||
df = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": [pickle.dumps(tensor) for tensor in arr]})
|
||||
|
||||
# Write the dataset to Parquet. The tensor column will be written as an
|
||||
# array of opaque byte blobs.
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(path)
|
||||
|
||||
# Manually deserialize the tensor pickle bytes and cast to our tensor
|
||||
# extension type. For the sake of efficiency, we directly construct a
|
||||
# TensorArray rather than .astype() casting on the mutated column with
|
||||
# TensorDtype.
|
||||
def cast_udf(block: pa.Table) -> pa.Table:
|
||||
block = block.to_pandas()
|
||||
block["two"] = TensorArray([pickle.loads(a) for a in block["two"]])
|
||||
return pa.Table.from_pandas(block)
|
||||
|
||||
# Read the Parquet files into a new Dataset, applying the casting UDF
|
||||
# on-the-fly within the underlying read tasks.
|
||||
ds = ray.data.read_parquet(path, _block_udf=cast_udf)
|
||||
|
||||
# Internally, this column is represented with our Arrow tensor extension
|
||||
# type.
|
||||
print(ds.schema())
|
||||
# -> one: int64
|
||||
# two: extension<arrow.py_extension_type<ArrowTensorType>>
|
||||
|
||||
Please note that the ``_tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
|
||||
|
||||
Working with tensor column datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Now that the tensor column is properly typed and in a ``Dataset``, we can perform operations on the dataset as if it was a normal table:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Arrow and Pandas is now aware of this tensor column, so we can do the
|
||||
# typical DataFrame operations on this column.
|
||||
ds = ds.map_batches(lambda x: 2 * (x + 1), format="pandas")
|
||||
# -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s]
|
||||
print(ds)
|
||||
# -> Dataset(
|
||||
# num_blocks=1, num_rows=3,
|
||||
# schema=<class 'int',
|
||||
# class ray.data.extensions.tensor_extension.ArrowTensorType>)
|
||||
print([row["two"] for row in ds.take(5)])
|
||||
# -> [2, 4, 6, 8, 10]
|
||||
|
||||
Writing and reading tensor columns
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This dataset can then be written to Parquet files. The tensor column schema will be preserved via the Pandas and Arrow extension types and associated metadata, allowing us to later read the Parquet files into a Dataset without needing to specify a column casting schema. This Pandas --> Arrow --> Parquet --> Arrow --> Pandas conversion support makes working with tensor columns extremely easy when using Ray Datasets to both write and read data.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# You can write the dataset to Parquet.
|
||||
ds.write_parquet("/some/path")
|
||||
# And you can read it back.
|
||||
read_ds = ray.data.read_parquet("/some/path")
|
||||
print(read_ds.schema())
|
||||
# -> one: int64
|
||||
# two: extension<arrow.py_extension_type<ArrowTensorType>>
|
||||
|
||||
End-to-end workflow with our Pandas extension type
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If working with in-memory Pandas DataFrames that you want to analyze, manipulate, store, and eventually read, the Pandas/Arrow extension types/arrays make it easy to extend this end-to-end workflow to tensor columns.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.data.extensions import TensorDtype
|
||||
|
||||
# Create a DataFrame with a list of ndarrays as a column.
|
||||
df = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": list(np.arange(24).reshape((3, 2, 2, 2)))})
|
||||
# Note the opaque np.object dtype for this column.
|
||||
print(df.dtypes)
|
||||
# -> one int64
|
||||
# two object
|
||||
# dtype: object
|
||||
|
||||
# Cast column to our TensorDtype Pandas extension type.
|
||||
df["two"] = df["two"].astype(TensorDtype())
|
||||
|
||||
# Note that the column dtype is now TensorDtype instead of
|
||||
# np.object.
|
||||
print(df.dtypes)
|
||||
# -> one int64
|
||||
# two TensorDtype
|
||||
# dtype: object
|
||||
|
||||
# Pandas is now aware of this tensor column, and we can do the
|
||||
# typical DataFrame operations on this column.
|
||||
col = 2 * df["two"]
|
||||
# The ndarrays underlying the tensor column will be manipulated,
|
||||
# but the column itself will continue to be a Pandas type.
|
||||
print(type(col))
|
||||
# -> pandas.core.series.Series
|
||||
print(col)
|
||||
# -> 0 [[[ 2 4]
|
||||
# [ 6 8]]
|
||||
# [[10 12]
|
||||
# [14 16]]]
|
||||
# 1 [[[18 20]
|
||||
# [22 24]]
|
||||
# [[26 28]
|
||||
# [30 32]]]
|
||||
# 2 [[[34 36]
|
||||
# [38 40]]
|
||||
# [[42 44]
|
||||
# [46 48]]]
|
||||
# Name: two, dtype: TensorDtype
|
||||
|
||||
# Once you do an aggregation on that column that returns a single
|
||||
# row's value, you get back our TensorArrayElement type.
|
||||
tensor = col.mean()
|
||||
print(type(tensor))
|
||||
# -> ray.data.extensions.tensor_extension.TensorArrayElement
|
||||
print(tensor)
|
||||
# -> array([[[18., 20.],
|
||||
# [22., 24.]],
|
||||
# [[26., 28.],
|
||||
# [30., 32.]]])
|
||||
|
||||
# This is a light wrapper around a NumPy ndarray, and can easily
|
||||
# be converted to an ndarray.
|
||||
type(tensor.to_numpy())
|
||||
# -> numpy.ndarray
|
||||
|
||||
# In addition to doing Pandas operations on the tensor column,
|
||||
# you can now put the DataFrame directly into a Dataset.
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
# Internally, this column is represented with the corresponding
|
||||
# Arrow tensor extension type.
|
||||
print(ds.schema())
|
||||
# -> one: int64
|
||||
# two: extension<arrow.py_extension_type<ArrowTensorType>>
|
||||
|
||||
# You can write the dataset to Parquet.
|
||||
ds.write_parquet("/some/path")
|
||||
# And you can read it back.
|
||||
read_ds = ray.data.read_parquet("/some/path")
|
||||
print(read_ds.schema())
|
||||
# -> one: int64
|
||||
# two: extension<arrow.py_extension_type<ArrowTensorType>>
|
||||
|
||||
read_df = ray.get(read_ds.to_pandas())[0]
|
||||
print(read_df.dtypes)
|
||||
# -> one int64
|
||||
# two TensorDtype
|
||||
# dtype: object
|
||||
|
||||
# The tensor extension type is preserved along the
|
||||
# Pandas --> Arrow --> Parquet --> Arrow --> Pandas
|
||||
# conversion chain.
|
||||
print(read_df.equals(df))
|
||||
# -> True
|
||||
|
||||
Limitations
|
||||
~~~~~~~~~~~
|
||||
|
||||
This feature currently comes with a few known limitations that we are either actively working on addressing or have already implemented workarounds for.
|
||||
|
||||
* All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here <https://github.com/ray-project/ray/issues/18316>`__.
|
||||
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``_tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
|
||||
* Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue <https://github.com/pytorch/pytorch/issues/51156>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18314>`__.
|
||||
* Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR <https://github.com/pandas-dev/pandas/pull/43160>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18315>`__.
|
|
@ -189,7 +189,7 @@ Datasets can be created from files on local disk or remote datasources such as S
|
|||
# Read multiple directories.
|
||||
ds = ray.data.read_csv(["s3://bucket/path1", "s3://bucket/path2"])
|
||||
|
||||
Finally, you can create a Dataset from existing data in the Ray object store or Ray compatible distributed DataFrames:
|
||||
Finally, you can create a ``Dataset`` from existing data in the Ray object store or Ray-compatible distributed DataFrames:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -219,6 +219,13 @@ Datasets can be written to local or remote storage using ``.write_csv()``, ``.wr
|
|||
ray.data.range(10000).repartition(1).write_csv("/tmp/output2")
|
||||
# -> /tmp/output2/data0.csv
|
||||
|
||||
You can also convert a ``Dataset`` to Ray-compatibile distributed DataFrames:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Convert a Ray Dataset into a Dask-on-Ray DataFrame.
|
||||
dask_df = ds.to_dask()
|
||||
|
||||
Transforming Datasets
|
||||
---------------------
|
||||
|
||||
|
@ -325,60 +332,6 @@ Datasets can be split up into disjoint sub-datasets. Locality-aware splitting is
|
|||
ray.get([w.train.remote(s) for s in shards])
|
||||
# -> [650, 650, ...]
|
||||
|
||||
Tensor-typed values
|
||||
-------------------
|
||||
|
||||
Datasets support tensor-typed values, which are represented in-memory as Arrow tensors (i.e., np.ndarray format). Tensor datasets can be read from and written to ``.npy`` files. Here are some examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a Dataset of tensor-typed values.
|
||||
ds = ray.data.range_tensor(10000, shape=(3, 5))
|
||||
# -> Dataset(num_blocks=200, num_rows=10000,
|
||||
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)
|
||||
|
||||
ds.map_batches(lambda t: t + 2).show(2)
|
||||
# -> [[2 2 2 2 2]
|
||||
# [2 2 2 2 2]
|
||||
# [2 2 2 2 2]]
|
||||
# [[3 3 3 3 3]
|
||||
# [3 3 3 3 3]
|
||||
# [3 3 3 3 3]]
|
||||
|
||||
# Save to storage.
|
||||
ds.write_numpy("/tmp/tensor_out")
|
||||
|
||||
# Read from storage.
|
||||
ray.data.read_numpy("/tmp/tensor_out")
|
||||
# -> Dataset(num_blocks=200, num_rows=?,
|
||||
# schema=<Tensor: shape=(None, 3, 5), dtype=int64>)
|
||||
|
||||
Tensor datasets are also created whenever an array type is returned from a map function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a dataset of Python integers.
|
||||
ds = ray.data.range(10)
|
||||
# -> Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)
|
||||
|
||||
# It is now converted into a Tensor dataset.
|
||||
ds = ds.map_batches(lambda x: np.array(x))
|
||||
# -> Dataset(num_blocks=10, num_rows=10,
|
||||
# schema=<Tensor: shape=(None,), dtype=int64>)
|
||||
|
||||
Tensor datasets can also be created from NumPy ndarrays that are already stored in the Ray object store:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Create a Dataset from a list of NumPy ndarray objects.
|
||||
arr1 = np.arange(0, 10)
|
||||
arr2 = np.arange(10, 20)
|
||||
ds = ray.data.from_numpy([ray.put(arr1), ray.put(arr2)])
|
||||
|
||||
Limitations: currently tensor-typed values cannot be nested in tabular records (e.g., as in TFRecord / Petastorm format). This is planned for development.
|
||||
|
||||
Custom datasources
|
||||
------------------
|
||||
|
||||
|
|
|
@ -34,6 +34,21 @@ DatasetPipeline API
|
|||
.. autoclass:: ray.data.dataset_pipeline.DatasetPipeline
|
||||
:members:
|
||||
|
||||
Tensor Column Extension API
|
||||
---------------------------
|
||||
|
||||
.. autoclass:: ray.data.extensions.tensor_extension.TensorDtype
|
||||
:members:
|
||||
|
||||
.. autoclass:: ray.data.extensions.tensor_extension.TensorArray
|
||||
:members:
|
||||
|
||||
.. autoclass:: ray.data.extensions.tensor_extension.ArrowTensorType
|
||||
:members:
|
||||
|
||||
.. autoclass:: ray.data.extensions.tensor_extension.ArrowTensorArray
|
||||
:members:
|
||||
|
||||
Custom Datasource API
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -277,6 +277,7 @@ Papers
|
|||
:caption: Ray Data
|
||||
|
||||
data/dataset.rst
|
||||
data/dataset-tensor-support.rst
|
||||
data/dataset-pipeline.rst
|
||||
data/package-ref.rst
|
||||
data/dask-on-ray.rst
|
||||
|
|
|
@ -1183,6 +1183,8 @@ class Dataset(Generic[T]):
|
|||
target_col = batch.pop(label_column)
|
||||
if feature_columns:
|
||||
batch = batch[feature_columns]
|
||||
# TODO(Clark): Support batches containing our extension array
|
||||
# TensorArray.
|
||||
yield batch.values, target_col.values
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Tuple, Union, Any, TYPE_CHECKING
|
||||
from typing import Callable, Optional, List, Tuple, Union, Any, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -36,6 +36,7 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
|
|||
paths: Union[str, List[str]],
|
||||
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
||||
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
|
||||
_block_udf: Optional[Callable[[Block], Block]] = None,
|
||||
**reader_args) -> List[ReadTask]:
|
||||
"""Creates and returns read tasks for a file-based datasource.
|
||||
"""
|
||||
|
@ -66,7 +67,10 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
|
|||
builder.add_block(data)
|
||||
else:
|
||||
builder.add(data)
|
||||
return builder.build()
|
||||
block = builder.build()
|
||||
if _block_udf is not None:
|
||||
block = _block_udf(block)
|
||||
return block
|
||||
|
||||
read_tasks = []
|
||||
for read_paths, file_sizes in zip(
|
||||
|
@ -111,6 +115,7 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
|
|||
path: str,
|
||||
dataset_uuid: str,
|
||||
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
||||
_block_udf: Optional[Callable[[Block], Block]] = None,
|
||||
**write_args) -> List[ObjectRef[WriteResult]]:
|
||||
"""Creates and returns write tasks for a file-based datasource."""
|
||||
path, filesystem = _resolve_paths_and_filesystem(path, filesystem)
|
||||
|
@ -124,6 +129,8 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
|
|||
fs = filesystem
|
||||
if isinstance(fs, _S3FileSystemWrapper):
|
||||
fs = fs.unwrap()
|
||||
if _block_udf is not None:
|
||||
block = _block_udf(block)
|
||||
with fs.open_output_stream(write_path) as f:
|
||||
_write_block_to_file(f, BlockAccessor.for_block(block))
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import logging
|
||||
from typing import Optional, List, Union, TYPE_CHECKING
|
||||
from typing import Callable, Optional, List, Union, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow
|
||||
|
||||
from ray.data.block import BlockAccessor
|
||||
from ray.data.block import Block, BlockAccessor
|
||||
from ray.data.datasource.datasource import ReadTask
|
||||
from ray.data.datasource.file_based_datasource import (
|
||||
FileBasedDatasource, _resolve_paths_and_filesystem)
|
||||
|
@ -30,6 +30,7 @@ class ParquetDatasource(FileBasedDatasource):
|
|||
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
|
||||
_block_udf: Optional[Callable[[Block], Block]] = None,
|
||||
**reader_args) -> List[ReadTask]:
|
||||
"""Creates and returns read tasks for a Parquet file-based datasource.
|
||||
"""
|
||||
|
@ -58,7 +59,6 @@ class ParquetDatasource(FileBasedDatasource):
|
|||
if columns:
|
||||
schema = pa.schema([schema.field(column) for column in columns],
|
||||
schema.metadata)
|
||||
pieces = pq_ds.pieces
|
||||
|
||||
def read_pieces(serialized_pieces: List[str]):
|
||||
# Implicitly trigger S3 subsystem initialization by importing
|
||||
|
@ -97,18 +97,36 @@ class ParquetDatasource(FileBasedDatasource):
|
|||
table = pa.concat_tables(tables, promote=True)
|
||||
elif len(tables) == 1:
|
||||
table = tables[0]
|
||||
if _block_udf is not None:
|
||||
table = _block_udf(table)
|
||||
# If len(tables) == 0, all fragments were empty, and we return the
|
||||
# empty table from the last fragment.
|
||||
return table
|
||||
|
||||
if _block_udf is not None:
|
||||
# Try to infer dataset schema by passing dummy table through UDF.
|
||||
dummy_table = schema.empty_table()
|
||||
try:
|
||||
inferred_schema = _block_udf(dummy_table).schema
|
||||
inferred_schema = inferred_schema.with_metadata(
|
||||
schema.metadata)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Failed to infer schema of dataset by passing dummy table "
|
||||
"through UDF due to the following exception:",
|
||||
exc_info=True)
|
||||
inferred_schema = schema
|
||||
else:
|
||||
inferred_schema = schema
|
||||
read_tasks = []
|
||||
for pieces_ in np.array_split(pieces, parallelism):
|
||||
if len(pieces_) == 0:
|
||||
for pieces in np.array_split(pq_ds.pieces, parallelism):
|
||||
if len(pieces) == 0:
|
||||
continue
|
||||
metadata = _get_metadata(pieces_, schema)
|
||||
pieces_ = [cloudpickle.dumps(p) for p in pieces_]
|
||||
metadata = _get_metadata(pieces, inferred_schema)
|
||||
pieces = [cloudpickle.dumps(p) for p in pieces]
|
||||
read_tasks.append(
|
||||
ReadTask(lambda pieces=pieces_: read_pieces(pieces), metadata))
|
||||
ReadTask(
|
||||
lambda pieces_=pieces: read_pieces(pieces_), metadata))
|
||||
|
||||
return read_tasks
|
||||
|
||||
|
|
10
python/ray/data/extensions/__init__.py
Normal file
10
python/ray/data/extensions/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from ray.data.extensions.tensor_extension import (
|
||||
TensorDtype, TensorArray, ArrowTensorType, ArrowTensorArray)
|
||||
|
||||
__all__ = [
|
||||
# Tensor array extension.
|
||||
"TensorDtype",
|
||||
"TensorArray",
|
||||
"ArrowTensorType",
|
||||
"ArrowTensorArray",
|
||||
]
|
1305
python/ray/data/extensions/tensor_extension.py
Normal file
1305
python/ray/data/extensions/tensor_extension.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -42,7 +42,17 @@ class ArrowRow:
|
|||
return self.as_pydict().items()
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._row[key][0].as_py()
|
||||
col = self._row[key]
|
||||
if len(col) == 0:
|
||||
return None
|
||||
item = col[0]
|
||||
try:
|
||||
# Try to interpret this as a pyarrow.Scalar value.
|
||||
return item.as_py()
|
||||
except AttributeError:
|
||||
# Assume that this row is an element of an extension array, and
|
||||
# that it is bypassing pyarrow's scalar model.
|
||||
return item
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return self.as_pydict() == other
|
||||
|
|
|
@ -13,7 +13,8 @@ def _check_pyarrow_version():
|
|||
try:
|
||||
version_info = pkg_resources.require("pyarrow")
|
||||
version_str = version_info[0].version
|
||||
version = tuple(int(n) for n in version_str.split("."))
|
||||
version = tuple(
|
||||
int(n) for n in version_str.split(".") if "dev" not in n)
|
||||
if version < MIN_PYARROW_VERSION:
|
||||
raise ImportError(
|
||||
"Datasets requires pyarrow >= "
|
||||
|
|
|
@ -189,6 +189,8 @@ def read_parquet(paths: Union[str, List[str]],
|
|||
columns: Optional[List[str]] = None,
|
||||
parallelism: int = 200,
|
||||
ray_remote_args: Dict[str, Any] = None,
|
||||
_tensor_column_schema: Optional[Dict[str, Tuple[
|
||||
np.dtype, Tuple[int, ...]]]] = None,
|
||||
**arrow_parquet_args) -> Dataset[ArrowRow]:
|
||||
"""Create an Arrow dataset from parquet files.
|
||||
|
||||
|
@ -205,11 +207,43 @@ def read_parquet(paths: Union[str, List[str]],
|
|||
columns: A list of column names to read.
|
||||
parallelism: The amount of parallelism to use for the dataset.
|
||||
ray_remote_args: kwargs passed to ray.remote in the read tasks.
|
||||
_tensor_column_schema: A dict of column name --> tensor dtype and shape
|
||||
mappings for converting a Parquet column containing serialized
|
||||
tensors (ndarrays) as their elements to our tensor column extension
|
||||
type. This assumes that the tensors were serialized in the raw
|
||||
NumPy array format in C-contiguous order (e.g. via
|
||||
`arr.tobytes()`).
|
||||
arrow_parquet_args: Other parquet read options to pass to pyarrow.
|
||||
|
||||
Returns:
|
||||
Dataset holding Arrow records read from the specified paths.
|
||||
"""
|
||||
if _tensor_column_schema is not None:
|
||||
existing_block_udf = arrow_parquet_args.pop("_block_udf", None)
|
||||
|
||||
def _block_udf(block: "pyarrow.Table") -> "pyarrow.Table":
|
||||
from ray.data.extensions import ArrowTensorArray
|
||||
|
||||
for tensor_col_name, (dtype,
|
||||
shape) in _tensor_column_schema.items():
|
||||
# NOTE(Clark): We use NumPy to consolidate these potentially
|
||||
# non-contiguous buffers, and to do buffer bookkeeping in
|
||||
# general.
|
||||
np_col = np.array([
|
||||
np.ndarray(shape, buffer=buf.as_buffer(), dtype=dtype)
|
||||
for buf in block.column(tensor_col_name)
|
||||
])
|
||||
|
||||
block = block.set_column(
|
||||
block._ensure_integer_index(tensor_col_name),
|
||||
tensor_col_name, ArrowTensorArray.from_numpy(np_col))
|
||||
if existing_block_udf is not None:
|
||||
# Apply UDF after casting the tensor columns.
|
||||
block = existing_block_udf(block)
|
||||
return block
|
||||
|
||||
arrow_parquet_args["_block_udf"] = _block_udf
|
||||
|
||||
return read_datasource(
|
||||
ParquetDatasource(),
|
||||
parallelism=parallelism,
|
||||
|
|
|
@ -21,6 +21,8 @@ from ray.data.datasource import DummyOutputDatasource
|
|||
from ray.data.datasource.csv_datasource import CSVDatasource
|
||||
from ray.data.block import BlockAccessor
|
||||
from ray.data.datasource.file_based_datasource import _unwrap_protocol
|
||||
from ray.data.extensions.tensor_extension import (
|
||||
TensorArray, TensorDtype, ArrowTensorType, ArrowTensorArray)
|
||||
import ray.data.tests.util as util
|
||||
from ray.data.tests.conftest import * # noqa
|
||||
|
||||
|
@ -187,6 +189,416 @@ def test_tensors(ray_start_regular_shared):
|
|||
"schema=<Tensor: shape=(None, 2, 2, 2), dtype=float64>)"), ds
|
||||
|
||||
|
||||
def test_tensor_array_ops(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
|
||||
df = pd.DataFrame({"one": [1, 2, 3], "two": TensorArray(arr)})
|
||||
|
||||
def apply_arithmetic_ops(arr):
|
||||
return 2 * (arr + 1) / 3
|
||||
|
||||
def apply_comparison_ops(arr):
|
||||
return arr % 2 == 0
|
||||
|
||||
def apply_logical_ops(arr):
|
||||
return arr & (3 * arr) | (5 * arr)
|
||||
|
||||
# Op tests, using NumPy as the groundtruth.
|
||||
np.testing.assert_equal(
|
||||
apply_arithmetic_ops(arr), apply_arithmetic_ops(df["two"]))
|
||||
|
||||
np.testing.assert_equal(
|
||||
apply_comparison_ops(arr), apply_comparison_ops(df["two"]))
|
||||
|
||||
np.testing.assert_equal(
|
||||
apply_logical_ops(arr), apply_logical_ops(df["two"]))
|
||||
|
||||
|
||||
def test_tensor_array_reductions(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
|
||||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)})
|
||||
|
||||
# Reduction tests, using NumPy as the groundtruth.
|
||||
for name, reducer in TensorArray.SUPPORTED_REDUCERS.items():
|
||||
np_kwargs = {}
|
||||
if name in ("std", "var"):
|
||||
# Pandas uses a ddof default of 1 while NumPy uses 0.
|
||||
# Give NumPy a ddof kwarg of 1 in order to ensure equivalent
|
||||
# standard deviation calculations.
|
||||
np_kwargs["ddof"] = 1
|
||||
np.testing.assert_equal(df["two"].agg(name),
|
||||
reducer(arr, axis=0, **np_kwargs))
|
||||
|
||||
|
||||
def test_arrow_tensor_array_getitem(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
|
||||
t_arr = ArrowTensorArray.from_numpy(arr)
|
||||
|
||||
for idx in range(outer_dim):
|
||||
np.testing.assert_array_equal(t_arr[idx], arr[idx])
|
||||
|
||||
# Test __iter__.
|
||||
for t_subarr, subarr in zip(t_arr, arr):
|
||||
np.testing.assert_array_equal(t_subarr, subarr)
|
||||
|
||||
# Test to_pylist.
|
||||
np.testing.assert_array_equal(t_arr.to_pylist(), list(arr))
|
||||
|
||||
# Test slicing and indexing.
|
||||
t_arr2 = t_arr[1:]
|
||||
|
||||
np.testing.assert_array_equal(t_arr2.to_numpy(), arr[1:])
|
||||
|
||||
for idx in range(1, outer_dim):
|
||||
np.testing.assert_array_equal(t_arr2[idx - 1], arr[idx])
|
||||
|
||||
|
||||
def test_tensors_in_tables_from_pandas(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": list(arr)})
|
||||
# Cast column to tensor extension dtype.
|
||||
df["two"] = df["two"].astype(TensorDtype())
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds_df = ray.get(ds.to_pandas())[0]
|
||||
assert ds_df.equals(df)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_roundtrip(ray_start_regular_shared,
|
||||
tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_with_schema(ray_start_regular_shared,
|
||||
tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
schema = pa.schema([
|
||||
("one", pa.int32()),
|
||||
("two", ArrowTensorType(inner_shape, pa.from_numpy_dtype(arr.dtype))),
|
||||
])
|
||||
ds = ray.data.read_parquet(str(tmp_path), schema=schema)
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_pickle_manual_serde(
|
||||
ray_start_regular_shared, tmp_path):
|
||||
import pickle
|
||||
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({
|
||||
"one": list(range(outer_dim)),
|
||||
"two": [pickle.dumps(a) for a in arr]
|
||||
})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
|
||||
# Manually deserialize the tensor pickle bytes and cast to our tensor
|
||||
# extension type.
|
||||
def deser_mapper(batch: pd.DataFrame):
|
||||
batch["two"] = [pickle.loads(a) for a in batch["two"]]
|
||||
batch["two"] = batch["two"].astype(TensorDtype())
|
||||
return batch
|
||||
|
||||
casted_ds = ds.map_batches(deser_mapper, batch_format="pandas")
|
||||
|
||||
values = [[s["one"], s["two"]] for s in casted_ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
# Manually deserialize the pickle tensor bytes and directly cast it to a
|
||||
# TensorArray.
|
||||
def deser_mapper_direct(batch: pd.DataFrame):
|
||||
batch["two"] = TensorArray([pickle.loads(a) for a in batch["two"]])
|
||||
return batch
|
||||
|
||||
casted_ds = ds.map_batches(deser_mapper_direct, batch_format="pandas")
|
||||
|
||||
values = [[s["one"], s["two"]] for s in casted_ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_bytes_manual_serde(ray_start_regular_shared,
|
||||
tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({
|
||||
"one": list(range(outer_dim)),
|
||||
"two": [a.tobytes() for a in arr]
|
||||
})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
|
||||
tensor_col_name = "two"
|
||||
|
||||
# Manually deserialize the tensor bytes and cast to a TensorArray.
|
||||
def np_deser_mapper(batch: pa.Table):
|
||||
# NOTE(Clark): We use NumPy to consolidate these potentially
|
||||
# non-contiguous buffers, and to do buffer bookkeeping in general.
|
||||
np_col = np.array([
|
||||
np.ndarray(inner_shape, buffer=buf.as_buffer(), dtype=arr.dtype)
|
||||
for buf in batch.column(tensor_col_name)
|
||||
])
|
||||
|
||||
return batch.set_column(
|
||||
batch._ensure_integer_index(tensor_col_name), tensor_col_name,
|
||||
ArrowTensorArray.from_numpy(np_col))
|
||||
|
||||
ds = ds.map_batches(np_deser_mapper, batch_format="pyarrow")
|
||||
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_bytes_manual_serde_udf(
|
||||
ray_start_regular_shared, tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
tensor_col_name = "two"
|
||||
df = pd.DataFrame({
|
||||
"one": list(range(outer_dim)),
|
||||
tensor_col_name: [a.tobytes() for a in arr]
|
||||
})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
|
||||
# Manually deserialize the tensor bytes and cast to a TensorArray.
|
||||
def np_deser_udf(block: pa.Table):
|
||||
# NOTE(Clark): We use NumPy to consolidate these potentially
|
||||
# non-contiguous buffers, and to do buffer bookkeeping in general.
|
||||
np_col = np.array([
|
||||
np.ndarray(inner_shape, buffer=buf.as_buffer(), dtype=arr.dtype)
|
||||
for buf in block.column(tensor_col_name)
|
||||
])
|
||||
|
||||
return block.set_column(
|
||||
block._ensure_integer_index(tensor_col_name), tensor_col_name,
|
||||
ArrowTensorArray.from_numpy(np_col))
|
||||
|
||||
ds = ray.data.read_parquet(str(tmp_path), _block_udf=np_deser_udf)
|
||||
|
||||
assert isinstance(ds.schema().field_by_name(tensor_col_name).type,
|
||||
ArrowTensorType)
|
||||
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_bytes_manual_serde_col_schema(
|
||||
ray_start_regular_shared, tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
tensor_col_name = "two"
|
||||
df = pd.DataFrame({
|
||||
"one": list(range(outer_dim)),
|
||||
tensor_col_name: [a.tobytes() for a in arr]
|
||||
})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
|
||||
def _block_udf(block: pa.Table):
|
||||
df = block.to_pandas()
|
||||
df[tensor_col_name] += 1
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
ds = ray.data.read_parquet(
|
||||
str(tmp_path),
|
||||
_block_udf=_block_udf,
|
||||
_tensor_column_schema={tensor_col_name: (arr.dtype, inner_shape)})
|
||||
|
||||
assert isinstance(ds.schema().field_by_name(tensor_col_name).type,
|
||||
ArrowTensorType)
|
||||
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr + 1))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=("Waiting for Arrow to support registering custom ExtensionType "
|
||||
"casting kernels. See "
|
||||
"https://issues.apache.org/jira/browse/ARROW-5890#"))
|
||||
def test_tensors_in_tables_parquet_bytes_with_schema(ray_start_regular_shared,
|
||||
tmp_path):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df = pd.DataFrame({
|
||||
"one": list(range(outer_dim)),
|
||||
"two": [a.tobytes() for a in arr]
|
||||
})
|
||||
ds = ray.data.from_pandas([ray.put(df)])
|
||||
ds.write_parquet(str(tmp_path))
|
||||
schema = pa.schema([
|
||||
("one", pa.int32()),
|
||||
("two", ArrowTensorType(inner_shape, pa.from_numpy_dtype(arr.dtype))),
|
||||
])
|
||||
ds = ray.data.read_parquet(str(tmp_path), schema=schema)
|
||||
values = [[s["one"], s["two"]] for s in ds.take()]
|
||||
expected = list(zip(list(range(outer_dim)), arr))
|
||||
for v, e in zip(sorted(values), expected):
|
||||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=("Waiting for pytorch to support tensor creation from objects that "
|
||||
"implement the __array__ interface. See "
|
||||
"https://github.com/pytorch/pytorch/issues/51156"))
|
||||
@pytest.mark.parametrize("pipelined", [False, True])
|
||||
def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined):
|
||||
import torch
|
||||
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df1 = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": TensorArray(arr),
|
||||
"label": [1.0, 2.0, 3.0]
|
||||
})
|
||||
arr2 = np.arange(num_items, 2 * num_items).reshape(shape)
|
||||
df2 = pd.DataFrame({
|
||||
"one": [4, 5, 6],
|
||||
"two": TensorArray(arr2),
|
||||
"label": [4.0, 5.0, 6.0]
|
||||
})
|
||||
df = pd.concat([df1, df2])
|
||||
ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)])
|
||||
ds = maybe_pipeline(ds, pipelined)
|
||||
torchd = ds.to_torch(label_column="label", batch_size=2)
|
||||
|
||||
num_epochs = 2
|
||||
for _ in range(num_epochs):
|
||||
iterations = []
|
||||
for batch in iter(torchd):
|
||||
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
|
||||
combined_iterations = np.concatenate(iterations)
|
||||
assert np.array_equal(np.sort(df.values), np.sort(combined_iterations))
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"Waiting for Pandas DataFrame.values for extension arrays fix to be "
|
||||
"released. See https://github.com/pandas-dev/pandas/pull/43160"))
|
||||
@pytest.mark.parametrize("pipelined", [False, True])
|
||||
def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined):
|
||||
import tensorflow as tf
|
||||
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim, ) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape).astype(np.float)
|
||||
# TODO(Clark): Ensure that heterogeneous columns is properly supported
|
||||
# (tf.RaggedTensorSpec)
|
||||
df1 = pd.DataFrame({
|
||||
"one": TensorArray(arr),
|
||||
"two": TensorArray(arr),
|
||||
"label": TensorArray(arr),
|
||||
})
|
||||
arr2 = np.arange(num_items, 2 * num_items).reshape(shape).astype(np.float)
|
||||
df2 = pd.DataFrame({
|
||||
"one": TensorArray(arr2),
|
||||
"two": TensorArray(arr2),
|
||||
"label": TensorArray(arr2),
|
||||
})
|
||||
df = pd.concat([df1, df2])
|
||||
ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)])
|
||||
ds = maybe_pipeline(ds, pipelined)
|
||||
tfd = ds.to_tf(
|
||||
label_column="label",
|
||||
output_signature=(tf.TensorSpec(
|
||||
shape=(None, 2, 2, 2, 2), dtype=tf.float32),
|
||||
tf.TensorSpec(
|
||||
shape=(None, 1, 2, 2, 2), dtype=tf.float32)))
|
||||
iterations = []
|
||||
for batch in tfd.as_numpy_iterator():
|
||||
iterations.append(np.concatenate((batch[0], batch[1]), axis=1))
|
||||
combined_iterations = np.concatenate(iterations)
|
||||
arr = np.array(
|
||||
[[np.asarray(v) for v in values] for values in df.to_numpy()])
|
||||
np.testing.assert_array_equal(arr, combined_iterations)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fs,data_path", [(None, lazy_fixture("local_path")),
|
||||
(lazy_fixture("local_fs"), lazy_fixture("local_path")),
|
||||
|
@ -639,6 +1051,55 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared,
|
|||
assert sorted(values) == [[1, "a"], [1, "a"]]
|
||||
|
||||
|
||||
def test_parquet_read_with_udf(ray_start_regular_shared, tmp_path):
|
||||
one_data = list(range(6))
|
||||
df = pd.DataFrame({
|
||||
"one": one_data,
|
||||
"two": 2 * ["a"] + 2 * ["b"] + 2 * ["c"]
|
||||
})
|
||||
table = pa.Table.from_pandas(df)
|
||||
pq.write_to_dataset(
|
||||
table,
|
||||
root_path=str(tmp_path),
|
||||
partition_cols=["one"],
|
||||
use_legacy_dataset=False)
|
||||
|
||||
def _block_udf(block: pa.Table):
|
||||
df = block.to_pandas()
|
||||
df["one"] += 1
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
# 1 block/read task
|
||||
|
||||
ds = ray.data.read_parquet(
|
||||
str(tmp_path), parallelism=1, _block_udf=_block_udf)
|
||||
|
||||
ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()])
|
||||
assert len(ds._blocks._blocks) == 1
|
||||
np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1)
|
||||
|
||||
# 2 blocks/read tasks
|
||||
|
||||
ds = ray.data.read_parquet(
|
||||
str(tmp_path), parallelism=2, _block_udf=_block_udf)
|
||||
|
||||
ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()])
|
||||
assert len(ds._blocks._blocks) == 2
|
||||
np.testing.assert_array_equal(sorted(ones), np.array(one_data) + 1)
|
||||
|
||||
# 2 blocks/read tasks, 1 empty block
|
||||
|
||||
ds = ray.data.read_parquet(
|
||||
str(tmp_path),
|
||||
parallelism=2,
|
||||
filter=(pa.dataset.field("two") == "a"),
|
||||
_block_udf=_block_udf)
|
||||
|
||||
ones, twos = zip(*[[s["one"], s["two"]] for s in ds.take()])
|
||||
assert len(ds._blocks._blocks) == 2
|
||||
np.testing.assert_array_equal(sorted(ones), np.array(one_data[:2]) + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fs,data_path,endpoint_url", [
|
||||
(None, lazy_fixture("local_path"), None),
|
||||
(lazy_fixture("local_fs"), lazy_fixture("local_path"), None),
|
||||
|
@ -673,6 +1134,30 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url):
|
|||
fs.delete_dir(_unwrap_protocol(path))
|
||||
|
||||
|
||||
def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path):
|
||||
data_path = str(tmp_path)
|
||||
one_data = list(range(6))
|
||||
df1 = pd.DataFrame({"one": one_data[:3], "two": ["a", "b", "c"]})
|
||||
df2 = pd.DataFrame({"one": one_data[3:], "two": ["e", "f", "g"]})
|
||||
df = pd.concat([df1, df2])
|
||||
ds = ray.data.from_pandas([ray.put(df1), ray.put(df2)])
|
||||
|
||||
def _block_udf(block: pa.Table):
|
||||
df = block.to_pandas()
|
||||
df["one"] += 1
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
# 2 write tasks
|
||||
ds._set_uuid("data")
|
||||
ds.write_parquet(data_path, _block_udf=_block_udf)
|
||||
path1 = os.path.join(data_path, "data_000000.parquet")
|
||||
path2 = os.path.join(data_path, "data_000001.parquet")
|
||||
dfds = pd.concat([pd.read_parquet(path1), pd.read_parquet(path2)])
|
||||
expected_df = df
|
||||
expected_df["one"] += 1
|
||||
assert expected_df.equals(dfds)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fs,data_path", [(None, lazy_fixture("local_path")),
|
||||
(lazy_fixture("local_fs"), lazy_fixture("local_path")),
|
||||
|
|
Loading…
Add table
Reference in a new issue