mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Annotate more api (#26501)
This commit is contained in:
parent
12d038b0e2
commit
8ca5584b9f
15 changed files with 35 additions and 0 deletions
|
@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
|
|||
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer
|
||||
from ray.train.horovod.config import HorovodConfig
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class HorovodTrainer(DataParallelTrainer):
|
||||
"""A Trainer for data parallel Horovod training.
|
||||
|
||||
|
|
|
@ -11,11 +11,13 @@ from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class HuggingFacePredictor(Predictor):
|
||||
"""A predictor for HuggingFace Transformers PyTorch models.
|
||||
|
||||
|
|
|
@ -13,11 +13,13 @@ from ray.air._internal.checkpointing import (
|
|||
)
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
|
||||
|
|
|
@ -8,11 +8,13 @@ from ray.air.checkpoint import Checkpoint
|
|||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.lightgbm.utils import load_checkpoint
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class LightGBMPredictor(Predictor):
|
||||
"""A predictor for LightGBM models.
|
||||
|
||||
|
|
|
@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
|
|||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def to_air_checkpoint(
|
||||
path: str,
|
||||
booster: lightgbm.Booster,
|
||||
|
@ -39,6 +41,7 @@ def to_air_checkpoint(
|
|||
return checkpoint
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
||||
|
|
|
@ -9,11 +9,13 @@ from ray.rllib.policy.policy import Policy
|
|||
from ray.rllib.utils.typing import EnvType
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.rl.utils import load_checkpoint
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class RLPredictor(Predictor):
|
||||
"""A predictor for RLlib policies.
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from ray.air._internal.checkpointing import (
|
|||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
@ -16,6 +17,7 @@ RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
|
|||
RL_CONFIG_FILE = "config.pkl"
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
env: Optional[EnvType] = None,
|
||||
|
|
|
@ -10,11 +10,13 @@ from ray.train.predictor import Predictor
|
|||
from ray.train.sklearn._sklearn_utils import _set_cpu_params
|
||||
from ray.train.sklearn.utils import load_checkpoint
|
||||
from ray.util.joblib import register_ray
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class SklearnPredictor(Predictor):
|
||||
"""A predictor for scikit-learn compatible estimators.
|
||||
|
||||
|
|
|
@ -10,11 +10,13 @@ from ray.air._internal.checkpointing import (
|
|||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def to_air_checkpoint(
|
||||
path: str,
|
||||
estimator: BaseEstimator,
|
||||
|
@ -41,6 +43,7 @@ def to_air_checkpoint(
|
|||
return checkpoint
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
|
||||
|
|
|
@ -10,6 +10,7 @@ from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||
from ray.train._internal.dl_predictor import DLPredictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
@ -17,6 +18,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class TensorflowPredictor(DLPredictor):
|
||||
"""A predictor for TensorFlow models.
|
||||
|
||||
|
|
|
@ -6,11 +6,13 @@ from tensorflow import keras
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def to_air_checkpoint(
|
||||
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
|
||||
) -> Checkpoint:
|
||||
|
@ -29,6 +31,7 @@ def to_air_checkpoint(
|
|||
return checkpoint
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
|
||||
|
|
|
@ -9,6 +9,7 @@ from ray.train.predictor import DataBatchType
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.torch.utils import load_checkpoint
|
||||
from ray.train._internal.dl_predictor import DLPredictor
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
@ -16,6 +17,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class TorchPredictor(DLPredictor):
|
||||
"""A predictor for PyTorch models.
|
||||
|
||||
|
|
|
@ -6,11 +6,13 @@ from ray.air.checkpoint import Checkpoint
|
|||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def to_air_checkpoint(
|
||||
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
|
||||
) -> Checkpoint:
|
||||
|
@ -29,6 +31,7 @@ def to_air_checkpoint(
|
|||
return checkpoint
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
|
||||
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
|
||||
|
|
|
@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
|
|||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.constants import MODEL_KEY
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def to_air_checkpoint(
|
||||
path: str,
|
||||
booster: xgboost.Booster,
|
||||
|
@ -39,6 +41,7 @@ def to_air_checkpoint(
|
|||
return checkpoint
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def load_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
||||
|
|
|
@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
|
|||
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||
from ray.train.predictor import Predictor
|
||||
from ray.train.xgboost.utils import load_checkpoint
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class XGBoostPredictor(Predictor):
|
||||
"""A predictor for XGBoost models.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue