[Datasets] Support tensor columns in to_tf and to_torch. (#24752)

This PR adds support for tensor columns in the to_tf() and to_torch() APIs.

For Torch, this involves an explicit extension array check and (zero-copy) conversion of the tensor column to a NumPy array before converting the column to a Torch tensor.

For TensorFlow, this involves bypassing df.values when converting tensor feature columns to NumPy arrays, instead manually creating a single NumPy array from the column Series.

In both cases, I think that the UX around heterogeneous feature columns and squeezing the column dimension could be improved, but I'm saving that for a future PR.
This commit is contained in:
Clark Zinzow 2022-05-17 01:11:00 -07:00 committed by GitHub
parent ef870e936c
commit ea635aecd2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 332 additions and 50 deletions

View file

@ -107,7 +107,9 @@ If your serialized tensors don't fit the above constraints (e.g. they're stored
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>
Please note that the ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
.. note::
The ``tensor_column_schema`` and ``_block_udf`` parameters are both experimental developer APIs and may break in future versions.
Working with tensor column datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -143,6 +145,167 @@ This dataset can then be written to Parquet files. The tensor column schema will
# -> one: int64
# two: extension<arrow.py_extension_type<ArrowTensorType>>
Converting to a Torch/TensorFlow Dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This dataset can also be converted to a Torch or TensorFlow dataset via the standard
:meth:`ds.to_torch() <ray.data.Dataset.to_torch>` and
:meth:`ds.to_tf() <ray.data.Dataset.to_tf>` APIs for ingestion into those respective ML
training frameworks. The tensor column will be automatically converted to a
Torch/TensorFlow tensor without incurring any copies.
.. note::
When converting to a TensorFlow Dataset, you will need to give the full tensor spec
for the tensor columns, including the shape of each underlying tensor element in said
column.
.. tabbed:: Torch
Convert a ``Dataset`` containing a single tensor feature column to a Torch ``IterableDataset``.
.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)
# A feature tensor and label tensor is yielded per batch.
for X, y in torch_ds:
# Train model(X, y)
.. tabbed:: TensorFlow
Convert a ``Dataset`` containing a single tensor feature column to a TensorFlow ``tf.data.Dataset``.
.. code-block:: python
import ray
import numpy as np
import pandas as pd
import tensorflow as tf
tensor_element_shape = (32, 32)
df = pd.DataFrame({
"feature": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
output_signature=(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
# A feature tensor and label tensor is yielded per batch.
for X, y in tf_ds:
# Train model(X, y)
If your columns have different types **OR** your (tensor) columns have different shapes,
these columns are incompatible and you will not be able to stack the column tensors
into a single tensor. Instead, you will need to group the columns by compatibility in
the ``feature_columns`` argument.
E.g., if columns ``"feature_1"`` and ``"feature_2"`` are incompatible, you should give
``to_torch()`` a ``feature_columns=[["feature_1"], ["feature_2"]]`` argument in order to
instruct it to return separate tensors for ``"feature_1"`` and ``"feature_2"``. For
``to_torch()``, if isolating single columns as in the ``"feature_1"`` + ``"feature_2"``
example, you may also want to provide ``unsqueeze_feature_tensors=False`` in order to
remove the redundant column dimension for each of the unit column tensors.
.. tabbed:: Torch
Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a Torch ``IterableDataset``.
.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4, 32, 32))),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a Torch IterableDataset.
torch_ds = ds.to_torch(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)
# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in torch_ds:
# Train model((feature_1, feature_2), y)
.. tabbed:: TensorFlow
Convert a ``Dataset`` containing a tensor feature column and a scalar feature column
to a TensorFlow ``tf.data.Dataset``.
.. code-block:: python
import ray
import numpy as np
import pandas as pd
import torch
tensor_element_shape = (32, 32)
df = pd.DataFrame({
"feature_1": TensorArray(np.arange(4096).reshape((4,) + tensor_element_shape)),
"feature_2": [5, 6, 7, 8],
"label": [1, 2, 3, 4],
})
ds = ray.data.from_pandas(df)
# Convert the dataset to a TensorFlow Dataset.
tf_ds = ds.to_tf(
label_column="label",
feature_columns=[["feature_1"], ["feature_2"]],
output_signature=(
(
tf.TensorSpec(shape=(None, 1) + tensor_element_shape, dtype=tf.float32),
tf.TensorSpec(shape=(None, 1), dtype=tf.int64),
),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
# Two feature tensors and one label tensor is yielded per batch.
for (feature_1, feature_2), y in tf_ds:
# Train model((feature_1, feature_2), y)
End-to-end workflow with our Pandas extension type
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -246,5 +409,3 @@ This feature currently comes with a few known limitations that we are either act
* All tensors in a tensor column currently must be the same shape. Please let us know if you require heterogeneous tensor shape for your tensor column! Tracking issue is `here <https://github.com/ray-project/ray/issues/18316>`__.
* Automatic casting via specifying an override Arrow schema when reading Parquet is blocked by Arrow supporting custom ExtensionType casting kernels. See `issue <https://issues.apache.org/jira/browse/ARROW-5890>`__. An explicit ``tensor_column_schema`` parameter has been added for :func:`read_parquet() <ray.data.read_api.read_parquet>` as a stopgap solution.
* Ingesting tables with tensor columns into pytorch via ``ds.to_torch()`` is blocked by pytorch supporting tensor creation from objects that implement the `__array__` interface. See `issue <https://github.com/pytorch/pytorch/issues/51156>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18314>`__.
* Ingesting tables with tensor columns into TensorFlow via ``ds.to_tf()`` is blocked by a Pandas fix for properly interpreting extension arrays in ``DataFrame.values`` being released. See `PR <https://github.com/pandas-dev/pandas/pull/43160>`__. Workarounds are being `investigated <https://github.com/ray-project/ray/issues/18315>`__.

View file

@ -2446,6 +2446,27 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
if isinstance(output_signature, list):
output_signature = tuple(output_signature)
def get_df_values(df: "pandas.DataFrame") -> np.ndarray:
# TODO(Clark): Support unsqueezing column dimension API, similar to
# to_torch().
try:
values = df.values
except ValueError as e:
import pandas as pd
# Pandas DataFrame.values doesn't support extension arrays in all
# supported Pandas versions, so we check to see if this DataFrame
# contains any extensions arrays and do a manual conversion if so.
# See https://github.com/pandas-dev/pandas/pull/43160.
if any(
isinstance(dtype, pd.api.extensions.ExtensionDtype)
for dtype in df.dtypes
):
values = np.stack([col.to_numpy() for _, col in df.items()], axis=1)
else:
raise e from None
return values
def make_generator():
for batch in self.iter_batches(
prefetch_blocks=prefetch_blocks,
@ -2458,13 +2479,13 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
features = None
if feature_columns is None:
features = batch.values
features = get_df_values(batch)
elif isinstance(feature_columns, list):
if all(isinstance(column, str) for column in feature_columns):
features = batch[feature_columns].values
features = get_df_values(batch[feature_columns])
elif all(isinstance(columns, list) for columns in feature_columns):
features = tuple(
batch[columns].values for columns in feature_columns
get_df_values(batch[columns]) for columns in feature_columns
)
else:
raise ValueError(
@ -2473,7 +2494,7 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
)
elif isinstance(feature_columns, dict):
features = {
key: batch[columns].values
key: get_df_values(batch[columns])
for key, columns in feature_columns.items()
}
else:
@ -2482,8 +2503,6 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
f"but got a `{type(feature_columns).__name__}` instead."
)
# TODO(Clark): Support batches containing our extension array
# TensorArray.
if label_column:
yield features, targets
else:

View file

@ -962,49 +962,90 @@ def test_tensors_in_tables_parquet_bytes_with_schema(
np.testing.assert_equal(v, e)
@pytest.mark.skip(
reason=(
"Waiting for pytorch to support tensor creation from objects that "
"implement the __array__ interface. See "
"https://github.com/pytorch/pytorch/issues/51156"
)
)
@pytest.mark.parametrize("pipelined", [False, True])
def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined):
import torch
outer_dim = 3
inner_shape = (2, 2, 2)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape)
df1 = pd.DataFrame(
{"one": [1, 2, 3], "two": TensorArray(arr), "label": [1.0, 2.0, 3.0]}
{"one": TensorArray(arr), "two": TensorArray(arr + 1), "label": [1.0, 2.0, 3.0]}
)
arr2 = np.arange(num_items, 2 * num_items).reshape(shape)
df2 = pd.DataFrame(
{"one": [4, 5, 6], "two": TensorArray(arr2), "label": [4.0, 5.0, 6.0]}
{
"one": TensorArray(arr2),
"two": TensorArray(arr2 + 1),
"label": [4.0, 5.0, 6.0],
}
)
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = maybe_pipeline(ds, pipelined)
torchd = ds.to_torch(label_column="label", batch_size=2)
num_epochs = 2
for _ in range(num_epochs):
iterations = []
for batch in iter(torchd):
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
combined_iterations = np.concatenate(iterations)
assert np.array_equal(np.sort(df.values), np.sort(combined_iterations))
@pytest.mark.skip(
reason=(
"Waiting for Pandas DataFrame.values for extension arrays fix to be "
"released. See https://github.com/pandas-dev/pandas/pull/43160"
torchd = ds.to_torch(
label_column="label", batch_size=2, unsqueeze_label_tensor=False
)
)
num_epochs = 1 if pipelined else 2
for _ in range(num_epochs):
features, labels = [], []
for batch in iter(torchd):
features.append(batch[0].numpy())
labels.append(batch[1].numpy())
features, labels = np.concatenate(features), np.concatenate(labels)
values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1)
np.testing.assert_array_equal(values, features)
np.testing.assert_array_equal(df["label"].to_numpy(), labels)
@pytest.mark.parametrize("pipelined", [False, True])
def test_tensors_in_tables_to_torch_mix(ray_start_regular_shared, pipelined):
outer_dim = 3
inner_shape = (2, 2, 2)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape)
df1 = pd.DataFrame(
{
"one": TensorArray(arr),
"two": [1, 2, 3],
"label": [1.0, 2.0, 3.0],
}
)
arr2 = np.arange(num_items, 2 * num_items).reshape(shape)
df2 = pd.DataFrame(
{
"one": TensorArray(arr2),
"two": [4, 5, 6],
"label": [4.0, 5.0, 6.0],
}
)
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = maybe_pipeline(ds, pipelined)
torchd = ds.to_torch(
label_column="label",
feature_columns=[["one"], ["two"]],
batch_size=2,
unsqueeze_label_tensor=False,
unsqueeze_feature_tensors=False,
)
num_epochs = 1 if pipelined else 2
for _ in range(num_epochs):
col1, col2, labels = [], [], []
for batch in iter(torchd):
col1.append(batch[0][0].numpy())
col2.append(batch[0][1].numpy())
labels.append(batch[1].numpy())
col1, col2 = np.concatenate(col1), np.concatenate(col2)
labels = np.concatenate(labels)
np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy()))
np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy()))
np.testing.assert_array_equal(labels, np.sort(df["label"].to_numpy()))
@pytest.mark.parametrize("pipelined", [False, True])
def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined):
import tensorflow as tf
@ -1014,21 +1055,19 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined):
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape).astype(np.float)
# TODO(Clark): Ensure that heterogeneous columns is properly supported
# (tf.RaggedTensorSpec)
df1 = pd.DataFrame(
{
"one": TensorArray(arr),
"two": TensorArray(arr),
"label": TensorArray(arr),
"two": TensorArray(arr + 1),
"label": [1, 2, 3],
}
)
arr2 = np.arange(num_items, 2 * num_items).reshape(shape).astype(np.float)
df2 = pd.DataFrame(
{
"one": TensorArray(arr2),
"two": TensorArray(arr2),
"label": TensorArray(arr2),
"two": TensorArray(arr2 + 1),
"label": [4, 5, 6],
}
)
df = pd.concat([df1, df2])
@ -1038,15 +1077,70 @@ def test_tensors_in_tables_to_tf(ray_start_regular_shared, pipelined):
label_column="label",
output_signature=(
tf.TensorSpec(shape=(None, 2, 2, 2, 2), dtype=tf.float32),
tf.TensorSpec(shape=(None, 1, 2, 2, 2), dtype=tf.float32),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
iterations = []
features, labels = [], []
for batch in tfd.as_numpy_iterator():
iterations.append(np.concatenate((batch[0], batch[1]), axis=1))
combined_iterations = np.concatenate(iterations)
arr = np.array([[np.asarray(v) for v in values] for values in df.to_numpy()])
np.testing.assert_array_equal(arr, combined_iterations)
features.append(batch[0])
labels.append(batch[1])
features, labels = np.concatenate(features), np.concatenate(labels)
values = np.stack([df["one"].to_numpy(), df["two"].to_numpy()], axis=1)
np.testing.assert_array_equal(values, features)
np.testing.assert_array_equal(df["label"].to_numpy(), labels)
@pytest.mark.parametrize("pipelined", [False, True])
def test_tensors_in_tables_to_tf_mix(ray_start_regular_shared, pipelined):
import tensorflow as tf
outer_dim = 3
inner_shape = (2, 2, 2)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape).astype(np.float)
df1 = pd.DataFrame(
{
"one": TensorArray(arr),
"two": [1, 2, 3],
"label": [1.0, 2.0, 3.0],
}
)
arr2 = np.arange(num_items, 2 * num_items).reshape(shape).astype(np.float)
df2 = pd.DataFrame(
{
"one": TensorArray(arr2),
"two": [4, 5, 6],
"label": [4.0, 5.0, 6.0],
}
)
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
ds = maybe_pipeline(ds, pipelined)
tfd = ds.to_tf(
label_column="label",
feature_columns=[["one"], ["two"]],
output_signature=(
(
tf.TensorSpec(shape=(None, 1, 2, 2, 2), dtype=tf.float32),
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
),
tf.TensorSpec(shape=(None,), dtype=tf.float32),
),
batch_size=2,
)
col1, col2, labels = [], [], []
for batch in tfd.as_numpy_iterator():
col1.append(batch[0][0])
col2.append(batch[0][1])
labels.append(batch[1])
col1 = np.squeeze(np.concatenate(col1), axis=1)
col2 = np.squeeze(np.concatenate(col2), axis=1)
labels = np.concatenate(labels)
np.testing.assert_array_equal(col1, np.sort(df["one"].to_numpy()))
np.testing.assert_array_equal(col2, np.sort(df["two"].to_numpy()))
np.testing.assert_array_equal(labels, np.sort(df["label"].to_numpy()))
def test_empty_shuffle(ray_start_regular_shared):

View file

@ -54,6 +54,12 @@ def convert_pandas_to_torch_tensor(
def tensorize(vals, dtype):
"""This recursive function allows to convert pyarrow List dtypes
to multi-dimensional tensors."""
if isinstance(vals, pd.api.extensions.ExtensionArray):
# torch.as_tensor() does not yet support the __array__ protocol, so we need
# to convert extension arrays to ndarrays manually before converting to a
# Torch tensor.
# See https://github.com/pytorch/pytorch/issues/51156.
vals = vals.to_numpy()
try:
return torch.as_tensor(vals, dtype=dtype)
except TypeError:
@ -79,8 +85,10 @@ def convert_pandas_to_torch_tensor(
feature_tensors.append(t)
if len(feature_tensors) > 1:
return torch.cat(feature_tensors, dim=1)
return feature_tensors[0]
feature_tensor = torch.cat(feature_tensors, dim=1)
else:
feature_tensor = feature_tensors[0]
return feature_tensor
if multi_input:
if type(column_dtypes) not in [list, tuple]: