mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[ml] Don't require preprocessor in TorchPredictor (#23163)
This commit is contained in:
parent
6a1e336b24
commit
154edce2a4
2 changed files with 21 additions and 2 deletions
|
@ -44,7 +44,13 @@ class TorchPredictor(Predictor):
|
|||
``model``.
|
||||
"""
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
preprocessor = checkpoint_dict[PREPROCESSOR_KEY]
|
||||
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
|
||||
if MODEL_KEY not in checkpoint_dict:
|
||||
raise RuntimeError(
|
||||
f"No item with key: {MODEL_KEY} is found in the "
|
||||
f"Checkpoint. Make sure this key exists when saving the "
|
||||
f"checkpoint in ``TorchTrainer``."
|
||||
)
|
||||
model = load_torch_model(
|
||||
saved_model=checkpoint_dict[MODEL_KEY], model_definition=model
|
||||
)
|
||||
|
@ -113,7 +119,9 @@ class TorchPredictor(Predictor):
|
|||
Returns:
|
||||
DataBatchType: Prediction result.
|
||||
"""
|
||||
data = self.preprocessor.transform_batch(data)
|
||||
if self.preprocessor:
|
||||
data = self.preprocessor.transform_batch(data)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
# If numpy array, then convert to pandas dataframe.
|
||||
data = pd.DataFrame(data)
|
||||
|
|
|
@ -53,3 +53,14 @@ def test_predict_feature_columns():
|
|||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [4, 8, 12]
|
||||
|
||||
|
||||
def test_predict_no_preprocessor():
|
||||
checkpoint = Checkpoint.from_dict({MODEL_KEY: model})
|
||||
predictor = TorchPredictor.from_checkpoint(checkpoint)
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
|
||||
|
|
Loading…
Add table
Reference in a new issue