mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[AIR - Datasets] Fix AIR release tests dealing with tensor columns. (#27221)
This PR fixes some AIR release tests that deal with tensor columns.
This commit is contained in:
parent
d25a3ff80a
commit
3730ec8cc9
2 changed files with 5 additions and 9 deletions
|
@ -8,7 +8,6 @@ from torchvision import transforms
|
|||
from torchvision.models import resnet18
|
||||
|
||||
import ray
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
from ray.train.torch import TorchCheckpoint, TorchPredictor
|
||||
from ray.train.batch_predictor import BatchPredictor
|
||||
from ray.data.preprocessors import BatchMapper
|
||||
|
@ -17,8 +16,7 @@ from ray.data.datasource import ImageFolderDatasource
|
|||
|
||||
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
User Pytorch code to transform user image. Note we still use TensorArray as
|
||||
intermediate format to hold images for now.
|
||||
User Pytorch code to transform user image.
|
||||
"""
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
|
@ -28,7 +26,7 @@ def preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
|||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
df["image"] = TensorArray([preprocess(image.to_numpy()) for image in df["image"]])
|
||||
df.loc[:, "image"] = [preprocess(image).numpy() for image in df["image"]]
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
|
||||
import ray
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
from ray.train.torch import TorchCheckpoint
|
||||
from ray.data.preprocessors import BatchMapper
|
||||
from ray import train
|
||||
|
@ -23,8 +22,7 @@ from ray.air.config import ScalingConfig
|
|||
|
||||
def preprocess_image_with_label(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
User Pytorch code to transform user image. Note we still use TensorArray as
|
||||
intermediate format to hold images for now.
|
||||
User Pytorch code to transform user image.
|
||||
"""
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
|
@ -34,9 +32,9 @@ def preprocess_image_with_label(df: pd.DataFrame) -> pd.DataFrame:
|
|||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
df["image"] = TensorArray([preprocess(image.to_numpy()) for image in df["image"]])
|
||||
df.loc[:, "image"] = [preprocess(image).numpy() for image in df["image"]]
|
||||
# Fix fixed synthetic value for perf benchmark purpose
|
||||
df["label"] = df["label"].map(lambda _: 1)
|
||||
df.loc[:, "label"] = df["label"].map(lambda _: 1)
|
||||
return df
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue