diff --git a/python/ray/air/predictor.py b/python/ray/air/predictor.py index a4d76f97c..2d0d9854e 100644 --- a/python/ray/air/predictor.py +++ b/python/ray/air/predictor.py @@ -50,8 +50,9 @@ class Predictor(abc.ABC): To implement a new Predictor for your particular framework, you should subclass the base ``Predictor`` and implement the following two methods: - 1. ``_predict_pandas``: Given a pandas.DataFrame input, return a - pandas.DataFrame containing predictions. + 1. ``_predict_pandas`` or ``_predict_arrow``: Given a + pandas.DataFrame/pyarrow.Table input, return a + pandas.DataFrame/pyarrow.Table containing predictions. 2. ``from_checkpoint``: Logic for creating a Predictor from an :ref:`AIR Checkpoint `. """ @@ -111,6 +112,24 @@ class Predictor(abc.ABC): """ raise NotImplementedError + @DeveloperAPI + def _predict_arrow(self, data: "pyarrow.Table", **kwargs) -> "pyarrow.Table": + """Perform inference on an Arrow Table. + + Predictors can implement this method instead of ``_predict_pandas`` + for better performance when the input batch type is a Numpy array, dict of + numpy arrays, or an Arrow Table as conversion from these types are zero copy. + + Args: + data: An Arrow Table to perform predictions on. + kwargs: Arguments specific to the predictor implementation. + + Returns: + An Arrow Table containing the prediction result. + """ + + raise NotImplementedError + def __reduce__(self): raise PredictorNotSerializableException( "Predictor instances are not serializable. Instead, you may want "