mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air] Accessors for preprocessor in Predictor class (#26600)
This commit is contained in:
parent
e7ab969f61
commit
b0eb051282
19 changed files with 316 additions and 80 deletions
|
@ -8,19 +8,26 @@ import tarfile
|
|||
import tempfile
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||
from ray.air._internal.remote_storage import (
|
||||
download_from_uri,
|
||||
fs_hint,
|
||||
is_non_local_path_uri,
|
||||
upload_to_uri,
|
||||
)
|
||||
from ray.air.constants import PREPROCESSOR_KEY
|
||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.util.ml_utils.filelock import TempFileLock
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
_DICT_CHECKPOINT_FILE_NAME = "dict_checkpoint.pkl"
|
||||
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY = "_ray_additional_checkpoint_files"
|
||||
_METADATA_CHECKPOINT_SUFFIX = ".meta.pkl"
|
||||
|
@ -600,6 +607,23 @@ class Checkpoint:
|
|||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def get_preprocessor(self) -> Optional["Preprocessor"]:
|
||||
"""Return the saved preprocessor, if one exists."""
|
||||
|
||||
# The preprocessor will either be stored in an in-memory dict or
|
||||
# written to storage. In either case, it will use the PREPROCESSOR_KEY key.
|
||||
|
||||
# First try converting to dictionary.
|
||||
checkpoint_dict = self.to_dict()
|
||||
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
|
||||
|
||||
if preprocessor is None:
|
||||
# Fallback to reading from directory.
|
||||
with self.as_directory() as checkpoint_path:
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return preprocessor
|
||||
|
||||
|
||||
def _get_local_path(path: Optional[str]) -> Optional[str]:
|
||||
"""Check if path is a local path. Otherwise return None."""
|
||||
|
|
|
@ -6,8 +6,18 @@ import unittest
|
|||
from typing import Any
|
||||
|
||||
import ray
|
||||
from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
|
||||
from ray.air._internal.remote_storage import delete_at_uri, _ensure_directory
|
||||
from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
|
||||
from ray.air.constants import PREPROCESSOR_KEY
|
||||
from ray.data import Preprocessor
|
||||
|
||||
|
||||
class DummyPreprocessor(Preprocessor):
|
||||
def __init__(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
|
||||
def transform_batch(self, df):
|
||||
return df * self.multiplier
|
||||
|
||||
|
||||
class CheckpointsConversionTest(unittest.TestCase):
|
||||
|
@ -470,6 +480,71 @@ class CheckpointsSerdeTest(unittest.TestCase):
|
|||
self._testCheckpointSerde(checkpoint, *checkpoint.get_internal_representation())
|
||||
|
||||
|
||||
class PreprocessorCheckpointTest(unittest.TestCase):
|
||||
def testDictCheckpointWithoutPreprocessor(self):
|
||||
data = {"metric": 5}
|
||||
checkpoint = Checkpoint.from_dict(data)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor is None
|
||||
|
||||
def testDictCheckpointWithPreprocessor(self):
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
data = {"metric": 5, PREPROCESSOR_KEY: preprocessor}
|
||||
checkpoint = Checkpoint.from_dict(data)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor.multiplier == 1
|
||||
|
||||
def testDictCheckpointWithPreprocessorAsDir(self):
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
data = {"metric": 5, PREPROCESSOR_KEY: preprocessor}
|
||||
checkpoint = Checkpoint.from_dict(data)
|
||||
checkpoint_path = checkpoint.to_directory()
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_path)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor.multiplier == 1
|
||||
|
||||
def testDirCheckpointWithoutPreprocessor(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = {"metric": 5}
|
||||
checkpoint_dir = os.path.join(tmpdir, "existing_checkpoint")
|
||||
os.mkdir(checkpoint_dir, 0o755)
|
||||
with open(os.path.join(checkpoint_dir, "test_data.pkl"), "wb") as fp:
|
||||
pickle.dump(data, fp)
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor is None
|
||||
|
||||
def testDirCheckpointWithPreprocessor(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
data = {"metric": 5}
|
||||
checkpoint_dir = os.path.join(tmpdir, "existing_checkpoint")
|
||||
os.mkdir(checkpoint_dir, 0o755)
|
||||
with open(os.path.join(checkpoint_dir, "test_data.pkl"), "wb") as fp:
|
||||
pickle.dump(data, fp)
|
||||
with open(os.path.join(checkpoint_dir, PREPROCESSOR_KEY), "wb") as fp:
|
||||
pickle.dump(preprocessor, fp)
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor.multiplier == 1
|
||||
|
||||
def testDirCheckpointWithPreprocessorAsDict(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
data = {"metric": 5}
|
||||
checkpoint_dir = os.path.join(tmpdir, "existing_checkpoint")
|
||||
os.mkdir(checkpoint_dir, 0o755)
|
||||
with open(os.path.join(checkpoint_dir, "test_data.pkl"), "wb") as fp:
|
||||
pickle.dump(data, fp)
|
||||
with open(os.path.join(checkpoint_dir, PREPROCESSOR_KEY), "wb") as fp:
|
||||
pickle.dump(preprocessor, fp)
|
||||
checkpoint = Checkpoint.from_directory(checkpoint_dir)
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
checkpoint = checkpoint.from_dict(checkpoint_dict)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
assert preprocessor.multiplier == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -5,6 +5,7 @@ import pandas as pd
|
|||
import ray
|
||||
from ray.air import Checkpoint
|
||||
from ray.air.util.data_batch_conversion import convert_batch_type_to_pandas
|
||||
from ray.data import Preprocessor
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
|
@ -18,23 +19,17 @@ class BatchPredictor:
|
|||
|
||||
This batch predictor wraps around a predictor class and executes it
|
||||
in a distributed way when calling ``predict()``.
|
||||
|
||||
Attributes:
|
||||
checkpoint: Checkpoint loaded by the distributed predictor objects.
|
||||
predictor_cls: Predictor class reference. When scoring, each scoring worker
|
||||
will create an instance of this class and call ``predict(batch)`` on it.
|
||||
**predictor_kwargs: Keyword arguments passed to the predictor on
|
||||
initialization.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, checkpoint: Checkpoint, predictor_cls: Type[Predictor], **predictor_kwargs
|
||||
):
|
||||
self._checkpoint = checkpoint
|
||||
# Store as object ref so we only serialize it once for all map workers
|
||||
self.checkpoint_ref = checkpoint.to_object_ref()
|
||||
self.predictor_cls = predictor_cls
|
||||
self.predictor_kwargs = predictor_kwargs
|
||||
self._checkpoint_ref = checkpoint.to_object_ref()
|
||||
self._predictor_cls = predictor_cls
|
||||
self._predictor_kwargs = predictor_kwargs
|
||||
self._override_preprocessor: Optional[Preprocessor] = None
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
@ -66,6 +61,17 @@ class BatchPredictor:
|
|||
predictor_cls=PandasUDFPredictor,
|
||||
)
|
||||
|
||||
def get_preprocessor(self) -> Preprocessor:
|
||||
"""Get the preprocessor to use prior to executing predictions."""
|
||||
if self._override_preprocessor:
|
||||
return self._override_preprocessor
|
||||
|
||||
return self._checkpoint.get_preprocessor()
|
||||
|
||||
def set_preprocessor(self, preprocessor: Preprocessor) -> None:
|
||||
"""Set the preprocessor to use prior to executing predictions."""
|
||||
self._override_preprocessor = preprocessor
|
||||
|
||||
def predict(
|
||||
self,
|
||||
data: Union[ray.data.Dataset, ray.data.DatasetPipeline],
|
||||
|
@ -130,9 +136,10 @@ class BatchPredictor:
|
|||
Dataset containing scoring results.
|
||||
|
||||
"""
|
||||
predictor_cls = self.predictor_cls
|
||||
checkpoint_ref = self.checkpoint_ref
|
||||
predictor_kwargs = self.predictor_kwargs
|
||||
predictor_cls = self._predictor_cls
|
||||
checkpoint_ref = self._checkpoint_ref
|
||||
predictor_kwargs = self._predictor_kwargs
|
||||
override_prep = self._override_preprocessor
|
||||
# Automatic set use_gpu in predictor constructor if user provided
|
||||
# explicit GPU resources
|
||||
if (
|
||||
|
@ -144,16 +151,18 @@ class BatchPredictor:
|
|||
class ScoringWrapper:
|
||||
def __init__(self):
|
||||
checkpoint = Checkpoint.from_object_ref(checkpoint_ref)
|
||||
self.predictor = predictor_cls.from_checkpoint(
|
||||
self._predictor = predictor_cls.from_checkpoint(
|
||||
checkpoint, **predictor_kwargs
|
||||
)
|
||||
if override_prep:
|
||||
self._predictor.set_preprocessor(override_prep)
|
||||
|
||||
def __call__(self, batch):
|
||||
if feature_columns:
|
||||
prediction_batch = batch[feature_columns]
|
||||
else:
|
||||
prediction_batch = batch
|
||||
prediction_output = self.predictor.predict(
|
||||
prediction_output = self._predictor.predict(
|
||||
prediction_batch, **predict_kwargs
|
||||
)
|
||||
if keep_columns:
|
||||
|
|
|
@ -7,7 +7,6 @@ from transformers.pipelines.table_question_answering import (
|
|||
TableQuestionAnsweringPipeline,
|
||||
)
|
||||
|
||||
from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.predictor import Predictor
|
||||
|
@ -35,7 +34,7 @@ class HuggingFacePredictor(Predictor):
|
|||
preprocessor: Optional["Preprocessor"] = None,
|
||||
):
|
||||
self.pipeline = pipeline
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
@ -66,8 +65,8 @@ class HuggingFacePredictor(Predictor):
|
|||
"If `pipeline_cls` is not specified, 'task' must be passed as a kwarg."
|
||||
)
|
||||
pipeline_cls = pipeline_cls or pipeline_factory
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
with checkpoint.as_directory() as checkpoint_path:
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
# Tokenizer will be loaded automatically (no need to specify
|
||||
# `tokenizer=checkpoint_path`)
|
||||
pipeline = pipeline_cls(model=checkpoint_path, **pipeline_kwargs)
|
||||
|
|
|
@ -28,7 +28,7 @@ class LightGBMPredictor(Predictor):
|
|||
self, model: lightgbm.Booster, preprocessor: Optional["Preprocessor"] = None
|
||||
):
|
||||
self.model = model
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint) -> "LightGBMPredictor":
|
||||
|
@ -42,7 +42,8 @@ class LightGBMPredictor(Predictor):
|
|||
``LightGBMTrainer`` run.
|
||||
|
||||
"""
|
||||
bst, preprocessor = load_checkpoint(checkpoint)
|
||||
bst, _ = load_checkpoint(checkpoint)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(model=bst, preprocessor=preprocessor)
|
||||
|
||||
def _predict_pandas(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import abc
|
||||
from typing import Dict, Type, Callable
|
||||
from typing import Dict, Type, Optional, Callable
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -11,6 +11,7 @@ from ray.air.util.data_batch_conversion import (
|
|||
convert_batch_type_to_pandas,
|
||||
convert_pandas_to_batch_type,
|
||||
)
|
||||
from ray.data import Preprocessor
|
||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||
|
||||
try:
|
||||
|
@ -72,8 +73,11 @@ class Predictor(abc.ABC):
|
|||
tensor data to avoid extra copies from Pandas conversions.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocessor: Optional[Preprocessor] = None):
|
||||
"""Subclasseses must call Predictor.__init__() to set a preprocessor."""
|
||||
self._preprocessor: Optional[Preprocessor] = preprocessor
|
||||
|
||||
@classmethod
|
||||
@PublicAPI(stability="alpha")
|
||||
@abc.abstractmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
|
||||
"""Create a specific predictor from a checkpoint.
|
||||
|
@ -108,7 +112,14 @@ class Predictor(abc.ABC):
|
|||
|
||||
return PandasUDFPredictor.from_checkpoint(Checkpoint.from_dict({"dummy": 1}))
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def get_preprocessor(self) -> Optional[Preprocessor]:
|
||||
"""Get the preprocessor to use prior to executing predictions."""
|
||||
return self._preprocessor
|
||||
|
||||
def set_preprocessor(self, preprocessor: Optional[Preprocessor]) -> None:
|
||||
"""Set the preprocessor to use prior to executing predictions."""
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
def predict(self, data: DataBatchType, **kwargs) -> DataBatchType:
|
||||
"""Perform inference on a batch of data.
|
||||
|
||||
|
@ -123,8 +134,13 @@ class Predictor(abc.ABC):
|
|||
"""
|
||||
data_df = convert_batch_type_to_pandas(data)
|
||||
|
||||
if getattr(self, "preprocessor", None):
|
||||
data_df = self.preprocessor.transform_batch(data_df)
|
||||
if not hasattr(self, "_preprocessor"):
|
||||
raise NotImplementedError(
|
||||
"Subclasses of Predictor must call Predictor.__init__(preprocessor)."
|
||||
)
|
||||
|
||||
if self._preprocessor:
|
||||
data_df = self._preprocessor.transform_batch(data_df)
|
||||
|
||||
predictions_df = self._predict_pandas(data_df, **kwargs)
|
||||
return convert_pandas_to_batch_type(
|
||||
|
|
|
@ -31,7 +31,7 @@ class RLPredictor(Predictor):
|
|||
preprocessor: Optional["Preprocessor"] = None,
|
||||
):
|
||||
self.policy = policy
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
@ -52,7 +52,8 @@ class RLPredictor(Predictor):
|
|||
it is parsed from the saved trainer configuration instead.
|
||||
|
||||
"""
|
||||
policy, preprocessor = load_checkpoint(checkpoint, env)
|
||||
policy, _ = load_checkpoint(checkpoint, env)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(policy=policy, preprocessor=preprocessor)
|
||||
|
||||
def _predict_pandas(self, data: "pd.DataFrame", **kwargs) -> "pd.DataFrame":
|
||||
|
|
|
@ -33,7 +33,7 @@ class SklearnPredictor(Predictor):
|
|||
preprocessor: Optional["Preprocessor"] = None,
|
||||
):
|
||||
self.estimator = estimator
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint) -> "SklearnPredictor":
|
||||
|
@ -46,7 +46,8 @@ class SklearnPredictor(Predictor):
|
|||
preprocessor from. It is expected to be from the result of a
|
||||
``SklearnTrainer`` run.
|
||||
"""
|
||||
estimator, preprocessor = load_checkpoint(checkpoint)
|
||||
estimator, _ = load_checkpoint(checkpoint)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(estimator=estimator, preprocessor=preprocessor)
|
||||
|
||||
def _predict_pandas(
|
||||
|
|
|
@ -41,7 +41,6 @@ class TensorflowPredictor(DLPredictor):
|
|||
):
|
||||
self.model_definition = model_definition
|
||||
self.model_weights = model_weights
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
self.use_gpu = use_gpu
|
||||
# TensorFlow model objects cannot be pickled, therefore we use
|
||||
|
@ -73,6 +72,7 @@ class TensorflowPredictor(DLPredictor):
|
|||
|
||||
if model_weights is not None:
|
||||
self._model.set_weights(model_weights)
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
@ -94,7 +94,8 @@ class TensorflowPredictor(DLPredictor):
|
|||
"""
|
||||
# Cannot use TensorFlow load_checkpoint here
|
||||
# due to instantiated models not being pickleable
|
||||
model_weights, preprocessor = _load_checkpoint(checkpoint, "TensorflowTrainer")
|
||||
model_weights, _ = _load_checkpoint(checkpoint, "TensorflowTrainer")
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(
|
||||
model_definition=model_definition,
|
||||
model_weights=model_weights,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from ray.air.constants import PREPROCESSOR_KEY
|
||||
|
||||
import ray
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
|
@ -11,22 +13,33 @@ from ray.train.predictor import Predictor
|
|||
|
||||
|
||||
class DummyPreprocessor(Preprocessor):
|
||||
def __init__(self, multiplier=2):
|
||||
self.multiplier = multiplier
|
||||
|
||||
def transform_batch(self, df):
|
||||
return df * 2
|
||||
return df * self.multiplier
|
||||
|
||||
|
||||
class DummyPredictor(Predictor):
|
||||
def __init__(self, factor: float = 1.0, use_gpu: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
factor: float = 1.0,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
use_gpu: bool = False,
|
||||
):
|
||||
self.factor = factor
|
||||
self.preprocessor = DummyPreprocessor()
|
||||
self.use_gpu = use_gpu
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, checkpoint: Checkpoint, use_gpu: bool = False, **kwargs
|
||||
) -> "DummyPredictor":
|
||||
checkpoint_data = checkpoint.to_dict()
|
||||
return cls(**checkpoint_data, use_gpu=use_gpu)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(
|
||||
checkpoint_data["factor"], preprocessor=preprocessor, use_gpu=use_gpu
|
||||
)
|
||||
|
||||
def _predict_pandas(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
|
||||
# Need to throw exception here instead of constructor to surface the
|
||||
|
@ -44,12 +57,14 @@ class DummyPredictorFS(DummyPredictor):
|
|||
# simulate reading
|
||||
time.sleep(1)
|
||||
checkpoint_data = checkpoint.to_dict()
|
||||
return cls(**checkpoint_data)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(checkpoint_data["factor"], preprocessor=preprocessor)
|
||||
|
||||
|
||||
def test_batch_prediction():
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}),
|
||||
DummyPredictor,
|
||||
)
|
||||
|
||||
test_dataset = ray.data.range(4)
|
||||
|
@ -76,7 +91,8 @@ def test_batch_prediction():
|
|||
|
||||
def test_batch_prediction_fs():
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0}), DummyPredictorFS
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}),
|
||||
DummyPredictorFS,
|
||||
)
|
||||
|
||||
test_dataset = ray.data.from_items([1.0, 2.0, 3.0, 4.0] * 32).repartition(8)
|
||||
|
@ -98,7 +114,8 @@ def test_batch_prediction_fs():
|
|||
|
||||
def test_batch_prediction_feature_cols():
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}),
|
||||
DummyPredictor,
|
||||
)
|
||||
|
||||
test_dataset = ray.data.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
||||
|
@ -110,7 +127,8 @@ def test_batch_prediction_feature_cols():
|
|||
|
||||
def test_batch_prediction_keep_cols():
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}),
|
||||
DummyPredictor,
|
||||
)
|
||||
|
||||
test_dataset = ray.data.from_pandas(
|
||||
|
@ -153,7 +171,8 @@ def test_automatic_enable_gpu_from_num_gpus_per_worker():
|
|||
"""
|
||||
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}),
|
||||
DummyPredictor,
|
||||
)
|
||||
test_dataset = ray.data.range(4)
|
||||
|
||||
|
@ -163,6 +182,38 @@ def test_automatic_enable_gpu_from_num_gpus_per_worker():
|
|||
_ = batch_predictor.predict(test_dataset, num_gpus_per_worker=1)
|
||||
|
||||
|
||||
def test_get_and_set_preprocessor():
|
||||
"""Test preprocessor can be set and get."""
|
||||
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
batch_predictor = BatchPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: preprocessor}),
|
||||
DummyPredictor,
|
||||
)
|
||||
assert batch_predictor.get_preprocessor() == preprocessor
|
||||
|
||||
test_dataset = ray.data.range(4)
|
||||
output_ds = batch_predictor.predict(test_dataset)
|
||||
assert output_ds.to_pandas().to_numpy().squeeze().tolist() == [
|
||||
0.0,
|
||||
2.0,
|
||||
4.0,
|
||||
6.0,
|
||||
]
|
||||
|
||||
preprocessor2 = DummyPreprocessor(2)
|
||||
batch_predictor.set_preprocessor(preprocessor2)
|
||||
assert batch_predictor.get_preprocessor() == preprocessor2
|
||||
|
||||
output_ds = batch_predictor.predict(test_dataset)
|
||||
assert output_ds.to_pandas().to_numpy().squeeze().tolist() == [
|
||||
0.0,
|
||||
4.0,
|
||||
8.0,
|
||||
12.0,
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ def test_predict(tmpdir, ray_start_runtime_env, batch_type):
|
|||
|
||||
assert len(predictions) == 3
|
||||
if preprocessor:
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
ray.get(test.remote(use_preprocessor=True))
|
||||
ray.get(test.remote(use_preprocessor=False))
|
||||
|
|
|
@ -49,7 +49,10 @@ def test_init():
|
|||
checkpoint_predictor = LightGBMPredictor.from_checkpoint(checkpoint)
|
||||
|
||||
assert get_num_trees(checkpoint_predictor.model) == get_num_trees(predictor.model)
|
||||
assert checkpoint_predictor.preprocessor.attr == predictor.preprocessor.attr
|
||||
assert (
|
||||
checkpoint_predictor.get_preprocessor().attr
|
||||
== predictor.get_preprocessor().attr
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_type", [np.ndarray, pd.DataFrame, pa.Table, dict])
|
||||
|
@ -62,7 +65,7 @@ def test_predict(batch_type):
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns():
|
||||
|
@ -73,7 +76,7 @@ def test_predict_feature_columns():
|
|||
predictions = predictor.predict(data_batch, feature_columns=[0, 1])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns_pandas():
|
||||
|
@ -90,7 +93,7 @@ def test_predict_feature_columns_pandas():
|
|||
predictions = predictor.predict(data_batch, feature_columns=["A", "B"])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_no_preprocessor_no_training():
|
||||
|
|
|
@ -1,28 +1,37 @@
|
|||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from ray.air.util.data_batch_conversion import DataType
|
||||
|
||||
import ray
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import PREPROCESSOR_KEY
|
||||
from ray.data import Preprocessor
|
||||
from ray.train.predictor import Predictor, PredictorNotSerializableException
|
||||
|
||||
|
||||
class DummyPreprocessor(Preprocessor):
|
||||
def __init__(self, multiplier=2):
|
||||
self.multiplier = multiplier
|
||||
|
||||
def transform_batch(self, df):
|
||||
return df * 2
|
||||
return df * self.multiplier
|
||||
|
||||
|
||||
class DummyPredictor(Predictor):
|
||||
def __init__(self, factor: float = 1.0):
|
||||
def __init__(
|
||||
self, factor: float = 1.0, preprocessor: Optional[Preprocessor] = None
|
||||
):
|
||||
self.factor = factor
|
||||
self.preprocessor = DummyPreprocessor()
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "DummyPredictor":
|
||||
checkpoint_data = checkpoint.to_dict()
|
||||
return cls(**checkpoint_data)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(checkpoint_data["factor"], preprocessor)
|
||||
|
||||
def _predict_pandas(self, data: pd.DataFrame, **kwargs) -> pd.DataFrame:
|
||||
return data * self.factor
|
||||
|
@ -45,27 +54,33 @@ def test_from_checkpoint():
|
|||
assert DummyPredictor.from_checkpoint(checkpoint).factor == 2.0
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"ray.train.predictor.convert_batch_type_to_pandas",
|
||||
return_value=mock.DEFAULT,
|
||||
)
|
||||
@mock.patch(
|
||||
"ray.train.predictor.convert_pandas_to_batch_type",
|
||||
return_value=mock.DEFAULT,
|
||||
)
|
||||
def test_predict(convert_from_pandas_mock, convert_to_pandas_mock):
|
||||
checkpoint = Checkpoint.from_dict({"factor": 2.0})
|
||||
predictor = DummyPredictor.from_checkpoint(checkpoint)
|
||||
@mock.patch("ray.train.predictor.convert_pandas_to_batch_type")
|
||||
@mock.patch("ray.train.predictor.convert_batch_type_to_pandas")
|
||||
def test_predict(convert_to_pandas_mock, convert_from_pandas_mock):
|
||||
|
||||
input = pd.DataFrame({"x": [1, 2, 3]})
|
||||
expected_output = input * 4.0
|
||||
|
||||
convert_to_pandas_mock.return_value = input
|
||||
convert_from_pandas_mock.return_value = expected_output
|
||||
|
||||
checkpoint = Checkpoint.from_dict(
|
||||
{"factor": 2.0, PREPROCESSOR_KEY: DummyPreprocessor()}
|
||||
)
|
||||
predictor = DummyPredictor.from_checkpoint(checkpoint)
|
||||
|
||||
actual_output = predictor.predict(input)
|
||||
assert actual_output.equals(expected_output)
|
||||
|
||||
# Ensure the proper conversion functions are called.
|
||||
convert_to_pandas_mock.assert_called_once()
|
||||
convert_to_pandas_mock.assert_called_once_with(input)
|
||||
convert_from_pandas_mock.assert_called_once()
|
||||
|
||||
pd.testing.assert_frame_equal(
|
||||
convert_from_pandas_mock.call_args[0][0], expected_output
|
||||
)
|
||||
assert convert_from_pandas_mock.call_args[1]["type"] == DataType.PANDAS
|
||||
|
||||
|
||||
def test_from_udf():
|
||||
def check_truth(df, all_true=False):
|
||||
|
@ -99,6 +114,37 @@ def test_kwargs(predict_pandas_mock):
|
|||
assert predict_pandas_mock.call_args[1]["extra_arg"] == 1
|
||||
|
||||
|
||||
def test_get_and_set_preprocessor():
|
||||
"""Test preprocessor can be set and get."""
|
||||
|
||||
preprocessor = DummyPreprocessor(1)
|
||||
predictor = DummyPredictor.from_checkpoint(
|
||||
Checkpoint.from_dict({"factor": 2.0, PREPROCESSOR_KEY: preprocessor}),
|
||||
)
|
||||
assert predictor.get_preprocessor() == preprocessor
|
||||
|
||||
test_dataset = pd.DataFrame(range(4))
|
||||
output_df = predictor.predict(test_dataset)
|
||||
assert output_df.to_numpy().squeeze().tolist() == [
|
||||
0.0,
|
||||
2.0,
|
||||
4.0,
|
||||
6.0,
|
||||
]
|
||||
|
||||
preprocessor2 = DummyPreprocessor(2)
|
||||
predictor.set_preprocessor(preprocessor2)
|
||||
assert predictor.get_preprocessor() == preprocessor2
|
||||
|
||||
output_df = predictor.predict(test_dataset)
|
||||
assert output_df.to_numpy().squeeze().tolist() == [
|
||||
0.0,
|
||||
4.0,
|
||||
8.0,
|
||||
12.0,
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -57,7 +57,10 @@ def test_init():
|
|||
checkpoint_predictor.estimator.feature_importances_,
|
||||
predictor.estimator.feature_importances_,
|
||||
)
|
||||
assert checkpoint_predictor.preprocessor.attr == predictor.preprocessor.attr
|
||||
assert (
|
||||
checkpoint_predictor.get_preprocessor().attr
|
||||
== predictor.get_preprocessor().attr
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_type", [np.ndarray, pd.DataFrame, pa.Table, dict])
|
||||
|
@ -70,7 +73,7 @@ def test_predict(batch_type):
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_set_cpus(ray_start_4_cpus):
|
||||
|
@ -81,7 +84,7 @@ def test_predict_set_cpus(ray_start_4_cpus):
|
|||
predictions = predictor.predict(data_batch, num_estimator_cpus=2)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
assert predictor.estimator.n_jobs == 2
|
||||
|
||||
|
||||
|
@ -93,7 +96,7 @@ def test_predict_feature_columns():
|
|||
predictions = predictor.predict(data_batch, feature_columns=[0, 1])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns_pandas():
|
||||
|
@ -110,7 +113,7 @@ def test_predict_feature_columns_pandas():
|
|||
predictions = predictor.predict(data_batch, feature_columns=["A", "B"])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_no_preprocessor():
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_init():
|
|||
|
||||
assert checkpoint_predictor.model_definition == predictor.model_definition
|
||||
assert checkpoint_predictor.model_weights == predictor.model_weights
|
||||
assert checkpoint_predictor.preprocessor == predictor.preprocessor
|
||||
assert checkpoint_predictor.get_preprocessor() == predictor.get_preprocessor()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_gpu", [False, True])
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_init(model, preprocessor):
|
|||
)
|
||||
|
||||
assert checkpoint_predictor.model == predictor.model
|
||||
assert checkpoint_predictor.preprocessor == predictor.preprocessor
|
||||
assert checkpoint_predictor.get_preprocessor() == predictor.get_preprocessor()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_gpu", [False, True])
|
||||
|
|
|
@ -51,7 +51,10 @@ def test_init():
|
|||
checkpoint_predictor = XGBoostPredictor.from_checkpoint(checkpoint)
|
||||
|
||||
assert get_num_trees(checkpoint_predictor.model) == get_num_trees(predictor.model)
|
||||
assert checkpoint_predictor.preprocessor.attr == predictor.preprocessor.attr
|
||||
assert (
|
||||
checkpoint_predictor.get_preprocessor().attr
|
||||
== predictor.get_preprocessor().attr
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_type", [np.ndarray, pd.DataFrame, pa.Table, dict])
|
||||
|
@ -64,7 +67,7 @@ def test_predict(batch_type):
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns():
|
||||
|
@ -75,7 +78,7 @@ def test_predict_feature_columns():
|
|||
predictions = predictor.predict(data_batch, feature_columns=[0, 1])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_feature_columns_pandas():
|
||||
|
@ -92,7 +95,7 @@ def test_predict_feature_columns_pandas():
|
|||
predictions = predictor.predict(data_batch, feature_columns=["A", "B"])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert hasattr(predictor.preprocessor, "_batch_transformed")
|
||||
assert hasattr(predictor.get_preprocessor(), "_batch_transformed")
|
||||
|
||||
|
||||
def test_predict_no_preprocessor_no_training():
|
||||
|
|
|
@ -37,7 +37,6 @@ class TorchPredictor(DLPredictor):
|
|||
):
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
# TODO (jiaodong): #26249 Use multiple GPU devices with sharded input
|
||||
self.use_gpu = use_gpu
|
||||
|
@ -59,6 +58,8 @@ class TorchPredictor(DLPredictor):
|
|||
"enable GPU prediction."
|
||||
)
|
||||
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
|
@ -80,7 +81,8 @@ class TorchPredictor(DLPredictor):
|
|||
use_gpu: If set, the model will be moved to GPU on instantiation and
|
||||
prediction happens on GPU.
|
||||
"""
|
||||
model, preprocessor = load_checkpoint(checkpoint, model)
|
||||
model, _ = load_checkpoint(checkpoint, model)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(model=model, preprocessor=preprocessor, use_gpu=use_gpu)
|
||||
|
||||
def _array_to_tensor(
|
||||
|
|
|
@ -27,7 +27,7 @@ class XGBoostPredictor(Predictor):
|
|||
self, model: xgboost.Booster, preprocessor: Optional["Preprocessor"] = None
|
||||
):
|
||||
self.model = model
|
||||
self.preprocessor = preprocessor
|
||||
super().__init__(preprocessor)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, checkpoint: Checkpoint) -> "XGBoostPredictor":
|
||||
|
@ -41,7 +41,8 @@ class XGBoostPredictor(Predictor):
|
|||
``XGBoostTrainer`` run.
|
||||
|
||||
"""
|
||||
bst, preprocessor = load_checkpoint(checkpoint)
|
||||
bst, _ = load_checkpoint(checkpoint)
|
||||
preprocessor = checkpoint.get_preprocessor()
|
||||
return cls(model=bst, preprocessor=preprocessor)
|
||||
|
||||
def _predict_pandas(
|
||||
|
|
Loading…
Add table
Reference in a new issue