[AIR] Add _predict_arrow interface for Predictor (#25579)

* add interface

* update docstring
This commit is contained in:
Amog Kamsetty 2022-06-08 10:27:29 -07:00 committed by GitHub
parent 0bbc3379bd
commit 1be32e5977
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -50,8 +50,9 @@ class Predictor(abc.ABC):
To implement a new Predictor for your particular framework, you should subclass
the base ``Predictor`` and implement the following two methods:
1. ``_predict_pandas``: Given a pandas.DataFrame input, return a
pandas.DataFrame containing predictions.
1. ``_predict_pandas`` or ``_predict_arrow``: Given a
pandas.DataFrame/pyarrow.Table input, return a
pandas.DataFrame/pyarrow.Table containing predictions.
2. ``from_checkpoint``: Logic for creating a Predictor from an
:ref:`AIR Checkpoint <air-checkpoint-ref>`.
"""
@ -111,6 +112,24 @@ class Predictor(abc.ABC):
"""
raise NotImplementedError
@DeveloperAPI
def _predict_arrow(self, data: "pyarrow.Table", **kwargs) -> "pyarrow.Table":
"""Perform inference on an Arrow Table.
Predictors can implement this method instead of ``_predict_pandas``
for better performance when the input batch type is a Numpy array, dict of
numpy arrays, or an Arrow Table as conversion from these types are zero copy.
Args:
data: An Arrow Table to perform predictions on.
kwargs: Arguments specific to the predictor implementation.
Returns:
An Arrow Table containing the prediction result.
"""
raise NotImplementedError
def __reduce__(self):
raise PredictorNotSerializableException(
"Predictor instances are not serializable. Instead, you may want "