mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Ray dataset] detect dataframe dtype as object (#25811)
* fix ci * not break master
This commit is contained in:
parent
f34cd2fd8f
commit
b2e9aea908
1 changed files with 8 additions and 0 deletions
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from pandas.api.types import is_object_dtype
|
||||
|
||||
|
||||
def convert_pandas_to_tf_tensor(
|
||||
|
@ -47,6 +48,13 @@ def convert_pandas_to_tf_tensor(
|
|||
# them. If the columns contain different types (for example, `float32`s
|
||||
# and `int32`s), then `tf.concat` raises an error.
|
||||
dtype: np.dtype = np.find_common_type(df.dtypes, [])
|
||||
|
||||
# if the columns are `ray.data.extensions.tensor_extension.TensorArray`,
|
||||
# the dtype will be `object`. In this case, we need to set the dtype to
|
||||
# none, and use the automatic type casting of `tf.convert_to_tensor`.
|
||||
if is_object_dtype(dtype):
|
||||
dtype = None
|
||||
|
||||
except TypeError:
|
||||
# `find_common_type` fails if a series has `TensorDtype`. In this case,
|
||||
# don't cast any of the series and continue.
|
||||
|
|
Loading…
Add table
Reference in a new issue