[AIR] Update TorchPredictor to new Predictor API (#25536)

This commit is contained in:
Amog Kamsetty 2022-06-22 09:49:07 -07:00 committed by GitHub
parent 6552e096e6
commit d6e8b90236
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 142 additions and 108 deletions

View file

@ -14,14 +14,24 @@ class DummyPreprocessor(Preprocessor):
return df * 2
class DummyModel(torch.nn.Linear):
class DummyModelSingleTensor(torch.nn.Module):
def forward(self, input):
return input * 2
class DummyModelMultiInput(torch.nn.Module):
def forward(self, input_dict):
return sum(input_dict.values())
class DummyModelMultiOutput(torch.nn.Module):
def forward(self, input_tensor):
return [input_tensor, input_tensor]
@pytest.fixture
def model():
return DummyModel(1, 1)
return DummyModelSingleTensor()
@pytest.fixture
@ -53,11 +63,11 @@ def test_predict_model_not_training(model):
def test_predict_array(model):
predictor = TorchPredictor(model=model)
data_batch = np.array([[1], [2], [3]])
data_batch = np.asarray([[1], [2], [3]])
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
assert predictions.flatten().tolist() == [2, 4, 6]
def test_predict_array_with_preprocessor(model, preprocessor):
@ -67,17 +77,31 @@ def test_predict_array_with_preprocessor(model, preprocessor):
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [4, 8, 12]
assert predictions.flatten().tolist() == [4, 8, 12]
def test_predict_dataframe():
predictor = TorchPredictor(model=torch.nn.Linear(2, 1, bias=False))
predictor = TorchPredictor(model=DummyModelMultiInput())
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [0.0, 0.0, 0.0]})
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [1.0, 2.0, 3.0]})
predictions = predictor.predict(data_batch, dtype=torch.float)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0]
assert predictions.to_numpy().flatten().tolist() == [1.0, 2.0, 3.0]
def test_predict_multi_output():
predictor = TorchPredictor(model=DummyModelMultiOutput())
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch)
# Model outputs two tensors
assert len(predictions) == 2
for k, v in predictions.items():
# Each tensor is of size 3
assert len(v) == 3
assert v.flatten().tolist() == [1, 2, 3]
@pytest.mark.parametrize(
@ -89,26 +113,13 @@ def test_predict_dataframe():
(torch.int64, np.int64),
),
)
def test_predict_array_with_different_dtypes(input_dtype, expected_output_dtype):
predictor = TorchPredictor(model=torch.nn.Identity())
def test_predict_array_with_different_dtypes(model, input_dtype, expected_output_dtype):
predictor = TorchPredictor(model=model)
data_batch = np.array([[1], [2], [3]])
predictions = predictor.predict(data_batch, dtype=input_dtype)
assert all(
prediction.dtype == expected_output_dtype
for prediction in predictions["predictions"]
)
def test_predict_dataframe_with_feature_columns():
predictor = TorchPredictor(model=torch.nn.Identity())
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [1.0, 1.0, 1.0]})
predictions = predictor.predict(data_batch, feature_columns=["X0"])
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0]
assert predictions.dtype == expected_output_dtype
def test_predict_array_no_training(model):
@ -119,7 +130,36 @@ def test_predict_array_no_training(model):
predictions = predictor.predict(data_batch)
assert len(predictions) == 3
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
assert predictions.flatten().tolist() == [2, 4, 6]
def test_array_real_model():
model = torch.nn.Linear(2, 1)
predictor = TorchPredictor(model=model)
data = np.array([[1, 2], [3, 4]])
predictions = predictor.predict(data, dtype=torch.float)
assert len(predictions) == 2
def test_multi_modal_real_model():
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(1, 1)
self.linear2 = torch.nn.Linear(1, 1)
def forward(self, input_dict: dict):
out1 = self.linear1(input_dict["A"])
out2 = self.linear2(input_dict["B"])
return out1 + out2
predictor = TorchPredictor(model=CustomModule())
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
predictions = predictor.predict(data, dtype=torch.float)
assert len(predictions) == 2
if __name__ == "__main__":

View file

@ -1,12 +1,13 @@
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np
import pandas as pd
import torch
from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor
from ray.air.checkpoint import Checkpoint
from ray.train.predictor import DataBatchType, Predictor
from ray.air.checkpoint import Checkpoint
from ray.air.util.data_batch_conversion import convert_pandas_to_batch_type, DataType
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.train.torch.utils import load_checkpoint
if TYPE_CHECKING:
@ -47,81 +48,79 @@ class TorchPredictor(Predictor):
model, preprocessor = load_checkpoint(checkpoint, model)
return TorchPredictor(model=model, preprocessor=preprocessor)
# parity with Datset.to_torch
def _convert_to_tensor(
def _predict_pandas(
self,
data: pd.DataFrame,
feature_columns: Optional[
Union[List[str], List[List[str]], List[int], List[List[int]]]
] = None,
dtypes: Optional[torch.dtype] = None,
unsqueeze: bool = True,
) -> torch.Tensor:
"""Handle conversion of data to tensor.
dtype: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
) -> pd.DataFrame:
def tensorize(numpy_array, dtype):
torch_tensor = torch.from_numpy(numpy_array).to(dtype)
# Off-the-shelf torch Modules expect the input size to have at least 2
# dimensions (batch_size, feature_size). If the tensor for the column
# is flattened, then we unqueeze it to add an extra dimension.
if len(torch_tensor.size()) == 1:
torch_tensor = torch_tensor.unsqueeze(dim=1)
return torch_tensor
tensors = convert_pandas_to_batch_type(data, DataType.NUMPY)
# Single numpy array.
if isinstance(tensors, np.ndarray):
column_name = data.columns[0]
if isinstance(dtype, dict):
dtype = dtype[column_name]
model_input = tensorize(tensors, dtype)
Same arguments as in ``convert_pandas_to_torch_tensor``."""
# TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type.
# Reduce conversion cost if input is in Numpy
if isinstance(feature_columns, dict):
features_tensor = {
key: convert_pandas_to_torch_tensor(
data,
feature_columns[key],
dtypes[key] if isinstance(dtypes, dict) else dtypes,
unsqueeze=unsqueeze,
)
for key in feature_columns
}
else:
features_tensor = convert_pandas_to_torch_tensor(
data,
columns=feature_columns,
column_dtypes=dtypes,
unsqueeze=unsqueeze,
)
return features_tensor
model_input = {
k: tensorize(v, dtype=dtype[k] if isinstance(dtype, dict) else dtype)
for k, v in tensors.items()
}
def _predict(self, tensor: torch.Tensor) -> pd.DataFrame:
"""Handle actual prediction."""
prediction = self.model(tensor).cpu().detach().numpy()
# If model has outputs a Numpy array (for example outputting logits),
# these cannot be used as values in a Pandas Dataframe.
# We have to convert the outermost dimension to a python list (but the values
# in the list can still be Numpy arrays).
return pd.DataFrame({"predictions": list(prediction)}, columns=["predictions"])
with torch.no_grad():
self.model.eval()
output = self.model(model_input)
def untensorize(torch_tensor):
numpy_array = torch_tensor.cpu().detach().numpy()
return TensorArray(numpy_array)
# Handle model multi-output. For example if model outputs 2 images.
if isinstance(output, dict):
return pd.DataFrame({k: untensorize(v) for k, v in output})
elif isinstance(output, list) or isinstance(output, tuple):
tensor_name = "output_"
output_dict = {}
for i in range(len(output)):
output_dict[tensor_name + str(i + 1)] = untensorize(output[i])
return pd.DataFrame(output_dict)
else:
return pd.DataFrame(
{"predictions": untensorize(output)}, columns=["predictions"]
)
def predict(
self,
data: DataBatchType,
feature_columns: Optional[
Union[List[str], List[List[str]], List[int], List[List[int]]]
] = None,
dtype: Optional[torch.dtype] = None,
unsqueeze: bool = True,
dtype: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
) -> DataBatchType:
"""Run inference on data batch.
The data is converted into a torch Tensor before being inputted to
the model.
If the provided data is a single array or a dataframe/table with a single
column, it will be converted into a single PyTorch tensor before being
inputted to the model.
If the provided data is a multi-column table or a dict of numpy arrays,
it will be converted into a dict of tensors before being inputted to the
model. This is useful for multi-modal inputs (for example your model accepts
both image and text).
Args:
data: A batch of input data. Either a pandas DataFrame or numpy
array.
feature_columns: The names or indices of the columns in the
data to use as features to predict on. If this arg is a
list of lists or a dict of string-list pairs, then the
data batch will be converted into a
multiple tensors which are then concatenated before feeding
into the model. This is useful for multi-input models. If
None, then use all columns in ``data``.
dtype: The dtypes to use for the tensors. This should match the
format of ``feature_columns``, or be a single dtype, in which
case it will be applied to all tensors.
If None, then automatically infer the dtype.
unsqueeze: If set to True, the features tensors will be unsqueezed
(reshaped to (N, 1)) before being concatenated into the final features
tensor. Otherwise, they will be left as is, that is (N, ).
Defaults to True.
data: A batch of input data of ``DataBatchType``.
dtype: The dtypes to use for the tensors. Either a single dtype for all
tensors or a mapping from column name to dtype.
Examples:
@ -135,7 +134,7 @@ class TorchPredictor(Predictor):
predictor = TorchPredictor(model=model)
data = np.array([[1, 2], [3, 4]])
predictions = predictor.predict(data)
predictions = predictor.predict(data, dtype=torch.float)
.. code-block:: python
@ -143,30 +142,25 @@ class TorchPredictor(Predictor):
import torch
from ray.train.torch import TorchPredictor
model = torch.nn.Linear(1, 1)
predictor = TorchPredictor(model=model)
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(1, 1)
self.linear2 = torch.nn.Linear(1, 1)
def forward(self, input_dict: dict):
out1 = self.linear1(input_dict["A"])
out2 = self.linear2(input_dict["B"])
return out1 + out2
predictor = TorchPredictor(model=CustomModule())
# Pandas dataframe.
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
predictions = predictor.predict(data)
# Only use first column as the feature
predictions = predictor.predict(data, feature_columns=["A"])
Returns:
DataBatchType: Prediction result.
"""
self.model.eval()
if self.preprocessor:
data = self.preprocessor.transform_batch(data)
if isinstance(data, np.ndarray):
tensor = torch.tensor(data, dtype=dtype)
else:
tensor = self._convert_to_tensor(
data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze
)
return self._predict(tensor)
return super(TorchPredictor, self).predict(data=data, dtype=dtype)