mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Datasets] Fix __array__
protocol on TensorArrayElement
and TensorArray
. (#25647)
This PR fixes two issues with the __array__ protocol on the tensor extension: 1. The __array__ protocol on TensorArrayElement was missing the dtype parameter, causing np.asarray(tae, dtype=some_dtype) calls to fail. This PR adds support for the dtype argument. 2. TensorArray and TensorArrayElement didn't support NumPy's scalar casting semantics for single-element tensors. This PR adds support for these scalar casting semantics.
This commit is contained in:
parent
1dd714e0fa
commit
4fb92dd2f1
2 changed files with 111 additions and 6 deletions
|
@ -316,7 +316,31 @@ class TensorOpsMixin(pd.api.extensions.ExtensionScalarOpsMixin):
|
|||
return cls._create_method(op)
|
||||
|
||||
|
||||
class TensorArrayElement(TensorOpsMixin):
|
||||
class TensorScalarCastMixin:
|
||||
"""
|
||||
Mixin for casting scalar tensors to a particular numeric type.
|
||||
"""
|
||||
|
||||
def _scalarfunc(self, func: Callable[[Any], Any]):
|
||||
return func(self._tensor)
|
||||
|
||||
def __complex__(self):
|
||||
return self._scalarfunc(complex)
|
||||
|
||||
def __float__(self):
|
||||
return self._scalarfunc(float)
|
||||
|
||||
def __int__(self):
|
||||
return self._scalarfunc(int)
|
||||
|
||||
def __hex__(self):
|
||||
return self._scalarfunc(hex)
|
||||
|
||||
def __oct__(self):
|
||||
return self._scalarfunc(oct)
|
||||
|
||||
|
||||
class TensorArrayElement(TensorOpsMixin, TensorScalarCastMixin):
|
||||
"""
|
||||
Single element of a TensorArray, wrapping an underlying ndarray.
|
||||
"""
|
||||
|
@ -336,18 +360,54 @@ class TensorArrayElement(TensorOpsMixin):
|
|||
def __str__(self):
|
||||
return self._tensor.__str__()
|
||||
|
||||
@property
|
||||
def numpy_dtype(self):
|
||||
"""
|
||||
Get the dtype of the tensor.
|
||||
:return: The numpy dtype of the backing ndarray
|
||||
"""
|
||||
return self._tensor.dtype
|
||||
|
||||
@property
|
||||
def numpy_ndim(self):
|
||||
"""
|
||||
Get the number of tensor dimensions.
|
||||
:return: integer for the number of dimensions
|
||||
"""
|
||||
return self._tensor.ndim
|
||||
|
||||
@property
|
||||
def numpy_shape(self):
|
||||
"""
|
||||
Get the shape of the tensor.
|
||||
:return: A tuple of integers for the numpy shape of the backing ndarray
|
||||
"""
|
||||
return self._tensor.shape
|
||||
|
||||
@property
|
||||
def numpy_size(self):
|
||||
"""
|
||||
Get the size of the tensor.
|
||||
:return: integer for the number of elements in the tensor
|
||||
"""
|
||||
return self._tensor.size
|
||||
|
||||
def to_numpy(self):
|
||||
"""
|
||||
Return the values of this element as a NumPy ndarray.
|
||||
"""
|
||||
return np.asarray(self._tensor)
|
||||
|
||||
def __array__(self):
|
||||
return np.asarray(self._tensor)
|
||||
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
|
||||
return np.asarray(self._tensor, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
|
||||
class TensorArray(
|
||||
pd.api.extensions.ExtensionArray,
|
||||
TensorOpsMixin,
|
||||
TensorScalarCastMixin,
|
||||
):
|
||||
"""
|
||||
Pandas `ExtensionArray` representing a tensor column, i.e. a column
|
||||
consisting of ndarrays as elements. All tensors in a column must have the
|
||||
|
@ -891,8 +951,8 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
|
|||
except KeyError:
|
||||
raise NotImplementedError(f"'{name}' aggregate not implemented.") from None
|
||||
|
||||
def __array__(self, dtype: np.dtype = None):
|
||||
return np.asarray(self._tensor, dtype=dtype)
|
||||
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
|
||||
return np.asarray(self._tensor, dtype=dtype, **kwargs)
|
||||
|
||||
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs, **kwargs):
|
||||
"""
|
||||
|
@ -989,6 +1049,14 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
|
|||
"""
|
||||
return self._tensor.shape
|
||||
|
||||
@property
|
||||
def numpy_size(self):
|
||||
"""
|
||||
Get the size of the tensor.
|
||||
:return: integer for the number of elements in the tensor
|
||||
"""
|
||||
return self._tensor.size
|
||||
|
||||
@property
|
||||
def _is_boolean(self):
|
||||
"""
|
||||
|
|
|
@ -604,6 +604,43 @@ def test_tensor_array_ops(ray_start_regular_shared):
|
|||
np.testing.assert_equal(apply_logical_ops(arr), apply_logical_ops(df["two"]))
|
||||
|
||||
|
||||
def test_tensor_array_array_protocol(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
shape = (outer_dim,) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
|
||||
t_arr = TensorArray(arr)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
np.asarray(t_arr, dtype=np.float32), arr.astype(np.float32)
|
||||
)
|
||||
|
||||
t_arr_elem = t_arr[0]
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
np.asarray(t_arr_elem, dtype=np.float32), arr[0].astype(np.float32)
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_array_scalar_cast(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (1,)
|
||||
shape = (outer_dim,) + inner_shape
|
||||
num_items = np.prod(np.array(shape))
|
||||
arr = np.arange(num_items).reshape(shape)
|
||||
|
||||
t_arr = TensorArray(arr)
|
||||
|
||||
for t_arr_elem, arr_elem in zip(t_arr, arr):
|
||||
assert float(t_arr_elem) == float(arr_elem)
|
||||
|
||||
arr = np.arange(1).reshape((1, 1, 1))
|
||||
t_arr = TensorArray(arr)
|
||||
assert float(t_arr) == float(arr)
|
||||
|
||||
|
||||
def test_tensor_array_reductions(ray_start_regular_shared):
|
||||
outer_dim = 3
|
||||
inner_shape = (2, 2, 2)
|
||||
|
|
Loading…
Add table
Reference in a new issue