[ml] Don't require preprocessor in TorchPredictor (#23163)

This commit is contained in:
Amog Kamsetty 2022-03-14 16:33:22 -07:00 committed by GitHub
parent 6a1e336b24
commit 154edce2a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 2 deletions

View file

@ -44,7 +44,13 @@ class TorchPredictor(Predictor):
``model``.
"""
checkpoint_dict = checkpoint.to_dict()
preprocessor = checkpoint_dict[PREPROCESSOR_KEY]
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
if MODEL_KEY not in checkpoint_dict:
raise RuntimeError(
f"No item with key: {MODEL_KEY} is found in the "
f"Checkpoint. Make sure this key exists when saving the "
f"checkpoint in ``TorchTrainer``."
)
model = load_torch_model(
saved_model=checkpoint_dict[MODEL_KEY], model_definition=model
)
@ -113,7 +119,9 @@ class TorchPredictor(Predictor):
Returns:
DataBatchType: Prediction result.
"""
data = self.preprocessor.transform_batch(data)
if self.preprocessor:
data = self.preprocessor.transform_batch(data)
if isinstance(data, np.ndarray):
# If numpy array, then convert to pandas dataframe.
data = pd.DataFrame(data)

View file

@ -53,3 +53,14 @@ def test_predict_feature_columns():
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [4, 8, 12]
def test_predict_no_preprocessor():
checkpoint = Checkpoint.from_dict({MODEL_KEY: model})
predictor = TorchPredictor.from_checkpoint(checkpoint)
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]