[Ray dataset] detect dataframe dtype as object (#25811)

* fix ci

* not break master
This commit is contained in:
Jimmy Yao 2022-06-16 11:23:03 -07:00 committed by GitHub
parent f34cd2fd8f
commit b2e9aea908
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,6 +3,7 @@ from typing import Optional
import numpy as np
import pandas as pd
import tensorflow as tf
from pandas.api.types import is_object_dtype
def convert_pandas_to_tf_tensor(
@ -47,6 +48,13 @@ def convert_pandas_to_tf_tensor(
# them. If the columns contain different types (for example, `float32`s
# and `int32`s), then `tf.concat` raises an error.
dtype: np.dtype = np.find_common_type(df.dtypes, [])
# if the columns are `ray.data.extensions.tensor_extension.TensorArray`,
# the dtype will be `object`. In this case, we need to set the dtype to
# none, and use the automatic type casting of `tf.convert_to_tensor`.
if is_object_dtype(dtype):
dtype = None
except TypeError:
# `find_common_type` fails if a series has `TensorDtype`. In this case,
# don't cast any of the series and continue.