mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Update TorchPredictor
to new Predictor API (#25536)
This commit is contained in:
parent
6552e096e6
commit
d6e8b90236
2 changed files with 142 additions and 108 deletions
|
@ -14,14 +14,24 @@ class DummyPreprocessor(Preprocessor):
|
|||
return df * 2
|
||||
|
||||
|
||||
class DummyModel(torch.nn.Linear):
|
||||
class DummyModelSingleTensor(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input * 2
|
||||
|
||||
|
||||
class DummyModelMultiInput(torch.nn.Module):
|
||||
def forward(self, input_dict):
|
||||
return sum(input_dict.values())
|
||||
|
||||
|
||||
class DummyModelMultiOutput(torch.nn.Module):
|
||||
def forward(self, input_tensor):
|
||||
return [input_tensor, input_tensor]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model():
|
||||
return DummyModel(1, 1)
|
||||
return DummyModelSingleTensor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -53,11 +63,11 @@ def test_predict_model_not_training(model):
|
|||
def test_predict_array(model):
|
||||
predictor = TorchPredictor(model=model)
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
data_batch = np.asarray([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
|
||||
assert predictions.flatten().tolist() == [2, 4, 6]
|
||||
|
||||
|
||||
def test_predict_array_with_preprocessor(model, preprocessor):
|
||||
|
@ -67,17 +77,31 @@ def test_predict_array_with_preprocessor(model, preprocessor):
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [4, 8, 12]
|
||||
assert predictions.flatten().tolist() == [4, 8, 12]
|
||||
|
||||
|
||||
def test_predict_dataframe():
|
||||
predictor = TorchPredictor(model=torch.nn.Linear(2, 1, bias=False))
|
||||
predictor = TorchPredictor(model=DummyModelMultiInput())
|
||||
|
||||
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [0.0, 0.0, 0.0]})
|
||||
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [1.0, 2.0, 3.0]})
|
||||
predictions = predictor.predict(data_batch, dtype=torch.float)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0]
|
||||
assert predictions.to_numpy().flatten().tolist() == [1.0, 2.0, 3.0]
|
||||
|
||||
|
||||
def test_predict_multi_output():
|
||||
predictor = TorchPredictor(model=DummyModelMultiOutput())
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch)
|
||||
|
||||
# Model outputs two tensors
|
||||
assert len(predictions) == 2
|
||||
for k, v in predictions.items():
|
||||
# Each tensor is of size 3
|
||||
assert len(v) == 3
|
||||
assert v.flatten().tolist() == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -89,26 +113,13 @@ def test_predict_dataframe():
|
|||
(torch.int64, np.int64),
|
||||
),
|
||||
)
|
||||
def test_predict_array_with_different_dtypes(input_dtype, expected_output_dtype):
|
||||
predictor = TorchPredictor(model=torch.nn.Identity())
|
||||
def test_predict_array_with_different_dtypes(model, input_dtype, expected_output_dtype):
|
||||
predictor = TorchPredictor(model=model)
|
||||
|
||||
data_batch = np.array([[1], [2], [3]])
|
||||
predictions = predictor.predict(data_batch, dtype=input_dtype)
|
||||
|
||||
assert all(
|
||||
prediction.dtype == expected_output_dtype
|
||||
for prediction in predictions["predictions"]
|
||||
)
|
||||
|
||||
|
||||
def test_predict_dataframe_with_feature_columns():
|
||||
predictor = TorchPredictor(model=torch.nn.Identity())
|
||||
|
||||
data_batch = pd.DataFrame({"X0": [0.0, 0.0, 0.0], "X1": [1.0, 1.0, 1.0]})
|
||||
predictions = predictor.predict(data_batch, feature_columns=["X0"])
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [0.0, 0.0, 0.0]
|
||||
assert predictions.dtype == expected_output_dtype
|
||||
|
||||
|
||||
def test_predict_array_no_training(model):
|
||||
|
@ -119,7 +130,36 @@ def test_predict_array_no_training(model):
|
|||
predictions = predictor.predict(data_batch)
|
||||
|
||||
assert len(predictions) == 3
|
||||
assert predictions.to_numpy().flatten().tolist() == [2, 4, 6]
|
||||
assert predictions.flatten().tolist() == [2, 4, 6]
|
||||
|
||||
|
||||
def test_array_real_model():
|
||||
model = torch.nn.Linear(2, 1)
|
||||
predictor = TorchPredictor(model=model)
|
||||
|
||||
data = np.array([[1, 2], [3, 4]])
|
||||
predictions = predictor.predict(data, dtype=torch.float)
|
||||
assert len(predictions) == 2
|
||||
|
||||
|
||||
def test_multi_modal_real_model():
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1, 1)
|
||||
self.linear2 = torch.nn.Linear(1, 1)
|
||||
|
||||
def forward(self, input_dict: dict):
|
||||
out1 = self.linear1(input_dict["A"])
|
||||
out2 = self.linear2(input_dict["B"])
|
||||
return out1 + out2
|
||||
|
||||
predictor = TorchPredictor(model=CustomModule())
|
||||
|
||||
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
||||
|
||||
predictions = predictor.predict(data, dtype=torch.float)
|
||||
assert len(predictions) == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from ray.air._internal.torch_utils import convert_pandas_to_torch_tensor
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.predictor import DataBatchType, Predictor
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.util.data_batch_conversion import convert_pandas_to_batch_type, DataType
|
||||
from ray.air.util.tensor_extensions.pandas import TensorArray
|
||||
from ray.train.torch.utils import load_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -47,81 +48,79 @@ class TorchPredictor(Predictor):
|
|||
model, preprocessor = load_checkpoint(checkpoint, model)
|
||||
return TorchPredictor(model=model, preprocessor=preprocessor)
|
||||
|
||||
# parity with Datset.to_torch
|
||||
def _convert_to_tensor(
|
||||
def _predict_pandas(
|
||||
self,
|
||||
data: pd.DataFrame,
|
||||
feature_columns: Optional[
|
||||
Union[List[str], List[List[str]], List[int], List[List[int]]]
|
||||
] = None,
|
||||
dtypes: Optional[torch.dtype] = None,
|
||||
unsqueeze: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Handle conversion of data to tensor.
|
||||
dtype: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
|
||||
) -> pd.DataFrame:
|
||||
def tensorize(numpy_array, dtype):
|
||||
torch_tensor = torch.from_numpy(numpy_array).to(dtype)
|
||||
|
||||
# Off-the-shelf torch Modules expect the input size to have at least 2
|
||||
# dimensions (batch_size, feature_size). If the tensor for the column
|
||||
# is flattened, then we unqueeze it to add an extra dimension.
|
||||
if len(torch_tensor.size()) == 1:
|
||||
torch_tensor = torch_tensor.unsqueeze(dim=1)
|
||||
|
||||
return torch_tensor
|
||||
|
||||
tensors = convert_pandas_to_batch_type(data, DataType.NUMPY)
|
||||
|
||||
# Single numpy array.
|
||||
if isinstance(tensors, np.ndarray):
|
||||
column_name = data.columns[0]
|
||||
if isinstance(dtype, dict):
|
||||
dtype = dtype[column_name]
|
||||
model_input = tensorize(tensors, dtype)
|
||||
|
||||
Same arguments as in ``convert_pandas_to_torch_tensor``."""
|
||||
# TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type.
|
||||
# Reduce conversion cost if input is in Numpy
|
||||
if isinstance(feature_columns, dict):
|
||||
features_tensor = {
|
||||
key: convert_pandas_to_torch_tensor(
|
||||
data,
|
||||
feature_columns[key],
|
||||
dtypes[key] if isinstance(dtypes, dict) else dtypes,
|
||||
unsqueeze=unsqueeze,
|
||||
)
|
||||
for key in feature_columns
|
||||
}
|
||||
else:
|
||||
features_tensor = convert_pandas_to_torch_tensor(
|
||||
data,
|
||||
columns=feature_columns,
|
||||
column_dtypes=dtypes,
|
||||
unsqueeze=unsqueeze,
|
||||
)
|
||||
return features_tensor
|
||||
model_input = {
|
||||
k: tensorize(v, dtype=dtype[k] if isinstance(dtype, dict) else dtype)
|
||||
for k, v in tensors.items()
|
||||
}
|
||||
|
||||
def _predict(self, tensor: torch.Tensor) -> pd.DataFrame:
|
||||
"""Handle actual prediction."""
|
||||
prediction = self.model(tensor).cpu().detach().numpy()
|
||||
# If model has outputs a Numpy array (for example outputting logits),
|
||||
# these cannot be used as values in a Pandas Dataframe.
|
||||
# We have to convert the outermost dimension to a python list (but the values
|
||||
# in the list can still be Numpy arrays).
|
||||
return pd.DataFrame({"predictions": list(prediction)}, columns=["predictions"])
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
output = self.model(model_input)
|
||||
|
||||
def untensorize(torch_tensor):
|
||||
numpy_array = torch_tensor.cpu().detach().numpy()
|
||||
return TensorArray(numpy_array)
|
||||
|
||||
# Handle model multi-output. For example if model outputs 2 images.
|
||||
if isinstance(output, dict):
|
||||
return pd.DataFrame({k: untensorize(v) for k, v in output})
|
||||
elif isinstance(output, list) or isinstance(output, tuple):
|
||||
tensor_name = "output_"
|
||||
output_dict = {}
|
||||
for i in range(len(output)):
|
||||
output_dict[tensor_name + str(i + 1)] = untensorize(output[i])
|
||||
return pd.DataFrame(output_dict)
|
||||
else:
|
||||
return pd.DataFrame(
|
||||
{"predictions": untensorize(output)}, columns=["predictions"]
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
data: DataBatchType,
|
||||
feature_columns: Optional[
|
||||
Union[List[str], List[List[str]], List[int], List[List[int]]]
|
||||
] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
unsqueeze: bool = True,
|
||||
dtype: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
|
||||
) -> DataBatchType:
|
||||
"""Run inference on data batch.
|
||||
|
||||
The data is converted into a torch Tensor before being inputted to
|
||||
the model.
|
||||
If the provided data is a single array or a dataframe/table with a single
|
||||
column, it will be converted into a single PyTorch tensor before being
|
||||
inputted to the model.
|
||||
|
||||
If the provided data is a multi-column table or a dict of numpy arrays,
|
||||
it will be converted into a dict of tensors before being inputted to the
|
||||
model. This is useful for multi-modal inputs (for example your model accepts
|
||||
both image and text).
|
||||
|
||||
Args:
|
||||
data: A batch of input data. Either a pandas DataFrame or numpy
|
||||
array.
|
||||
feature_columns: The names or indices of the columns in the
|
||||
data to use as features to predict on. If this arg is a
|
||||
list of lists or a dict of string-list pairs, then the
|
||||
data batch will be converted into a
|
||||
multiple tensors which are then concatenated before feeding
|
||||
into the model. This is useful for multi-input models. If
|
||||
None, then use all columns in ``data``.
|
||||
dtype: The dtypes to use for the tensors. This should match the
|
||||
format of ``feature_columns``, or be a single dtype, in which
|
||||
case it will be applied to all tensors.
|
||||
If None, then automatically infer the dtype.
|
||||
unsqueeze: If set to True, the features tensors will be unsqueezed
|
||||
(reshaped to (N, 1)) before being concatenated into the final features
|
||||
tensor. Otherwise, they will be left as is, that is (N, ).
|
||||
Defaults to True.
|
||||
data: A batch of input data of ``DataBatchType``.
|
||||
dtype: The dtypes to use for the tensors. Either a single dtype for all
|
||||
tensors or a mapping from column name to dtype.
|
||||
|
||||
Examples:
|
||||
|
||||
|
@ -135,7 +134,7 @@ class TorchPredictor(Predictor):
|
|||
predictor = TorchPredictor(model=model)
|
||||
|
||||
data = np.array([[1, 2], [3, 4]])
|
||||
predictions = predictor.predict(data)
|
||||
predictions = predictor.predict(data, dtype=torch.float)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -143,30 +142,25 @@ class TorchPredictor(Predictor):
|
|||
import torch
|
||||
from ray.train.torch import TorchPredictor
|
||||
|
||||
model = torch.nn.Linear(1, 1)
|
||||
predictor = TorchPredictor(model=model)
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1, 1)
|
||||
self.linear2 = torch.nn.Linear(1, 1)
|
||||
|
||||
def forward(self, input_dict: dict):
|
||||
out1 = self.linear1(input_dict["A"])
|
||||
out2 = self.linear2(input_dict["B"])
|
||||
return out1 + out2
|
||||
|
||||
predictor = TorchPredictor(model=CustomModule())
|
||||
|
||||
# Pandas dataframe.
|
||||
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
|
||||
|
||||
predictions = predictor.predict(data)
|
||||
|
||||
# Only use first column as the feature
|
||||
predictions = predictor.predict(data, feature_columns=["A"])
|
||||
|
||||
Returns:
|
||||
DataBatchType: Prediction result.
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
if self.preprocessor:
|
||||
data = self.preprocessor.transform_batch(data)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
tensor = torch.tensor(data, dtype=dtype)
|
||||
else:
|
||||
tensor = self._convert_to_tensor(
|
||||
data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze
|
||||
)
|
||||
|
||||
return self._predict(tensor)
|
||||
return super(TorchPredictor, self).predict(data=data, dtype=dtype)
|
||||
|
|
Loading…
Add table
Reference in a new issue