Revert "[Datasets] Automatically cast tensor columns when building Pandas blocks. (#26684)" (#26921)

This reverts commit 0c139914bb.
This commit is contained in:
Chen Shen 2022-07-22 22:26:40 -07:00 committed by GitHub
parent 170bde40a0
commit 042450d319
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 89 additions and 243 deletions

View file

@ -29,7 +29,7 @@ def test_numpy_pandas():
actual_output = convert_batch_type_to_pandas(input_data)
assert expected_output.equals(actual_output)
np.testing.assert_array_equal(
assert np.array_equal(
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
)
@ -40,18 +40,18 @@ def test_numpy_multi_dim_pandas():
actual_output = convert_batch_type_to_pandas(input_data)
assert expected_output.equals(actual_output)
np.testing.assert_array_equal(
assert np.array_equal(
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
)
def test_numpy_object_pandas():
input_data = np.array([[1, 2, 3], [1]], dtype=object)
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: input_data})
expected_output = pd.DataFrame({TENSOR_COLUMN_NAME: TensorArray(input_data)})
actual_output = convert_batch_type_to_pandas(input_data)
assert expected_output.equals(actual_output)
np.testing.assert_array_equal(
assert np.array_equal(
convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY), input_data
)
@ -69,7 +69,7 @@ def test_dict_pandas():
assert expected_output.equals(actual_output)
output_array = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
np.testing.assert_array_equal(output_array, input_data["x"])
assert np.array_equal(output_array, input_data["x"])
def test_dict_multi_dim_to_pandas():
@ -80,7 +80,7 @@ def test_dict_multi_dim_to_pandas():
assert expected_output.equals(actual_output)
output_array = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
np.testing.assert_array_equal(output_array, input_data["x"])
assert np.array_equal(output_array, input_data["x"])
def test_dict_pandas_multi_column():
@ -91,7 +91,7 @@ def test_dict_pandas_multi_column():
output_dict = convert_pandas_to_batch_type(actual_output, type=DataType.NUMPY)
for k, v in output_dict.items():
np.testing.assert_array_equal(v, array_dict[k])
assert np.array_equal(v, array_dict[k])
def test_arrow_pandas():

View file

@ -1,10 +1,8 @@
from enum import Enum, auto
import logging
import numpy as np
import pandas as pd
import ray
from ray.air.data_batch_type import DataBatchType
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.util.annotations import DeveloperAPI
@ -14,8 +12,6 @@ try:
except ImportError:
pyarrow = None
logger = logging.getLogger(__name__)
@DeveloperAPI
class DataType(Enum):
@ -35,24 +31,13 @@ def convert_batch_type_to_pandas(data: DataBatchType) -> pd.DataFrame:
A pandas Dataframe representation of the input data.
"""
global _tensor_cast_failed_warned
from ray.air.util.tensor_extensions.pandas import TensorArray
if isinstance(data, pd.DataFrame):
return data
elif isinstance(data, np.ndarray):
try:
# Try to convert numpy arrays to TensorArrays.
data = TensorArray(data)
except TypeError as e:
# Fall back to existing NumPy array.
if ray.util.log_once("datasets_tensor_array_cast_warning"):
logger.warning(
"Tried to transparently convert ndarray batch to a TensorArray "
f"but the conversion failed, leaving ndarray batch as-is: {e}"
)
return pd.DataFrame({TENSOR_COLUMN_NAME: data})
return pd.DataFrame({TENSOR_COLUMN_NAME: TensorArray(data)})
elif isinstance(data, dict):
tensor_dict = {}
@ -63,18 +48,8 @@ def convert_batch_type_to_pandas(data: DataBatchType) -> pd.DataFrame:
f"np.ndarray. Found type {type(v)} for key {k} "
f"instead."
)
try:
# Try to convert numpy arrays to TensorArrays.
v = TensorArray(v)
except TypeError as e:
# Fall back to existing NumPy array.
if ray.util.log_once("datasets_tensor_array_cast_warning"):
logger.warning(
f"Tried to transparently convert column ndarray {k} of batch "
"to a TensorArray but the conversion failed, leaving column "
f"as-is: {e}"
)
tensor_dict[k] = v
# Convert numpy arrays to TensorArray.
tensor_dict[k] = TensorArray(v)
return pd.DataFrame(tensor_dict)
elif pyarrow is not None and isinstance(data, pyarrow.Table):

View file

@ -45,7 +45,7 @@ class ArrowTensorType(pa.PyExtensionType):
"""
from ray.air.util.tensor_extensions.pandas import TensorDtype
return TensorDtype(self._shape, self.storage_type.value_type.to_pandas_dtype())
return TensorDtype()
def __reduce__(self):
return ArrowTensorType, (self._shape, self.storage_type.value_type)
@ -60,14 +60,11 @@ class ArrowTensorType(pa.PyExtensionType):
"""
return ArrowTensorArray
def __str__(self) -> str:
return (
f"ArrowTensorType(shape={self.shape}, dtype={self.storage_type.value_type})"
def __str__(self):
return "<ArrowTensorType: shape={}, dtype={}>".format(
self.shape, self.storage_type.value_type
)
def __repr__(self) -> str:
return str(self)
@PublicAPI(stability="beta")
class ArrowTensorArray(pa.ExtensionArray):

View file

@ -204,12 +204,12 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
dtype: object
>>> # Cast column to our TensorDtype extension type.
>>> from ray.data.extensions import TensorDtype
>>> df["two"] = df["two"].astype(TensorDtype((3, 2, 2, 2), np.int64))
>>> df["two"] = df["two"].astype(TensorDtype())
>>> # Note that the column dtype is now TensorDtype instead of
>>> # np.object.
>>> df.dtypes # doctest: +SKIP
one int64
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
two TensorDtype
dtype: object
>>> # Pandas is now aware of this tensor column, and we can do the
>>> # typical DataFrame operations on this column.
@ -231,7 +231,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
[38 40]]
[[42 44]
[46 48]]]
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
Name: two, dtype: TensorDtype
>>> # Once you do an aggregation on that column that returns a single
>>> # row's value, you get back our TensorArrayElement type.
>>> tensor = col.mean()
@ -264,7 +264,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
>>> read_df.dtypes # doctest: +SKIP
one int64
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
two TensorDtype
dtype: object
>>> # The tensor extension type is preserved along the
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
@ -278,10 +278,6 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
# https://github.com/CODAIT/text-extensions-for-pandas/issues/166
base = None
def __init__(self, shape: Tuple[int, ...], dtype: np.dtype):
self._shape = shape
self._dtype = dtype
@property
def type(self):
"""
@ -299,7 +295,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
A string identifying the data type.
Will be used for display in, e.g. ``Series.dtype``
"""
return f"{type(self).__name__}(shape={self._shape}, dtype={self._dtype})"
return "TensorDtype"
@classmethod
def construct_from_string(cls, string: str):
@ -346,26 +342,16 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
... f"Cannot construct a '{cls.__name__}' from '{string}'"
... )
"""
import ast
import re
if not isinstance(string, str):
raise TypeError(
f"'construct_from_string' expects a string, got {type(string)}"
)
# Upstream code uses exceptions as part of its normal control flow and
# will pass this method bogus class names.
regex = r"^TensorDtype\(shape=(\(\d+,(?:\s\d+,?)*\)), dtype=(\w+)\)$"
m = re.search(regex, string)
if m is None:
raise TypeError(
f"Cannot construct a '{cls.__name__}' from '{string}'; expected a "
"string like 'TensorDtype(shape=(1, 2, 3), dtype=int64)'."
)
shape, dtype = m.groups()
shape = ast.literal_eval(shape)
dtype = np.dtype(dtype)
return cls(shape, dtype)
if string == cls.__name__:
return cls()
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@classmethod
def construct_array_type(cls):
@ -402,33 +388,6 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
return TensorArray(values)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return str(self)
@property
def _is_boolean(self):
"""
Whether this extension array should be considered boolean.
By default, ExtensionArrays are assumed to be non-numeric.
Setting this to True will affect the behavior of several places,
e.g.
* is_bool
* boolean indexing
Returns
-------
bool
"""
# This is needed to support returning a TensorArray from .isnan().
from pandas.core.dtypes.common import is_bool_dtype
return is_bool_dtype(self._dtype)
class TensorOpsMixin(pd.api.extensions.ExtensionScalarOpsMixin):
"""
@ -593,7 +552,7 @@ class TensorArray(
>>> # Note that the column dtype is TensorDtype.
>>> df.dtypes # doctest: +SKIP
one int64
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
two TensorDtype
dtype: object
>>> # Pandas is aware of this tensor column, and we can do the
>>> # typical DataFrame operations on this column.
@ -615,7 +574,7 @@ class TensorArray(
[38 40]]
[[42 44]
[46 48]]]
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
Name: two, dtype: TensorDtype
>>> # Once you do an aggregation on that column that returns a single
>>> # row's value, you get back our TensorArrayElement type.
>>> tensor = col.mean() # doctest: +SKIP
@ -649,7 +608,7 @@ class TensorArray(
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] # doctest: +SKIP
>>> read_df.dtypes # doctest: +SKIP
one int64
two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
two TensorDtype
dtype: object
>>> # The tensor extension type is preserved along the
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
@ -692,40 +651,24 @@ class TensorArray(
# Convert series to ndarray and passthrough to ndarray handling
# logic.
values = values.to_numpy()
elif isinstance(values, Sequence):
values = np.array([np.asarray(v) for v in values])
if isinstance(values, np.ndarray):
if values.dtype.type is np.object_:
if len(values) == 0 or (
not isinstance(values[0], str)
and isinstance(
values[0], (np.ndarray, TensorArrayElement, Sequence)
)
if (
values.dtype.type is np.object_
and len(values) > 0
and isinstance(values[0], (np.ndarray, TensorArrayElement))
):
# Convert ndarrays of ndarrays/TensorArrayElements
# with an opaque object type to a properly typed ndarray of
# ndarrays.
self._tensor = np.array([np.asarray(v) for v in values])
if self._tensor.dtype.type is np.object_:
subndarray_types = [v.dtype for v in self._tensor]
raise TypeError(
"Tried to convert an ndarray of ndarray pointers (object "
"dtype) to a well-typed ndarray but this failed; convert "
"the ndarray to a well-typed ndarray before casting it as "
"a TensorArray, and note that ragged tensors are NOT "
"supported by TensorArray. subndarray types: "
f"{subndarray_types}"
)
else:
raise TypeError(
"Expected a well-typed ndarray or an object-typed ndarray of "
"ndarray pointers, but got an object-typed ndarray whose "
f"subndarrays are of type {type(values[0])}."
)
else:
# ndarray is well-typed, use it directly as the backing tensor.
self._tensor = values
elif isinstance(values, Sequence):
if len(values) == 0:
self._tensor = np.array([])
else:
self._tensor = np.stack([np.asarray(v) for v in values], axis=0)
elif isinstance(values, TensorArrayElement):
self._tensor = np.array([np.asarray(values)])
elif np.isscalar(values):
@ -857,7 +800,7 @@ class TensorArray(
"""
An instance of 'ExtensionDtype'.
"""
return TensorDtype(self.numpy_shape[1:], self.numpy_dtype)
return TensorDtype()
@property
def nbytes(self) -> int:
@ -1241,6 +1184,27 @@ class TensorArray(
"""
return self._tensor.size
@property
def _is_boolean(self):
"""
Whether this extension array should be considered boolean.
By default, ExtensionArrays are assumed to be non-numeric.
Setting this to True will affect the behavior of several places,
e.g.
* is_bool
* boolean indexing
Returns
-------
bool
"""
# This is needed to support returning a TensorArray from .isnan().
# TODO(Clark): Propagate tensor dtype to extension TensorDtype and
# move this property there.
return np.issubdtype(self._tensor.dtype, np.bool)
def astype(self, dtype, copy=True):
"""
Cast to a NumPy array with 'dtype'.
@ -1325,25 +1289,6 @@ class TensorArray(
return ArrowTensorArray.from_numpy(self._tensor)
@property
def _is_boolean(self):
"""
Whether this extension array should be considered boolean.
By default, ExtensionArrays are assumed to be non-numeric.
Setting this to True will affect the behavior of several places,
e.g.
* is_bool
* boolean indexing
Returns
-------
bool
"""
# This is needed to support returning a TensorArray from .isnan().
return self.dtype._is_boolean()
# Add operators from the mixin to the TensorArrayElement and TensorArray
# classes.

View file

@ -1,4 +1,3 @@
import logging
from typing import (
Callable,
Dict,
@ -16,7 +15,6 @@ import collections
import heapq
import numpy as np
import ray
from ray.data.block import (
Block,
BlockAccessor,
@ -42,7 +40,6 @@ if TYPE_CHECKING:
T = TypeVar("T")
_pandas = None
logger = logging.getLogger(__name__)
def lazy_import_pandas():
@ -98,29 +95,7 @@ class PandasBlockBuilder(TableBlockBuilder[T]):
def _concat_tables(self, tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
pandas = lazy_import_pandas()
from ray.data.extensions.tensor_extension import TensorArray
df = pandas.concat(tables, ignore_index=True)
# Try to convert any ndarray columns to TensorArray columns.
# TODO(Clark): Once Pandas supports registering extension types for type
# inference on construction, implement as much for NumPy ndarrays and remove
# this. See https://github.com/pandas-dev/pandas/issues/41848
for col_name, col in df.items():
if (
col.dtype.type is np.object_
and not col.empty
and isinstance(col[0], np.ndarray)
):
try:
df[col_name] = TensorArray(col)
except Exception as e:
if ray.util.log_once("datasets_tensor_array_cast_warning"):
logger.warning(
f"Tried to transparently convert column {col_name} to a "
"TensorArray but the conversion failed, leaving column "
f"as-is: {e}"
)
return df
return pandas.concat(tables, ignore_index=True)
@staticmethod
def _empty_table() -> "pandas.DataFrame":

View file

@ -109,6 +109,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Whether we have warned of Datasets containing multiple epochs of data.
_epoch_warned = False
# Whether we have warned about using slow Dataset transforms.
_slow_warned = False
TensorflowFeatureTypeSpec = Union[
"tf.TypeSpec", List["tf.TypeSpec"], Dict[str, "tf.TypeSpec"]
]
@ -1378,13 +1384,15 @@ class Dataset(Generic[T]):
epochs = [ds._get_epoch() for ds in datasets]
max_epoch = max(*epochs)
if len(set(epochs)) > 1:
if ray.util.log_once("datasets_epoch_warned"):
global _epoch_warned
if not _epoch_warned:
logger.warning(
"Dataset contains data from multiple epochs: {}, "
"likely due to a `rewindow()` call. The higher epoch "
"number {} will be used. This warning will not "
"be shown again.".format(set(epochs), max_epoch)
)
_epoch_warned = True
dataset_stats = DatasetStats(
stages={"union": []},
parent=[d._plan.stats() for d in datasets],
@ -3723,7 +3731,7 @@ class Dataset(Generic[T]):
for n, t in zip(schema.names, schema.types):
if hasattr(t, "__name__"):
t = t.__name__
schema_str.append(f"{n}: {t}")
schema_str.append("{}: {}".format(n, t))
schema_str = ", ".join(schema_str)
schema_str = "{" + schema_str + "}"
count = self._meta_count()
@ -3771,7 +3779,9 @@ class Dataset(Generic[T]):
self._epoch = epoch
def _warn_slow(self):
if ray.util.log_once("datasets_slow_warned"):
global _slow_warned
if not _slow_warned:
_slow_warned = True
logger.warning(
"The `map`, `flat_map`, and `filter` operations are unvectorized and "
"can be very slow. Consider using `.map_batches()` instead."

View file

@ -1,10 +1,7 @@
import logging
import pathlib
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import ray
from ray.data.datasource.binary_datasource import BinaryDatasource
from ray.data.datasource.datasource import Reader
from ray.data.datasource.file_based_datasource import (
@ -18,7 +15,6 @@ if TYPE_CHECKING:
import pyarrow
from ray.data.block import T
logger = logging.getLogger(__name__)
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "tiff", "bmp", "gif"]
@ -129,20 +125,10 @@ class ImageFolderDatasource(BinaryDatasource):
image = iio.imread(data)
label = _get_class_from_path(path, self.root)
try:
# Try to convert image ndarray to TensorArrays.
image = TensorArray([np.array(image)])
except TypeError as e:
# Fall back to existing NumPy array.
if ray.util.log_once("datasets_tensor_array_cast_warning"):
logger.warning(
"Tried to transparently convert image ndarray to a TensorArray "
f"but the conversion failed, leaving image ndarray as-is: {e}"
)
return pd.DataFrame(
{
"image": image,
"image": TensorArray([np.array(image)]),
"label": [label],
}
)

View file

@ -90,11 +90,7 @@ class Concatenator(Preprocessor):
columns_to_concat = list(included_columns - set(self.excluded_columns))
concatenated = df[columns_to_concat].to_numpy(dtype=self.dtype)
df = df.drop(columns=columns_to_concat)
try:
concatenated = TensorArray(concatenated)
except TypeError:
pass
df[self.output_column_name] = concatenated
df[self.output_column_name] = TensorArray(concatenated)
return df
def __repr__(self):

View file

@ -458,29 +458,6 @@ def test_range_table(ray_start_regular_shared):
assert ds.take() == [{"value": i} for i in range(10)]
def test_tensor_array_validation():
# Test unknown input type raises TypeError.
with pytest.raises(TypeError):
TensorArray(object())
# Test ragged tensor raises TypeError.
with pytest.raises(TypeError):
TensorArray(np.array([np.ones((2, 2)), np.ones((3, 3))], dtype=object))
with pytest.raises(TypeError):
TensorArray([np.ones((2, 2)), np.ones((3, 3))])
with pytest.raises(TypeError):
TensorArray(pd.Series([np.ones((2, 2)), np.ones((3, 3))]))
# Test non-primitive element raises TypeError.
with pytest.raises(TypeError):
TensorArray(np.array([object(), object()]))
with pytest.raises(TypeError):
TensorArray([object(), object()])
def test_tensor_array_block_slice():
# Test that ArrowBlock slicing works with tensor column extension type.
def check_for_copy(table1, table2, a, b, is_copy):
@ -627,7 +604,7 @@ def test_tensors_basic(ray_start_regular_shared):
ds = ray.data.range_tensor(6, shape=tensor_shape, parallelism=6)
assert str(ds) == (
"Dataset(num_blocks=6, num_rows=6, "
"schema={__value__: ArrowTensorType(shape=(3, 5), dtype=int64)})"
"schema={__value__: <ArrowTensorType: shape=(3, 5), dtype=int64>})"
)
assert ds.size_bytes() == 5 * 3 * 6 * 8
@ -818,7 +795,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=10).map(lambda _: np.ones((4, 4)))
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=10, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
)
# Test map_batches.
@ -827,7 +804,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
)
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=24, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
)
# Test flat_map.
@ -836,24 +813,9 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
)
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=20, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
"schema={__value__: <ArrowTensorType: shape=(4, 4), dtype=double>})"
)
# Test map_batches ndarray column.
ds = ray.data.range(16, parallelism=4).map_batches(
lambda _: pd.DataFrame({"a": [np.ones((4, 4))] * 3}), batch_size=2
)
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=24, "
"schema={a: TensorDtype(shape=(4, 4), dtype=float64)})"
)
# Test map_batches ragged ndarray column falls back to opaque object-typed column.
ds = ray.data.range(16, parallelism=4).map_batches(
lambda _: pd.DataFrame({"a": [np.ones((2, 2)), np.ones((3, 3))]}), batch_size=2
)
assert str(ds) == ("Dataset(num_blocks=4, num_rows=16, schema={a: object})")
def test_tensors_in_tables_from_pandas(ray_start_regular_shared):
outer_dim = 3
@ -863,7 +825,7 @@ def test_tensors_in_tables_from_pandas(ray_start_regular_shared):
arr = np.arange(num_items).reshape(shape)
df = pd.DataFrame({"one": list(range(outer_dim)), "two": list(arr)})
# Cast column to tensor extension dtype.
df["two"] = df["two"].astype(TensorDtype(shape, np.int64))
df["two"] = df["two"].astype(TensorDtype())
ds = ray.data.from_pandas([df])
values = [[s["one"], s["two"]] for s in ds.take()]
expected = list(zip(list(range(outer_dim)), arr))
@ -942,7 +904,7 @@ def test_tensors_in_tables_parquet_pickle_manual_serde(
# extension type.
def deser_mapper(batch: pd.DataFrame):
batch["two"] = [pickle.loads(a) for a in batch["two"]]
batch["two"] = batch["two"].astype(TensorDtype(shape, np.int64))
batch["two"] = batch["two"].astype(TensorDtype())
return batch
casted_ds = ds.map_batches(deser_mapper, batch_format="pandas")

View file

@ -1170,7 +1170,7 @@ def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path):
ds = ray.data.read_numpy(data_path, filesystem=fs)
assert str(ds) == (
"Dataset(num_blocks=2, num_rows=None, "
"schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)})"
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})"
)
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])])
@ -1182,7 +1182,7 @@ def test_numpy_read(ray_start_regular_shared, tmp_path):
ds = ray.data.read_numpy(path)
assert str(ds) == (
"Dataset(num_blocks=1, num_rows=10, "
"schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)})"
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})"
)
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])])
@ -1195,7 +1195,7 @@ def test_numpy_read(ray_start_regular_shared, tmp_path):
assert ds.count() == 10
assert str(ds) == (
"Dataset(num_blocks=1, num_rows=10, "
"schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)})"
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})"
)
assert [v.item() for v in ds.take(2)] == [0, 1]
@ -1208,7 +1208,7 @@ def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path):
ds = ray.data.read_numpy(path, meta_provider=FastFileMetadataProvider())
assert str(ds) == (
"Dataset(num_blocks=1, num_rows=10, "
"schema={__value__: ArrowTensorType(shape=(1,), dtype=int64)})"
"schema={__value__: <ArrowTensorType: shape=(1,), dtype=int64>})"
)
np.testing.assert_equal(ds.take(2), [np.array([0]), np.array([1])])
@ -1265,7 +1265,7 @@ def test_numpy_read_partitioned_with_filter(
val_str = "".join(f"array({v}, dtype=int8), " for v in vals)[:-2]
assert_base_partitioned_ds(
ds,
schema="{__value__: ArrowTensorType(shape=(2,), dtype=int8)}",
schema="{__value__: <ArrowTensorType: shape=(2,), dtype=int8>}",
sorted_values=f"[[{val_str}]]",
ds_take_transform_fn=lambda taken: [taken],
sorted_values_transform_fn=lambda sorted_values: str(sorted_values),