diff --git a/python/ray/air/_internal/tensorflow_utils.py b/python/ray/air/_internal/tensorflow_utils.py index 13e947dc5..d81e600ab 100644 --- a/python/ray/air/_internal/tensorflow_utils.py +++ b/python/ray/air/_internal/tensorflow_utils.py @@ -47,6 +47,13 @@ def convert_pandas_to_tf_tensor( # them. If the columns contain different types (for example, `float32`s # and `int32`s), then `tf.concat` raises an error. dtype: np.dtype = np.find_common_type(df.dtypes, []) + + # if the columns are `ray.data.extensions.tensor_extension.TensorArray`, + # the dtype will be `object`. In this case, we need to set the dtype to + # none, and use the automatic type casting of `tf.convert_to_tensor`. + if isinstance(dtype, object): + dtype = None + except TypeError: # `find_common_type` fails if a series has `TensorDtype`. In this case, # don't cast any of the series and continue.