[Datasets] Fix __array__ protocol on TensorArrayElement and TensorArray. (#25647)

This PR fixes two issues with the __array__ protocol on the tensor extension:

1. The __array__ protocol on TensorArrayElement was missing the dtype parameter, causing np.asarray(tae, dtype=some_dtype) calls to fail. This PR adds support for the dtype argument.
2. TensorArray and TensorArrayElement didn't support NumPy's scalar casting semantics for single-element tensors. This PR adds support for these scalar casting semantics.
This commit is contained in:
Clark Zinzow 2022-06-10 16:42:16 -07:00 committed by GitHub
parent 1dd714e0fa
commit 4fb92dd2f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 111 additions and 6 deletions

View file

@ -316,7 +316,31 @@ class TensorOpsMixin(pd.api.extensions.ExtensionScalarOpsMixin):
return cls._create_method(op)
class TensorArrayElement(TensorOpsMixin):
class TensorScalarCastMixin:
"""
Mixin for casting scalar tensors to a particular numeric type.
"""
def _scalarfunc(self, func: Callable[[Any], Any]):
return func(self._tensor)
def __complex__(self):
return self._scalarfunc(complex)
def __float__(self):
return self._scalarfunc(float)
def __int__(self):
return self._scalarfunc(int)
def __hex__(self):
return self._scalarfunc(hex)
def __oct__(self):
return self._scalarfunc(oct)
class TensorArrayElement(TensorOpsMixin, TensorScalarCastMixin):
"""
Single element of a TensorArray, wrapping an underlying ndarray.
"""
@ -336,18 +360,54 @@ class TensorArrayElement(TensorOpsMixin):
def __str__(self):
return self._tensor.__str__()
@property
def numpy_dtype(self):
"""
Get the dtype of the tensor.
:return: The numpy dtype of the backing ndarray
"""
return self._tensor.dtype
@property
def numpy_ndim(self):
"""
Get the number of tensor dimensions.
:return: integer for the number of dimensions
"""
return self._tensor.ndim
@property
def numpy_shape(self):
"""
Get the shape of the tensor.
:return: A tuple of integers for the numpy shape of the backing ndarray
"""
return self._tensor.shape
@property
def numpy_size(self):
"""
Get the size of the tensor.
:return: integer for the number of elements in the tensor
"""
return self._tensor.size
def to_numpy(self):
"""
Return the values of this element as a NumPy ndarray.
"""
return np.asarray(self._tensor)
def __array__(self):
return np.asarray(self._tensor)
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
return np.asarray(self._tensor, dtype=dtype, **kwargs)
@PublicAPI(stability="beta")
class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
class TensorArray(
pd.api.extensions.ExtensionArray,
TensorOpsMixin,
TensorScalarCastMixin,
):
"""
Pandas `ExtensionArray` representing a tensor column, i.e. a column
consisting of ndarrays as elements. All tensors in a column must have the
@ -891,8 +951,8 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
except KeyError:
raise NotImplementedError(f"'{name}' aggregate not implemented.") from None
def __array__(self, dtype: np.dtype = None):
return np.asarray(self._tensor, dtype=dtype)
def __array__(self, dtype: np.dtype = None, **kwargs) -> np.ndarray:
return np.asarray(self._tensor, dtype=dtype, **kwargs)
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs, **kwargs):
"""
@ -989,6 +1049,14 @@ class TensorArray(pd.api.extensions.ExtensionArray, TensorOpsMixin):
"""
return self._tensor.shape
@property
def numpy_size(self):
"""
Get the size of the tensor.
:return: integer for the number of elements in the tensor
"""
return self._tensor.size
@property
def _is_boolean(self):
"""

View file

@ -604,6 +604,43 @@ def test_tensor_array_ops(ray_start_regular_shared):
np.testing.assert_equal(apply_logical_ops(arr), apply_logical_ops(df["two"]))
def test_tensor_array_array_protocol(ray_start_regular_shared):
outer_dim = 3
inner_shape = (2, 2, 2)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape)
t_arr = TensorArray(arr)
np.testing.assert_array_equal(
np.asarray(t_arr, dtype=np.float32), arr.astype(np.float32)
)
t_arr_elem = t_arr[0]
np.testing.assert_array_equal(
np.asarray(t_arr_elem, dtype=np.float32), arr[0].astype(np.float32)
)
def test_tensor_array_scalar_cast(ray_start_regular_shared):
outer_dim = 3
inner_shape = (1,)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape)
t_arr = TensorArray(arr)
for t_arr_elem, arr_elem in zip(t_arr, arr):
assert float(t_arr_elem) == float(arr_elem)
arr = np.arange(1).reshape((1, 1, 1))
t_arr = TensorArray(arr)
assert float(t_arr) == float(arr)
def test_tensor_array_reductions(ray_start_regular_shared):
outer_dim = 3
inner_shape = (2, 2, 2)