mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR - Datasets] Hide tensor extension from UDFs. (#27019)
We previously added automatic tensor extension casting on Datasets transformation outputs to allow the user to not have to worry about tensor column casting; however, this current state creates several issues: 1. Not all tensors are supported, which means that we’ll need to have an opaque object dtype (i.e. ndarray of ndarray pointers) fallback for the Pandas-only case. Known unsupported tensor use cases: a. Heterogeneous-shaped (i.e. ragged) tensors b. Struct arrays 2. UDFs will expect a NumPy column and won’t know what to do with our TensorArray type. E.g., torchvision transforms don’t respect the array protocol (which they should), and instead only support Torch tensors and NumPy ndarrays; passing a TensorArray column or a TensorArrayElement (a single item in the TensorArray column) fails. Implicit casting with object dtype fallback on UDF outputs can make the input type to downstream UDFs nondeterministic, where the user won’t know if they’ll get a TensorArray column or an object dtype column. 3. The tensor extension cast fallback warning spams the logs. This PR: 1. Adds automatic casting of tensor extension columns to NumPy ndarray columns for Datasets UDF inputs, meaning the UDFs will never have to see tensor extensions and that the UDF input column types will be consistent and deterministic; this fixes both (2) and (3). 2. No longer implicitly falls back to an opaque object dtype when TensorArray casting fails (e.g. for ragged tensors), and instead raises an error; this fixes (4) but removes our support for (1). 3. Adds a global enable_tensor_extension_casting config flag, which is True by default, that controls whether we perform this automatic casting. Turning off the implicit casting provides a path for (1), where the tensor extension can be avoided if working with ragged tensors in Pandas land. Turning off this flag also allows the user to explicitly control their tensor extension casting, if they want to work with it in their UDFs in order to reap the benefits of less data copies, more efficient slicing, stronger column typing, etc.
This commit is contained in:
parent
510a0e038c
commit
df124d0ad5
24 changed files with 413 additions and 185 deletions
|
@ -1228,7 +1228,7 @@
|
|||
],
|
||||
"source": [
|
||||
"predicted_classes = results.map_batches(\n",
|
||||
" lambda batch: [classes[pred.to_numpy().argmax(0)] for pred in batch[\"predictions\"]], \n",
|
||||
" lambda batch: [classes[pred.argmax(0)] for pred in batch[\"predictions\"]], \n",
|
||||
" batch_format=\"pandas\")"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -4,7 +4,6 @@ from torchvision import transforms
|
|||
from torchvision.models import resnet18
|
||||
|
||||
import ray
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
from ray.train.torch import TorchCheckpoint, TorchPredictor
|
||||
from ray.train.batch_predictor import BatchPredictor
|
||||
from ray.data.preprocessors import BatchMapper
|
||||
|
@ -24,7 +23,7 @@ def preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
|||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
df["image"] = TensorArray([preprocess(x.to_numpy()) for x in df["image"]])
|
||||
df["image"] = [preprocess(x).numpy() for x in df["image"]]
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
@ -515,7 +515,6 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from ray.data.preprocessors import BatchMapper\n",
|
||||
"from ray.data.extensions import TensorArray\n",
|
||||
"\n",
|
||||
"from torchvision import transforms\n",
|
||||
"\n",
|
||||
|
@ -526,10 +525,9 @@
|
|||
" [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" df.loc[:, \"image\"] = TensorArray([\n",
|
||||
" torchvision_transforms(np.asarray(image)).numpy()\n",
|
||||
" for image in df[\"image\"]\n",
|
||||
" ])\n",
|
||||
" df.loc[:, \"image\"] = [\n",
|
||||
" torchvision_transforms(image).numpy() for image in df[\"image\"]\n",
|
||||
" ]\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"mnist_normalize_preprocessor = BatchMapper(fn=preprocess_images)"
|
||||
|
|
|
@ -5,6 +5,8 @@ import pandas as pd
|
|||
import tensorflow as tf
|
||||
from pandas.api.types import is_object_dtype
|
||||
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
|
||||
|
||||
def convert_pandas_to_tf_tensor(
|
||||
df: pd.DataFrame, dtype: Optional[tf.dtypes.DType] = None
|
||||
|
@ -83,6 +85,23 @@ def convert_pandas_to_tf_tensor(
|
|||
return concatenated_tensor
|
||||
|
||||
|
||||
def convert_ndarray_to_tf_tensor(
|
||||
ndarray: np.ndarray,
|
||||
dtype: Optional[tf.dtypes.DType] = None,
|
||||
) -> tf.Tensor:
|
||||
"""Convert a NumPy ndarray to a TensorFlow Tensor.
|
||||
|
||||
Args:
|
||||
ndarray: A NumPy ndarray that we wish to convert to a TensorFlow Tensor.
|
||||
dtype: A TensorFlow dtype for the created tensor; if None, the dtype will be
|
||||
inferred from the NumPy ndarray data.
|
||||
|
||||
Returns: A TensorFlow Tensor.
|
||||
"""
|
||||
ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
|
||||
return tf.convert_to_tensor(ndarray, dtype=dtype)
|
||||
|
||||
|
||||
def convert_ndarray_batch_to_tf_tensor_batch(
|
||||
ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
dtypes: Optional[Union[tf.dtypes.DType, Dict[str, tf.dtypes.DType]]] = None,
|
||||
|
@ -106,11 +125,11 @@ def convert_ndarray_batch_to_tf_tensor_batch(
|
|||
f"should be given, instead got: {dtypes}"
|
||||
)
|
||||
dtypes = next(iter(dtypes.values()))
|
||||
batch = tf.convert_to_tensor(ndarrays, dtype=dtypes)
|
||||
batch = convert_ndarray_to_tf_tensor(ndarrays, dtypes)
|
||||
else:
|
||||
# Multi-tensor case.
|
||||
batch = {
|
||||
col_name: tf.convert_to_tensor(
|
||||
col_name: convert_ndarray_to_tf_tensor(
|
||||
col_ndarray,
|
||||
dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
|
||||
)
|
||||
|
|
|
@ -4,6 +4,8 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
|
||||
|
||||
def convert_pandas_to_torch_tensor(
|
||||
data_batch: pd.DataFrame,
|
||||
|
@ -102,6 +104,24 @@ def convert_pandas_to_torch_tensor(
|
|||
return get_tensor_for_columns(columns=columns, dtype=column_dtypes)
|
||||
|
||||
|
||||
def convert_ndarray_to_torch_tensor(
|
||||
ndarray: np.ndarray,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Convert a NumPy ndarray to a Torch Tensor.
|
||||
|
||||
Args:
|
||||
ndarray: A NumPy ndarray that we wish to convert to a Torch Tensor.
|
||||
dtype: A Torch dtype for the created tensor; if None, the dtype will be
|
||||
inferred from the NumPy ndarray data.
|
||||
|
||||
Returns: A Torch Tensor.
|
||||
"""
|
||||
ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
|
||||
return torch.as_tensor(ndarray, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def convert_ndarray_batch_to_torch_tensor_batch(
|
||||
ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
|
||||
|
@ -127,11 +147,11 @@ def convert_ndarray_batch_to_torch_tensor_batch(
|
|||
f"should be given, instead got: {dtypes}"
|
||||
)
|
||||
dtypes = next(iter(dtypes.values()))
|
||||
batch = torch.as_tensor(ndarrays, dtype=dtypes, device=device)
|
||||
batch = convert_ndarray_to_torch_tensor(ndarrays, dtype=dtypes, device=device)
|
||||
else:
|
||||
# Multi-tensor case.
|
||||
batch = {
|
||||
col_name: torch.as_tensor(
|
||||
col_name: convert_ndarray_to_torch_tensor(
|
||||
col_ndarray,
|
||||
dtype=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
|
||||
device=device,
|
||||
|
|
|
@ -16,40 +16,72 @@ def test_pandas_pandas():
|
|||
input_data = pd.DataFrame({"x": [1, 2, 3]})
|
||||
expected_output = input_data
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
assert convert_pandas_to_batch_type(actual_output, type=DataType.PANDAS).equals(
|
||||
input_data
|
||||
actual_output = convert_pandas_to_batch_type(actual_output, type=DataType.PANDAS)
|
||||
pd.testing.assert_frame_equal(actual_output, input_data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tensor_extension_for_input", [True, False])
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_pandas_multi_dim_pandas(cast_tensor_columns, use_tensor_extension_for_input):
|
||||
input_tensor = np.arange(12).reshape((3, 2, 2))
|
||||
input_data = pd.DataFrame(
|
||||
{
|
||||
"x": TensorArray(input_tensor)
|
||||
if use_tensor_extension_for_input
|
||||
else list(input_tensor)
|
||||
}
|
||||
)
|
||||
expected_output = pd.DataFrame(
|
||||
{
|
||||
"x": (
|
||||
list(input_tensor)
|
||||
if cast_tensor_columns or not use_tensor_extension_for_input
|
||||
else TensorArray(input_tensor)
|
||||
)
|
||||
}
|
||||
)
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
actual_output = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.PANDAS, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
pd.testing.assert_frame_equal(actual_output, input_data)
|
||||
|
||||
|
||||
def test_numpy_pandas():
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_numpy_pandas(cast_tensor_columns):
|
||||
input_data = np.array([1, 2, 3])
|
||||
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: TensorArray([1, 2, 3])})
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: input_data})
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
|
||||
output_array = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.NUMPY, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
np.testing.assert_equal(output_array, input_data)
|
||||
|
||||
|
||||
def test_numpy_multi_dim_pandas():
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_numpy_multi_dim_pandas(cast_tensor_columns):
|
||||
input_data = np.arange(12).reshape((3, 2, 2))
|
||||
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: TensorArray(input_data)})
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: list(input_data)})
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
|
||||
output_array = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.NUMPY, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
np.testing.assert_array_equal(np.array(list(output_array)), input_data)
|
||||
|
||||
|
||||
def test_numpy_object_pandas():
|
||||
input_data = np.array([[1, 2, 3], [1]], dtype=object)
|
||||
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: input_data})
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
|
||||
|
@ -62,34 +94,43 @@ def test_dict_fail():
|
|||
convert_batch_type_to_pandas(input_data)
|
||||
|
||||
|
||||
def test_dict_pandas():
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_dict_pandas(cast_tensor_columns):
|
||||
input_data = {"x": np.array([1, 2, 3])}
|
||||
expected_output = pd.DataFrame({"x": TensorArray(input_data["x"])})
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
expected_output = pd.DataFrame({"x": input_data["x"]})
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
output_array = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
|
||||
output_array = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.NUMPY, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
np.testing.assert_array_equal(output_array, input_data["x"])
|
||||
|
||||
|
||||
def test_dict_multi_dim_to_pandas():
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_dict_multi_dim_to_pandas(cast_tensor_columns):
|
||||
tensor = np.arange(12).reshape((3, 2, 2))
|
||||
input_data = {"x": tensor}
|
||||
expected_output = pd.DataFrame({"x": TensorArray(tensor)})
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
expected_output = pd.DataFrame({"x": list(tensor)})
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
output_array = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
|
||||
np.testing.assert_array_equal(output_array, input_data["x"])
|
||||
output_array = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.NUMPY, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
np.testing.assert_array_equal(np.array(list(output_array)), input_data["x"])
|
||||
|
||||
|
||||
def test_dict_pandas_multi_column():
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_dict_pandas_multi_column(cast_tensor_columns):
|
||||
array_dict = {"x": np.array([1, 2, 3]), "y": np.array([4, 5, 6])}
|
||||
expected_output = pd.DataFrame({k: TensorArray(v) for k, v in array_dict.items()})
|
||||
actual_output = convert_batch_type_to_pandas(array_dict)
|
||||
assert expected_output.equals(actual_output)
|
||||
expected_output = pd.DataFrame(array_dict)
|
||||
actual_output = convert_batch_type_to_pandas(array_dict, cast_tensor_columns)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
output_dict = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
|
||||
output_dict = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.NUMPY, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
for k, v in output_dict.items():
|
||||
np.testing.assert_array_equal(v, array_dict[k])
|
||||
|
||||
|
@ -99,26 +140,30 @@ def test_arrow_pandas():
|
|||
input_data = pa.Table.from_pandas(df)
|
||||
expected_output = df
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
assert convert_pandas_to_batch_type(actual_output, type=DataType.ARROW).equals(
|
||||
input_data
|
||||
)
|
||||
|
||||
|
||||
def test_arrow_tensor_pandas():
|
||||
np_array = np.array([1, 2, 3])
|
||||
df = pd.DataFrame({"x": TensorArray(np_array)})
|
||||
@pytest.mark.parametrize("cast_tensor_columns", [True, False])
|
||||
def test_arrow_tensor_pandas(cast_tensor_columns):
|
||||
np_array = np.arange(12).reshape((3, 2, 2))
|
||||
input_data = pa.Table.from_arrays(
|
||||
[ArrowTensorArray.from_numpy(np_array)], names=["x"]
|
||||
)
|
||||
expected_output = df
|
||||
actual_output = convert_batch_type_to_pandas(input_data)
|
||||
assert expected_output.equals(actual_output)
|
||||
|
||||
assert convert_pandas_to_batch_type(actual_output, type=DataType.ARROW).equals(
|
||||
input_data
|
||||
actual_output = convert_batch_type_to_pandas(input_data, cast_tensor_columns)
|
||||
expected_output = pd.DataFrame({"x": list(np_array)})
|
||||
expected_output = pd.DataFrame(
|
||||
{"x": (list(np_array) if cast_tensor_columns else TensorArray(np_array))}
|
||||
)
|
||||
pd.testing.assert_frame_equal(expected_output, actual_output)
|
||||
|
||||
arrow_output = convert_pandas_to_batch_type(
|
||||
actual_output, type=DataType.ARROW, cast_tensor_columns=cast_tensor_columns
|
||||
)
|
||||
assert arrow_output.equals(input_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from enum import Enum, auto
|
||||
import logging
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray.air.data_batch_type import DataBatchType
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
@ -14,8 +13,6 @@ try:
|
|||
except ImportError:
|
||||
pyarrow = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class DataType(Enum):
|
||||
|
@ -25,78 +22,64 @@ class DataType(Enum):
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
def convert_batch_type_to_pandas(data: DataBatchType) -> pd.DataFrame:
|
||||
def convert_batch_type_to_pandas(
|
||||
data: DataBatchType,
|
||||
cast_tensor_columns: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Convert the provided data to a Pandas DataFrame.
|
||||
|
||||
Args:
|
||||
data: Data of type DataBatchType
|
||||
cast_tensor_columns: Whether tensor columns should be cast to NumPy ndarrays.
|
||||
|
||||
Returns:
|
||||
A pandas Dataframe representation of the input data.
|
||||
|
||||
"""
|
||||
global _tensor_cast_failed_warned
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
return data
|
||||
|
||||
elif isinstance(data, np.ndarray):
|
||||
try:
|
||||
# Try to convert numpy arrays to TensorArrays.
|
||||
data = TensorArray(data)
|
||||
except TypeError as e:
|
||||
# Fall back to existing NumPy array.
|
||||
if ray.util.log_once("datasets_tensor_array_cast_warning"):
|
||||
logger.warning(
|
||||
"Tried to transparently convert ndarray batch to a TensorArray "
|
||||
f"but the conversion failed, leaving ndarray batch as-is: {e}"
|
||||
)
|
||||
return pd.DataFrame({TENSOR_COLUMN_NAME: data})
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
data = pd.DataFrame({TENSOR_COLUMN_NAME: _ndarray_to_column(data)})
|
||||
elif isinstance(data, dict):
|
||||
tensor_dict = {}
|
||||
for k, v in data.items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
for col_name, col in data.items():
|
||||
if not isinstance(col, np.ndarray):
|
||||
raise ValueError(
|
||||
"All values in the provided dict must be of type "
|
||||
f"np.ndarray. Found type {type(v)} for key {k} "
|
||||
f"np.ndarray. Found type {type(col)} for key {col_name} "
|
||||
f"instead."
|
||||
)
|
||||
try:
|
||||
# Try to convert numpy arrays to TensorArrays.
|
||||
v = TensorArray(v)
|
||||
except TypeError as e:
|
||||
# Fall back to existing NumPy array.
|
||||
if ray.util.log_once("datasets_tensor_array_cast_warning"):
|
||||
logger.warning(
|
||||
f"Tried to transparently convert column ndarray {k} of batch "
|
||||
"to a TensorArray but the conversion failed, leaving column "
|
||||
f"as-is: {e}"
|
||||
)
|
||||
tensor_dict[k] = v
|
||||
return pd.DataFrame(tensor_dict)
|
||||
|
||||
tensor_dict[col_name] = _ndarray_to_column(col)
|
||||
data = pd.DataFrame(tensor_dict)
|
||||
elif pyarrow is not None and isinstance(data, pyarrow.Table):
|
||||
return data.to_pandas()
|
||||
else:
|
||||
data = data.to_pandas()
|
||||
elif not isinstance(data, pd.DataFrame):
|
||||
raise ValueError(
|
||||
f"Received data of type: {type(data)}, but expected it to be one "
|
||||
f"of {DataBatchType}"
|
||||
)
|
||||
if cast_tensor_columns:
|
||||
data = _cast_tensor_columns_to_ndarrays(data)
|
||||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def convert_pandas_to_batch_type(data: pd.DataFrame, type: DataType) -> DataBatchType:
|
||||
def convert_pandas_to_batch_type(
|
||||
data: pd.DataFrame,
|
||||
type: DataType,
|
||||
cast_tensor_columns: bool = False,
|
||||
) -> DataBatchType:
|
||||
"""Convert the provided Pandas dataframe to the provided ``type``.
|
||||
|
||||
Args:
|
||||
data: A Pandas DataFrame
|
||||
type: The specific ``DataBatchType`` to convert to.
|
||||
cast_tensor_columns: Whether tensor columns should be cast to our tensor
|
||||
extension type.
|
||||
|
||||
Returns:
|
||||
The input data represented with the provided type.
|
||||
"""
|
||||
if cast_tensor_columns:
|
||||
data = _cast_ndarray_columns_to_tensor_extension(data)
|
||||
if type == DataType.PANDAS:
|
||||
return data
|
||||
|
||||
|
@ -124,3 +107,72 @@ def convert_pandas_to_batch_type(data: pd.DataFrame, type: DataType) -> DataBatc
|
|||
raise ValueError(
|
||||
f"Received type {type}, but expected it to be one of {DataType}"
|
||||
)
|
||||
|
||||
|
||||
def _ndarray_to_column(arr: np.ndarray) -> Union[pd.Series, List[np.ndarray]]:
|
||||
"""Convert a NumPy ndarray into an appropriate column format for insertion into a
|
||||
pandas DataFrame.
|
||||
|
||||
If conversion to a pandas Series fails (e.g. if the ndarray is multi-dimensional),
|
||||
fall back to a list of NumPy ndarrays.
|
||||
"""
|
||||
try:
|
||||
# Try to convert to Series, falling back to a list conversion if this fails
|
||||
# (e.g. if the ndarray is multi-dimensional).
|
||||
return pd.Series(arr)
|
||||
except ValueError:
|
||||
return list(arr)
|
||||
|
||||
|
||||
def _unwrap_ndarray_object_type_if_needed(arr: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap an object-dtyped NumPy ndarray containing ndarray pointers into a single
|
||||
contiguous ndarray, if needed/possible.
|
||||
"""
|
||||
if arr.dtype.type is np.object_:
|
||||
try:
|
||||
# Try to convert the NumPy ndarray to a non-object dtype.
|
||||
arr = np.array([np.asarray(v) for v in arr])
|
||||
except Exception:
|
||||
# This may fail if the subndarrays are of heterogeneous shape
|
||||
pass
|
||||
return arr
|
||||
|
||||
|
||||
def _cast_ndarray_columns_to_tensor_extension(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Cast all NumPy ndarray columns in df to our tensor extension type, TensorArray.
|
||||
"""
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
|
||||
# Try to convert any ndarray columns to TensorArray columns.
|
||||
# TODO(Clark): Once Pandas supports registering extension types for type
|
||||
# inference on construction, implement as much for NumPy ndarrays and remove
|
||||
# this. See https://github.com/pandas-dev/pandas/issues/41848
|
||||
for col_name, col in df.items():
|
||||
if (
|
||||
col.dtype.type is np.object_
|
||||
and not col.empty
|
||||
and isinstance(col.iloc[0], np.ndarray)
|
||||
):
|
||||
try:
|
||||
df.loc[:, col_name] = TensorArray(col)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Tried to cast column {col_name} to the TensorArray tensor "
|
||||
"extension type but the conversion failed. To disable automatic "
|
||||
"casting to this tensor extension, set "
|
||||
"ctx = DatasetContext.get_current(); "
|
||||
"ctx.enable_tensor_extension_casting = False."
|
||||
) from e
|
||||
return df
|
||||
|
||||
|
||||
def _cast_tensor_columns_to_ndarrays(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Cast all tensor extension columns in df to NumPy ndarrays."""
|
||||
from ray.air.util.tensor_extensions.pandas import TensorDtype
|
||||
|
||||
# Try to convert any tensor extension columns to ndarray columns.
|
||||
for col_name, col in df.items():
|
||||
if isinstance(col.dtype, TensorDtype):
|
||||
df.loc[:, col_name] = pd.Series(list(col.to_numpy()))
|
||||
return df
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
# - Added support for logical operators to TensorArray(Element).
|
||||
# - Miscellaneous small bug fixes and optimizations.
|
||||
|
||||
import itertools
|
||||
import numbers
|
||||
import os
|
||||
from distutils.version import LooseVersion
|
||||
|
@ -712,13 +713,15 @@ class TensorArray(
|
|||
# ndarrays.
|
||||
self._tensor = np.array([np.asarray(v) for v in values])
|
||||
if self._tensor.dtype.type is np.object_:
|
||||
subndarray_types = [v.dtype for v in self._tensor]
|
||||
subndarray_types = [
|
||||
v.dtype for v in itertools.islice(self._tensor, 5)
|
||||
]
|
||||
raise TypeError(
|
||||
"Tried to convert an ndarray of ndarray pointers (object "
|
||||
"dtype) to a well-typed ndarray but this failed; convert "
|
||||
"the ndarray to a well-typed ndarray before casting it as "
|
||||
"a TensorArray, and note that ragged tensors are NOT "
|
||||
"supported by TensorArray. subndarray types: "
|
||||
"supported by TensorArray. First 5 subndarray types: "
|
||||
f"{subndarray_types}"
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
|
@ -16,7 +15,6 @@ import collections
|
|||
import heapq
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.data.block import (
|
||||
Block,
|
||||
BlockAccessor,
|
||||
|
@ -26,6 +24,7 @@ from ray.data.block import (
|
|||
KeyType,
|
||||
U,
|
||||
)
|
||||
from ray.data.context import DatasetContext
|
||||
from ray.data.row import TableRow
|
||||
from ray.data._internal.table_block import (
|
||||
TableBlockAccessor,
|
||||
|
@ -42,7 +41,6 @@ if TYPE_CHECKING:
|
|||
T = TypeVar("T")
|
||||
|
||||
_pandas = None
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def lazy_import_pandas():
|
||||
|
@ -98,31 +96,18 @@ class PandasBlockBuilder(TableBlockBuilder[T]):
|
|||
|
||||
def _concat_tables(self, tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
|
||||
pandas = lazy_import_pandas()
|
||||
from ray.data.extensions.tensor_extension import TensorArray
|
||||
from ray.air.util.data_batch_conversion import (
|
||||
_cast_ndarray_columns_to_tensor_extension,
|
||||
)
|
||||
|
||||
if len(tables) > 1:
|
||||
df = pandas.concat(tables, ignore_index=True)
|
||||
else:
|
||||
df = tables[0]
|
||||
# Try to convert any ndarray columns to TensorArray columns.
|
||||
# TODO(Clark): Once Pandas supports registering extension types for type
|
||||
# inference on construction, implement as much for NumPy ndarrays and remove
|
||||
# this. See https://github.com/pandas-dev/pandas/issues/41848
|
||||
for col_name, col in df.items():
|
||||
if (
|
||||
col.dtype.type is np.object_
|
||||
and not col.empty
|
||||
and isinstance(col.iloc[0], np.ndarray)
|
||||
):
|
||||
try:
|
||||
df.loc[:, col_name] = TensorArray(col)
|
||||
except Exception as e:
|
||||
if ray.util.log_once("datasets_tensor_array_cast_warning"):
|
||||
logger.warning(
|
||||
f"Tried to transparently convert column {col_name} to a "
|
||||
"TensorArray but the conversion failed, leaving column "
|
||||
f"as-is: {e}"
|
||||
)
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
ctx = DatasetContext.get_current()
|
||||
if ctx.enable_tensor_extension_casting:
|
||||
df = _cast_ndarray_columns_to_tensor_extension(df)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
|
@ -147,21 +132,31 @@ class PandasBlockAccessor(TableBlockAccessor):
|
|||
|
||||
@staticmethod
|
||||
def _build_tensor_row(row: PandasRow) -> np.ndarray:
|
||||
# Getting an item in a Pandas tensor column returns a TensorArrayElement, which
|
||||
# we have to convert to an ndarray.
|
||||
return row[VALUE_COL_NAME].iloc[0].to_numpy()
|
||||
from ray.data.extensions import TensorArrayElement
|
||||
|
||||
tensor = row[VALUE_COL_NAME].iloc[0]
|
||||
if isinstance(tensor, TensorArrayElement):
|
||||
# Getting an item in a Pandas tensor column may return a TensorArrayElement,
|
||||
# which we have to convert to an ndarray.
|
||||
tensor = tensor.to_numpy()
|
||||
return tensor
|
||||
|
||||
def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
|
||||
view = self._table[start:end]
|
||||
view.reset_index(drop=True, inplace=True)
|
||||
if copy:
|
||||
view = view.copy(deep=True)
|
||||
return view
|
||||
|
||||
def take(self, indices: List[int]) -> "pandas.DataFrame":
|
||||
return self._table.take(indices)
|
||||
table = self._table.take(indices)
|
||||
table.reset_index(drop=True, inplace=True)
|
||||
return table
|
||||
|
||||
def random_shuffle(self, random_seed: Optional[int]) -> "pandas.DataFrame":
|
||||
return self._table.sample(frac=1, random_state=random_seed)
|
||||
table = self._table.sample(frac=1, random_state=random_seed)
|
||||
table.reset_index(drop=True, inplace=True)
|
||||
return table
|
||||
|
||||
def schema(self) -> PandasBlockSchema:
|
||||
dtypes = self._table.dtypes
|
||||
|
@ -179,7 +174,13 @@ class PandasBlockAccessor(TableBlockAccessor):
|
|||
return schema
|
||||
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
return self._table
|
||||
from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
|
||||
|
||||
ctx = DatasetContext.get_current()
|
||||
table = self._table
|
||||
if ctx.enable_tensor_extension_casting:
|
||||
table = _cast_tensor_columns_to_ndarrays(table)
|
||||
return table
|
||||
|
||||
def to_numpy(
|
||||
self, columns: Optional[Union[str, List[str]]] = None
|
||||
|
|
|
@ -68,6 +68,10 @@ DEFAULT_USE_POLARS = False
|
|||
# Whether to estimate in-memory decoding data size for data source.
|
||||
DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED = False
|
||||
|
||||
# Whether to automatically cast NumPy ndarray columns in Pandas DataFrames to tensor
|
||||
# extension columns.
|
||||
DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING = True
|
||||
|
||||
# Use this to prefix important warning messages for the user.
|
||||
WARN_PREFIX = "⚠️ "
|
||||
|
||||
|
@ -102,6 +106,7 @@ class DatasetContext:
|
|||
use_polars: bool,
|
||||
decoding_size_estimation: bool,
|
||||
min_parallelism: bool,
|
||||
enable_tensor_extension_casting: bool,
|
||||
):
|
||||
"""Private constructor (use get_current() instead)."""
|
||||
self.block_owner = block_owner
|
||||
|
@ -123,6 +128,7 @@ class DatasetContext:
|
|||
self.use_polars = use_polars
|
||||
self.decoding_size_estimation = decoding_size_estimation
|
||||
self.min_parallelism = min_parallelism
|
||||
self.enable_tensor_extension_casting = enable_tensor_extension_casting
|
||||
|
||||
@staticmethod
|
||||
def get_current() -> "DatasetContext":
|
||||
|
@ -157,6 +163,9 @@ class DatasetContext:
|
|||
use_polars=DEFAULT_USE_POLARS,
|
||||
decoding_size_estimation=DEFAULT_DECODING_SIZE_ESTIMATION_ENABLED,
|
||||
min_parallelism=DEFAULT_MIN_PARALLELISM,
|
||||
enable_tensor_extension_casting=(
|
||||
DEFAULT_ENABLE_TENSOR_EXTENSION_CASTING
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import logging
|
||||
import pathlib
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.data._internal.util import _check_import
|
||||
from ray.data.datasource.binary_datasource import BinaryDatasource
|
||||
from ray.data.datasource.datasource import Reader
|
||||
|
@ -18,7 +16,6 @@ if TYPE_CHECKING:
|
|||
import pyarrow
|
||||
from ray.data.block import T
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "tiff", "bmp", "gif"]
|
||||
|
||||
|
||||
|
@ -116,7 +113,6 @@ class ImageFolderDatasource(BinaryDatasource):
|
|||
):
|
||||
import imageio as iio
|
||||
import pandas as pd
|
||||
from ray.data.extensions import TensorArray
|
||||
import skimage
|
||||
|
||||
records = super()._read_file(f, path, include_paths=True)
|
||||
|
@ -127,23 +123,11 @@ class ImageFolderDatasource(BinaryDatasource):
|
|||
image = skimage.transform.resize(image, size)
|
||||
image = skimage.util.img_as_ubyte(image)
|
||||
|
||||
try:
|
||||
# Try to convert image `ndarray` to `TensorArray`s.
|
||||
image = TensorArray([np.array(image)])
|
||||
except TypeError as e:
|
||||
# Fall back to existing NumPy array.
|
||||
if ray.util.log_once("datasets_tensor_array_cast_warning"):
|
||||
logger.warning(
|
||||
"Tried to transparently convert image `ndarray` to a "
|
||||
"`TensorArray`, but the conversion failed. Left image ndarray "
|
||||
f" as-is: {e}"
|
||||
)
|
||||
|
||||
label = _get_class_from_path(path, root)
|
||||
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"image": image,
|
||||
"image": [np.array(image)],
|
||||
"label": [label],
|
||||
}
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@ from typing import List, Optional
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ray.data.extensions import TensorArray
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
|
@ -90,11 +89,7 @@ class Concatenator(Preprocessor):
|
|||
columns_to_concat = list(included_columns - set(self.exclude))
|
||||
concatenated = df[columns_to_concat].to_numpy(dtype=self.dtype)
|
||||
df = df.drop(columns=columns_to_concat)
|
||||
try:
|
||||
concatenated = TensorArray(concatenated)
|
||||
except TypeError:
|
||||
pass
|
||||
df[self.output_column_name] = concatenated
|
||||
df.loc[:, self.output_column_name] = list(concatenated)
|
||||
return df
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -253,3 +253,12 @@ def use_push_based_shuffle(request):
|
|||
ctx.use_push_based_shuffle = request.param
|
||||
yield request.param
|
||||
ctx.use_push_based_shuffle = original
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def enable_automatic_tensor_extension_cast(request):
|
||||
ctx = ray.data.context.DatasetContext.get_current()
|
||||
original = ctx.enable_tensor_extension_casting
|
||||
ctx.enable_tensor_extension_casting = request.param
|
||||
yield request.param
|
||||
ctx.enable_tensor_extension_casting = original
|
||||
|
|
|
@ -848,11 +848,26 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
|
|||
"schema={a: TensorDtype(shape=(4, 4), dtype=float64)})"
|
||||
)
|
||||
|
||||
# Test map_batches ragged ndarray column falls back to opaque object-typed column.
|
||||
ds = ray.data.range(16, parallelism=4).map_batches(
|
||||
lambda _: pd.DataFrame({"a": [np.ones((2, 2)), np.ones((3, 3))]}), batch_size=2
|
||||
)
|
||||
assert str(ds) == ("Dataset(num_blocks=4, num_rows=16, schema={a: object})")
|
||||
# Test map_batches ragged ndarray column fails by default.
|
||||
with pytest.raises(ValueError):
|
||||
ds = ray.data.range(16, parallelism=4).map_batches(
|
||||
lambda _: pd.DataFrame({"a": [np.ones((2, 2)), np.ones((3, 3))]}),
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
# Test map_batches ragged ndarray column uses opaque object-typed column if
|
||||
# automatic tensor extension type casting is disabled.
|
||||
ctx = DatasetContext.get_current()
|
||||
old_config = ctx.enable_tensor_extension_casting
|
||||
ctx.enable_tensor_extension_casting = False
|
||||
try:
|
||||
ds = ray.data.range(16, parallelism=4).map_batches(
|
||||
lambda _: pd.DataFrame({"a": [np.ones((2, 2)), np.ones((3, 3))]}),
|
||||
batch_size=2,
|
||||
)
|
||||
assert str(ds) == ("Dataset(num_blocks=4, num_rows=16, schema={a: object})")
|
||||
finally:
|
||||
ctx.enable_tensor_extension_casting = old_config
|
||||
|
||||
|
||||
def test_tensors_in_tables_from_pandas(ray_start_regular_shared):
|
||||
|
@ -871,7 +886,10 @@ def test_tensors_in_tables_from_pandas(ray_start_regular_shared):
|
|||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared):
|
||||
def test_tensors_in_tables_pandas_roundtrip(
|
||||
ray_start_regular_shared,
|
||||
enable_automatic_tensor_extension_cast,
|
||||
):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim,) + inner_shape
|
||||
|
@ -880,7 +898,10 @@ def test_tensors_in_tables_pandas_roundtrip(ray_start_regular_shared):
|
|||
df = pd.DataFrame({"one": list(range(outer_dim)), "two": TensorArray(arr)})
|
||||
ds = ray.data.from_pandas([df])
|
||||
ds_df = ds.to_pandas()
|
||||
assert ds_df.equals(df)
|
||||
expected_df = df
|
||||
if enable_automatic_tensor_extension_cast:
|
||||
expected_df.loc[:, "two"] = list(expected_df["two"].to_numpy())
|
||||
pd.testing.assert_frame_equal(ds_df, expected_df)
|
||||
|
||||
|
||||
def test_tensors_in_tables_parquet_roundtrip(ray_start_regular_shared, tmp_path):
|
||||
|
@ -1117,6 +1138,40 @@ def test_tensors_in_tables_parquet_bytes_with_schema(
|
|||
np.testing.assert_equal(v, e)
|
||||
|
||||
|
||||
def test_tensors_in_tables_iter_batches(
|
||||
ray_start_regular_shared,
|
||||
enable_automatic_tensor_extension_cast,
|
||||
):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim,) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
df1 = pd.DataFrame(
|
||||
{"one": TensorArray(arr), "two": TensorArray(arr + 1), "label": [1.0, 2.0, 3.0]}
|
||||
)
|
||||
arr2 = np.arange(num_items, 2 * num_items).reshape(shape)
|
||||
df2 = pd.DataFrame(
|
||||
{
|
||||
"one": TensorArray(arr2),
|
||||
"two": TensorArray(arr2 + 1),
|
||||
"label": [4.0, 5.0, 6.0],
|
||||
}
|
||||
)
|
||||
df = pd.concat([df1, df2], ignore_index=True)
|
||||
if enable_automatic_tensor_extension_cast:
|
||||
df.loc[:, "one"] = list(df["one"].to_numpy())
|
||||
df.loc[:, "two"] = list(df["two"].to_numpy())
|
||||
ds = ray.data.from_pandas([df1, df2])
|
||||
batches = list(ds.iter_batches(batch_size=2))
|
||||
assert len(batches) == 3
|
||||
expected_batches = [df.iloc[:2], df.iloc[2:4], df.iloc[4:]]
|
||||
for batch, expected_batch in zip(batches, expected_batches):
|
||||
batch = batch.reset_index(drop=True)
|
||||
expected_batch = expected_batch.reset_index(drop=True)
|
||||
pd.testing.assert_frame_equal(batch, expected_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipelined", [False, True])
|
||||
def test_tensors_in_tables_to_torch(ray_start_regular_shared, pipelined):
|
||||
outer_dim = 3
|
||||
|
|
|
@ -43,7 +43,6 @@ from ray.data.datasource.parquet_datasource import (
|
|||
_SerializedPiece,
|
||||
_deserialize_pieces_with_retry,
|
||||
)
|
||||
from ray.data.extensions import TensorDtype
|
||||
from ray.data.preprocessors import BatchMapper
|
||||
from ray.data.tests.conftest import * # noqa
|
||||
from ray.data.tests.mock_http_server import * # noqa
|
||||
|
@ -2823,7 +2822,9 @@ def test_torch_datasource_value_error(ray_start_regular_shared, local_path):
|
|||
)
|
||||
|
||||
|
||||
def test_image_folder_datasource(ray_start_regular_shared):
|
||||
def test_image_folder_datasource(
|
||||
ray_start_regular_shared, enable_automatic_tensor_extension_cast
|
||||
):
|
||||
root = os.path.join(os.path.dirname(__file__), "image-folder")
|
||||
ds = ray.data.read_datasource(ImageFolderDatasource(), root=root, size=(64, 64))
|
||||
|
||||
|
@ -2831,8 +2832,9 @@ def test_image_folder_datasource(ray_start_regular_shared):
|
|||
|
||||
df = ds.to_pandas()
|
||||
assert sorted(df["label"]) == ["cat", "cat", "dog"]
|
||||
assert type(df["image"].dtype) is TensorDtype
|
||||
assert all(tensor.to_numpy().shape == (64, 64, 3) for tensor in df["image"])
|
||||
assert df["image"].dtype.type is np.object_
|
||||
tensors = df["image"]
|
||||
assert all(tensor.shape == (64, 64, 3) for tensor in tensors)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [(-32, 32), (32, -32), (-32, -32)])
|
||||
|
@ -2843,7 +2845,6 @@ def test_image_folder_datasource_value_error(ray_start_regular_shared, size):
|
|||
|
||||
|
||||
def test_image_folder_datasource_e2e(ray_start_regular_shared):
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
from ray.train.torch import TorchCheckpoint, TorchPredictor
|
||||
from ray.train.batch_predictor import BatchPredictor
|
||||
|
||||
|
@ -2856,17 +2857,8 @@ def test_image_folder_datasource_e2e(ray_start_regular_shared):
|
|||
)
|
||||
|
||||
def preprocess(df):
|
||||
# We convert the `TensorArrayElement` to a NumPy array because `ToTensor`
|
||||
# expects a NumPy array or PIL image. `ToTensor` is necessary because Torch
|
||||
# expects images to have shape (C, H, W), and `ToTensor` changes the shape of
|
||||
# the data from (H, W, C) to (C, H, W).
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
lambda ray_tensor: ray_tensor.to_numpy(),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
df["image"] = TensorArray([preprocess(image) for image in df["image"]])
|
||||
preprocess = transforms.Compose([transforms.ToTensor()])
|
||||
df.loc[:, "image"] = [preprocess(image).numpy() for image in df["image"]]
|
||||
return df
|
||||
|
||||
preprocessor = BatchMapper(preprocess)
|
||||
|
|
|
@ -9,6 +9,7 @@ import pyarrow
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.data.context import DatasetContext
|
||||
from ray.data.preprocessor import Preprocessor, PreprocessorNotFittedException
|
||||
from ray.data.preprocessors import (
|
||||
BatchMapper,
|
||||
|
@ -1229,7 +1230,7 @@ def test_concatenator():
|
|||
prep = Concatenator(output_column_name="c")
|
||||
new_ds = prep.transform(ds)
|
||||
for i, row in enumerate(new_ds.take()):
|
||||
assert np.array_equal(row["c"].to_numpy(), np.array([i + 1, i + 1]))
|
||||
assert np.array_equal(row["c"], np.array([i + 1, i + 1]))
|
||||
|
||||
df = pd.DataFrame({"a": [1, 2, 3, 4]})
|
||||
ds = ray.data.from_pandas(df)
|
||||
|
@ -1258,12 +1259,23 @@ def test_concatenator():
|
|||
for i, row in enumerate(new_ds.take()):
|
||||
assert set(row) == {"concat_out", "b", "c"}
|
||||
|
||||
# check it works with string types
|
||||
# check it fails with string types by default
|
||||
df = pd.DataFrame({"a": ["string", "string2", "string3"]})
|
||||
ds = ray.data.from_pandas(df)
|
||||
prep = Concatenator(output_column_name="huh")
|
||||
new_ds = prep.transform(ds)
|
||||
assert "huh" in set(new_ds.schema().names)
|
||||
with pytest.raises(ValueError):
|
||||
new_ds = prep.transform(ds)
|
||||
|
||||
# check it works with string types if automatic tensor extension casting is
|
||||
# disabled
|
||||
ctx = DatasetContext.get_current()
|
||||
old_config = ctx.enable_tensor_extension_casting
|
||||
ctx.enable_tensor_extension_casting = False
|
||||
try:
|
||||
new_ds = prep.transform(ds)
|
||||
assert "huh" in set(new_ds.schema().names)
|
||||
finally:
|
||||
ctx.enable_tensor_extension_casting = old_config
|
||||
|
||||
|
||||
def test_tokenizer():
|
||||
|
|
|
@ -64,7 +64,11 @@ class DLPredictor(Predictor):
|
|||
def _predict_pandas(
|
||||
self, data: pd.DataFrame, dtype: Union[TensorDtype, Dict[str, TensorDtype]]
|
||||
) -> pd.DataFrame:
|
||||
tensors = convert_pandas_to_batch_type(data, DataType.NUMPY)
|
||||
tensors = convert_pandas_to_batch_type(
|
||||
data,
|
||||
DataType.NUMPY,
|
||||
self._cast_tensor_columns,
|
||||
)
|
||||
model_input = self._arrays_to_tensors(tensors, dtype)
|
||||
|
||||
output = self.call_model(model_input)
|
||||
|
|
|
@ -6,6 +6,7 @@ import ray
|
|||
from ray.air import Checkpoint
|
||||
from ray.air.util.data_batch_conversion import convert_batch_type_to_pandas
|
||||
from ray.data import Preprocessor
|
||||
from ray.data.context import DatasetContext
|
||||
from ray.data.preprocessors import BatchMapper
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
@ -169,12 +170,18 @@ class BatchPredictor:
|
|||
):
|
||||
predictor_kwargs["use_gpu"] = True
|
||||
|
||||
ctx = DatasetContext.get_current()
|
||||
cast_tensor_columns = ctx.enable_tensor_extension_casting
|
||||
|
||||
class ScoringWrapper:
|
||||
def __init__(self):
|
||||
checkpoint = Checkpoint.from_object_ref(checkpoint_ref)
|
||||
self._predictor = predictor_cls.from_checkpoint(
|
||||
checkpoint, **predictor_kwargs
|
||||
)
|
||||
if cast_tensor_columns:
|
||||
# Enable automatic tensor column casting at UDF boundaries.
|
||||
self._predictor._set_cast_tensor_columns()
|
||||
if override_prep:
|
||||
self._predictor.set_preprocessor(override_prep)
|
||||
|
||||
|
@ -188,7 +195,9 @@ class BatchPredictor:
|
|||
)
|
||||
if keep_columns:
|
||||
prediction_output[keep_columns] = batch[keep_columns]
|
||||
return convert_batch_type_to_pandas(prediction_output)
|
||||
return convert_batch_type_to_pandas(
|
||||
prediction_output, cast_tensor_columns
|
||||
)
|
||||
|
||||
compute = ray.data.ActorPoolStrategy(
|
||||
min_size=min_scoring_workers, max_size=max_scoring_workers
|
||||
|
|
|
@ -6,6 +6,7 @@ import pandas as pd
|
|||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
from ray.train.lightgbm.lightgbm_checkpoint import LightGBMCheckpoint
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
@ -118,6 +119,7 @@ class LightGBMPredictor(Predictor):
|
|||
feature_names = None
|
||||
if TENSOR_COLUMN_NAME in data:
|
||||
data = data[TENSOR_COLUMN_NAME].to_numpy()
|
||||
data = _unwrap_ndarray_object_type_if_needed(data)
|
||||
if feature_columns:
|
||||
# In this case feature_columns is a list of integers
|
||||
data = data[:, feature_columns]
|
||||
|
|
|
@ -76,6 +76,9 @@ class Predictor(abc.ABC):
|
|||
def __init__(self, preprocessor: Optional[Preprocessor] = None):
|
||||
"""Subclasseses must call Predictor.__init__() to set a preprocessor."""
|
||||
self._preprocessor: Optional[Preprocessor] = preprocessor
|
||||
# Whether tensor columns should be automatically cast from/to the tensor
|
||||
# extension type at UDF boundaries. This can be overridden by subclasses.
|
||||
self._cast_tensor_columns = False
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
|
@ -120,6 +123,15 @@ class Predictor(abc.ABC):
|
|||
"""Set the preprocessor to use prior to executing predictions."""
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
def _set_cast_tensor_columns(self):
|
||||
"""Enable automatic tensor column casting.
|
||||
|
||||
If this is called on a predictor, the predictor will cast tensor columns to
|
||||
NumPy ndarrays in the input to the preprocessors and cast tensor columns back to
|
||||
the tensor extension type in the prediction outputs.
|
||||
"""
|
||||
self._cast_tensor_columns = True
|
||||
|
||||
def predict(self, data: DataBatchType, **kwargs) -> DataBatchType:
|
||||
"""Perform inference on a batch of data.
|
||||
|
||||
|
@ -132,7 +144,7 @@ class Predictor(abc.ABC):
|
|||
DataBatchType: Prediction result. The return type will be the same as the
|
||||
input type.
|
||||
"""
|
||||
data_df = convert_batch_type_to_pandas(data)
|
||||
data_df = convert_batch_type_to_pandas(data, self._cast_tensor_columns)
|
||||
|
||||
if not hasattr(self, "_preprocessor"):
|
||||
raise NotImplementedError(
|
||||
|
@ -144,7 +156,9 @@ class Predictor(abc.ABC):
|
|||
|
||||
predictions_df = self._predict_pandas(data_df, **kwargs)
|
||||
return convert_pandas_to_batch_type(
|
||||
predictions_df, type=TYPE_TO_ENUM[type(data)]
|
||||
predictions_df,
|
||||
type=TYPE_TO_ENUM[type(data)],
|
||||
cast_tensor_columns=self._cast_tensor_columns,
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -5,6 +5,7 @@ import pandas as pd
|
|||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
from ray.train.predictor import Predictor
|
||||
|
@ -66,6 +67,7 @@ class RLPredictor(Predictor):
|
|||
def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame":
|
||||
if TENSOR_COLUMN_NAME in data:
|
||||
obs = data[TENSOR_COLUMN_NAME].to_numpy()
|
||||
obs = _unwrap_ndarray_object_type_if_needed(obs)
|
||||
else:
|
||||
obs = data.to_numpy()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from sklearn.base import BaseEstimator
|
|||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.sklearn._sklearn_utils import _set_cpu_params
|
||||
from ray.train.sklearn.sklearn_checkpoint import SklearnCheckpoint
|
||||
|
@ -130,6 +131,7 @@ class SklearnPredictor(Predictor):
|
|||
|
||||
if TENSOR_COLUMN_NAME in data:
|
||||
data = data[TENSOR_COLUMN_NAME].to_numpy()
|
||||
data = _unwrap_ndarray_object_type_if_needed(data)
|
||||
if feature_columns:
|
||||
data = data[:, feature_columns]
|
||||
elif feature_columns:
|
||||
|
|
|
@ -70,10 +70,10 @@ def test_predict(convert_to_pandas_mock, convert_from_pandas_mock):
|
|||
predictor = DummyPredictor.from_checkpoint(checkpoint)
|
||||
|
||||
actual_output = predictor.predict(input)
|
||||
assert actual_output.equals(expected_output)
|
||||
pd.testing.assert_frame_equal(actual_output, expected_output)
|
||||
|
||||
# Ensure the proper conversion functions are called.
|
||||
convert_to_pandas_mock.assert_called_once_with(input)
|
||||
convert_to_pandas_mock.assert_called_once_with(input, False)
|
||||
convert_from_pandas_mock.assert_called_once()
|
||||
|
||||
pd.testing.assert_frame_equal(
|
||||
|
|
|
@ -6,6 +6,7 @@ import xgboost
|
|||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
|
@ -123,6 +124,7 @@ class XGBoostPredictor(Predictor):
|
|||
feature_names = None
|
||||
if TENSOR_COLUMN_NAME in data:
|
||||
data = data[TENSOR_COLUMN_NAME].to_numpy()
|
||||
data = _unwrap_ndarray_object_type_if_needed(data)
|
||||
if feature_columns:
|
||||
# In this case feature_columns is a list of integers
|
||||
data = data[:, feature_columns]
|
||||
|
|
Loading…
Add table
Reference in a new issue