[AIR] Preprocessors feature guide (#25302)

This commit is contained in:
matthewdeng 2022-06-03 11:43:51 -07:00 committed by GitHub
parent 313e8730a2
commit 2e05b62236
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 336 additions and 0 deletions

View file

@ -180,6 +180,7 @@ parts:
- file: ray-air/key-concepts
- file: ray-air/deployment
- file: ray-air/check-ingest
- file: ray-air/preprocessors
- file: ray-air/examples/index
sections:
- file: ray-air/examples/analyze_tuning_results

View file

@ -0,0 +1,136 @@
# flake8: noqa
# __preprocessor_setup_start__
import pandas as pd
import ray
from ray.ml.preprocessors import MinMaxScaler
# Generate two simple datasets.
dataset = ray.data.range_table(8)
dataset1, dataset2 = dataset.split(2)
print(dataset1.take())
# [{'value': 0}, {'value': 1}, {'value': 2}, {'value': 3}]
print(dataset2.take())
# [{'value': 4}, {'value': 5}, {'value': 6}, {'value': 7}]
# __preprocessor_setup_end__
# __preprocessor_fit_transform_start__
# Fit the preprocessor on dataset1, and transform both dataset1 and dataset2.
preprocessor = MinMaxScaler(["value"])
dataset1_transformed = preprocessor.fit_transform(dataset1)
print(dataset1_transformed.take())
# [{'value': 0.0}, {'value': 0.3333333333333333}, {'value': 0.6666666666666666}, {'value': 1.0}]
dataset2_transformed = preprocessor.transform(dataset2)
print(dataset2_transformed.take())
# [{'value': 1.3333333333333333}, {'value': 1.6666666666666667}, {'value': 2.0}, {'value': 2.3333333333333335}]
# __preprocessor_fit_transform_end__
# __preprocessor_transform_batch_start__
batch = pd.DataFrame({"value": list(range(8, 12))})
batch_transformed = preprocessor.transform_batch(batch)
print(batch_transformed)
# value
# 0 2.666667
# 1 3.000000
# 2 3.333333
# 3 3.666667
# __preprocessor_transform_batch_end__
# __trainer_start__
import ray
from ray.ml.train.integrations.xgboost import XGBoostTrainer
from ray.ml.preprocessors import MinMaxScaler
train_dataset = ray.data.from_items([{"x": x, "y": 2 * x} for x in range(0, 32, 3)])
valid_dataset = ray.data.from_items([{"x": x, "y": 2 * x} for x in range(1, 32, 3)])
preprocessor = MinMaxScaler(["x"])
trainer = XGBoostTrainer(
label_column="y",
params={"objective": "reg:squarederror"},
scaling_config={"num_workers": 2},
datasets={"train": train_dataset, "valid": valid_dataset},
preprocessor=preprocessor,
)
result = trainer.fit()
# __trainer_end__
# __checkpoint_start__
from ray.ml.utils.checkpointing import load_preprocessor_from_dir
checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_path:
preprocessor = load_preprocessor_from_dir(checkpoint_path)
print(preprocessor)
# MixMaxScaler(columns=['x'], stats={'min(x)': 0, 'max(x)': 30})
# __checkpoint_end__
# __predictor_start__
from ray.ml.batch_predictor import BatchPredictor
from ray.ml.predictors.integrations.xgboost import XGBoostPredictor
test_dataset = ray.data.from_items([{"x": x} for x in range(2, 32, 3)])
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)
predicted_labels = batch_predictor.predict(test_dataset)
print(predicted_labels.to_pandas())
# predictions
# 0 0.098437
# 1 5.604667
# 2 11.405312
# 3 15.684700
# 4 23.990948
# 5 29.900211
# 6 34.599442
# 7 40.696899
# 8 45.681076
# 9 50.290031
# __predictor_end__
# __chain_start__
import ray
from ray.ml.preprocessors import Chain, MinMaxScaler, SimpleImputer
# Generate one simple dataset.
dataset = ray.data.from_items(
[{"value": 0}, {"value": 1}, {"value": 2}, {"value": 3}, {"value": None}]
)
print(dataset.take())
# [{'value': 0}, {'value': 1}, {'value': 2}, {'value': 3}, {'value': None}]
preprocessor = Chain(SimpleImputer(["value"]), MinMaxScaler(["value"]))
dataset_transformed = preprocessor.fit_transform(dataset)
print(dataset_transformed.take())
# [{'value': 0.0}, {'value': 0.3333333333333333}, {'value': 0.6666666666666666}, {'value': 1.0}, {'value': 0.5}]
# __chain_end__
# __custom_stateless_start__
import ray
from ray.ml.preprocessors import BatchMapper
# Generate a simple dataset.
dataset = ray.data.range_table(4)
print(dataset.take())
# [{'value': 0}, {'value': 1}, {'value': 2}, {'value': 3}]
# Create a stateless preprocess that multiplies values by 2.
preprocessor = BatchMapper(lambda df: df * 2)
dataset_transformed = preprocessor.transform(dataset)
print(dataset_transformed.take())
# [{'value': 0}, {'value': 2}, {'value': 4}, {'value': 6}]
# __custom_stateless_end__

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 17 KiB

View file

@ -71,6 +71,7 @@ You can take a trained model and do batch inference using the BatchPredictor obj
:start-after: __air_batch_predictor_start__
:end-before: __air_batch_predictor_end__
.. _air-key-concepts-online-inference:
Online Inference
----------------

View file

@ -81,6 +81,8 @@ Predictors
.. autoclass:: ray.ml.predictor.Predictor
:members:
.. autoclass:: ray.ml.predictor.DataBatchType
.. autoclass:: ray.ml.batch_predictor.BatchPredictor
:members:

View file

@ -0,0 +1,192 @@
.. _air-preprocessors:
Preprocessing Data
==================
This page describes how to perform data preprocessing in Ray AIR.
Data preprocessing is a common technique for transforming raw data into features that will be input to a machine learning model.
In general, you may want to apply the same preprocessing logic to your offline training data and online inference data.
Ray AIR provides several common preprocessors out of the box as well as interfaces that enable you to define your own custom logic.
Overview
--------
Ray AIR exposes a ``Preprocessor`` class for preprocessing. The ``Preprocessor`` has four methods that make up its core interface.
#. ``fit()``: Compute state information about a :class:`Dataset <ray.data.Dataset>` (e.g. the mean or standard deviation of a column)
and save it to the ``Preprocessor``. This information should then be used to perform ``transform()``.
*This is typically called on the training dataset.*
#. ``transform()``: Apply a transformation to a ``Dataset``.
If the ``Preprocessor`` is stateful, then ``fit()`` must be called first.
*This is typically called on the training, validation, test datasets.*
#. ``transform_batch()``: Apply a transformation to a single :class:`batch <ray.ml.predictor.DataBatchType>` of data.
*This is typically called on online or offline inference data.*
#. ``fit_transform()``: Syntactic sugar for calling both ``fit()`` and ``transform()`` on a ``Dataset``.
To show these in action, let's walk through a basic example. First we'll set up two simple Ray ``Dataset``\s.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __preprocessor_setup_start__
:end-before: __preprocessor_setup_end__
Next, ``fit`` the ``Preprocessor`` on one ``Dataset``, and ``transform`` both ``Dataset``\s with this fitted information.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __preprocessor_fit_transform_start__
:end-before: __preprocessor_fit_transform_end__
Finally, call ``transform_batch`` on a single batch of data.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __preprocessor_transform_batch_start__
:end-before: __preprocessor_transform_batch_end__
Life of an AIR Preprocessor
---------------------------
Now that we've gone over the basics, let's dive into how ``Preprocessor``\s fit into an end-to-end application built with AIR.
The diagram below depicts an overview of the main steps of a ``Preprocessor``:
#. Passed into a ``Trainer`` to ``fit`` and ``transform`` input ``Dataset``\s.
#. Saved as a ``Checkpoint``.
#. Reconstructed in a ``Predictor`` to ``fit_batch`` on batches of data.
.. figure:: images/air-preprocessor.svg
Throughout this section we'll go through this workflow in more detail, with code examples using XGBoost.
The same logic is applicable to other integrations as well.
Trainer
~~~~~~~
The journey of the ``Preprocessor`` starts with the :class:`Trainer <ray.ml.trainer.Trainer>`.
If the ``Trainer`` is instantiated with a ``Preprocessor``, then the following logic will be executed when ``Trainer.fit()`` is called:
#. If a ``"train"`` ``Dataset`` is passed in, then the ``Preprocessor`` will call ``fit()`` on it.
#. The ``Preprocessor`` will then call ``transform()`` on *all* ``Dataset``\s, including the ``"train"`` ``Dataset``.
#. The ``Trainer`` will then perform training on the preprocessed ``Dataset``\s.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __trainer_start__
:end-before: __trainer_end__
.. note::
If you're passing a ``Preprocessor`` that is already fitted, it will be refitted on the ``"train"`` ``Dataset``.
Adding the functionality to support passing in a fitted Preprocessor is being tracked
`here <https://github.com/ray-project/ray/issues/25299>`__.
.. TODO: Remove the note above once the issue is resolved.
Tune
~~~~
If you're using ``Ray Tune`` for hyperparameter optimization, be aware that each ``Trial`` will instantiate its own copy of
the ``Preprocessor`` and the fitting and transformation logic will occur once per ``Trial``.
Checkpoint
~~~~~~~~~~
``Trainer.fit()`` returns a ``Results`` object which contains a ``Checkpoint``.
If a ``Preprocessor`` was passed into the ``Trainer``, then it will be saved in the ``Checkpoint`` along with any fitted state.
As a sanity check, let's confirm the ``Preprocessor`` is available in the ``Checkpoint``. In practice you should not need to do this.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __checkpoint_start__
:end-before: __checkpoint_end__
Predictor
~~~~~~~~~
A ``Predictor`` can be constructed from a saved ``Checkpoint``. If the ``Checkpoint`` contains a ``Preprocessor``,
then the ``Preprocessor`` will be used to call ``transform_batch`` on input batches prior to performing inference.
In the following example, we show the Batch Predictor flow. The same logic applies to the :ref:`Online Inference flow <air-key-concepts-online-inference>`.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __predictor_start__
:end-before: __predictor_end__
Types of Preprocessors
----------------------
Basic Preprocessors
~~~~~~~~~~~~~~~~~~~
Ray AIR provides a handful of ``Preprocessor``\s that you can use out of the box, and more will be added over time.
`Contributions <https://docs.ray.io/en/master/getting-involved.html>`__ are welcome!
.. tabbed:: Common APIs
#. :class:`Preprocessor <ray.ml.preprocessor.Preprocessor>`
#. :class:`BatchMapper <ray.ml.preprocessors.BatchMapper>`
#. :class:`Chain <ray.ml.preprocessors.Chain>`
.. tabbed:: Tabular
#. :class:`Categorizer <ray.ml.preprocessors.Categorizer>`
#. :class:`FeatureHasher <ray.ml.preprocessors.FeatureHasher>`
#. :class:`LabelEncoder <ray.ml.preprocessors.LabelEncoder>`
#. :class:`MaxAbsScaler <ray.ml.preprocessors.MaxAbsScaler>`
#. :class:`MinMaxScaler <ray.ml.preprocessors.MinMaxScaler>`
#. :class:`Normalizer <ray.ml.preprocessors.Normalizer>`
#. :class:`OneHotEncoder <ray.ml.preprocessors.OneHotEncoder>`
#. :class:`OrdinalEncoder <ray.ml.preprocessors.OrdinalEncoder>`
#. :class:`PowerTransformer <ray.ml.preprocessors.PowerTransformer>`
#. :class:`RobustScaler <ray.ml.preprocessors.RobustScaler>`
#. :class:`SimpleImputer <ray.ml.preprocessors.SimpleImputer>`
#. :class:`StandardScaler <ray.ml.preprocessors.StandardScaler>`
#. :class:`SimpleImputer <ray.ml.preprocessors.SimpleImputer>`
.. tabbed:: Text
#. :class:`CountVectorizer <ray.ml.preprocessors.CountVectorizer>`
#. :class:`HashingVectorizer <ray.ml.preprocessors.HashingVectorizer>`
#. :class:`Tokenizer <ray.ml.preprocessors.Tokenizer>`
.. tabbed:: Image
Coming soon!
.. tabbed:: Utilities
#. :func:`train_test_split <ray.ml.train_test_split>`
Chaining Preprocessors
~~~~~~~~~~~~~~~~~~~~~~
More often than not, your preprocessing logic will contain multiple logical steps or apply different transformations to each column.
A simple ``Chain`` ``Preprocessor`` is provided which can be used to apply individual ``Preprocessor`` operations sequentially.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __chain_start__
:end-before: __chain_end__
.. tip::
Keep in mind that the operations are sequential. For example, if you define a ``Preprocessor``
``Chain([preprocessorA, preprocessorB])``, then ``preprocessorB.transform()`` will be applied
to the result of ``preprocessorA.transform()``.
Custom Preprocessors
~~~~~~~~~~~~~~~~~~~~
**Stateless Preprocessors:** Stateless preprocessors can be implemented with the ``BatchMapper``.
.. literalinclude:: doc_code/preprocessors.py
:language: python
:start-after: __custom_stateless_start__
:end-before: __custom_stateless_end__
**Stateful Preprocessors:** Coming soon!