mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Annotate more api (#26501)
This commit is contained in:
parent
12d038b0e2
commit
8ca5584b9f
15 changed files with 35 additions and 0 deletions
|
@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
|
||||||
|
|
||||||
from ray.train.data_parallel_trainer import DataParallelTrainer
|
from ray.train.data_parallel_trainer import DataParallelTrainer
|
||||||
from ray.train.horovod.config import HorovodConfig
|
from ray.train.horovod.config import HorovodConfig
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class HorovodTrainer(DataParallelTrainer):
|
class HorovodTrainer(DataParallelTrainer):
|
||||||
"""A Trainer for data parallel Horovod training.
|
"""A Trainer for data parallel Horovod training.
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,13 @@ from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||||
from ray.train.predictor import Predictor
|
from ray.train.predictor import Predictor
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class HuggingFacePredictor(Predictor):
|
class HuggingFacePredictor(Predictor):
|
||||||
"""A predictor for HuggingFace Transformers PyTorch models.
|
"""A predictor for HuggingFace Transformers PyTorch models.
|
||||||
|
|
||||||
|
|
|
@ -13,11 +13,13 @@ from ray.air._internal.checkpointing import (
|
||||||
)
|
)
|
||||||
from ray.air._internal.torch_utils import load_torch_model
|
from ray.air._internal.torch_utils import load_torch_model
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
|
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
|
||||||
|
|
|
@ -8,11 +8,13 @@ from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||||
from ray.train.lightgbm.utils import load_checkpoint
|
from ray.train.lightgbm.utils import load_checkpoint
|
||||||
from ray.train.predictor import Predictor
|
from ray.train.predictor import Predictor
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class LightGBMPredictor(Predictor):
|
class LightGBMPredictor(Predictor):
|
||||||
"""A predictor for LightGBM models.
|
"""A predictor for LightGBM models.
|
||||||
|
|
||||||
|
|
|
@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
|
||||||
)
|
)
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import MODEL_KEY
|
from ray.air.constants import MODEL_KEY
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def to_air_checkpoint(
|
def to_air_checkpoint(
|
||||||
path: str,
|
path: str,
|
||||||
booster: lightgbm.Booster,
|
booster: lightgbm.Booster,
|
||||||
|
@ -39,6 +41,7 @@ def to_air_checkpoint(
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
|
||||||
|
|
|
@ -9,11 +9,13 @@ from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.typing import EnvType
|
from ray.rllib.utils.typing import EnvType
|
||||||
from ray.train.predictor import Predictor
|
from ray.train.predictor import Predictor
|
||||||
from ray.train.rl.utils import load_checkpoint
|
from ray.train.rl.utils import load_checkpoint
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class RLPredictor(Predictor):
|
class RLPredictor(Predictor):
|
||||||
"""A predictor for RLlib policies.
|
"""A predictor for RLlib policies.
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from ray.air._internal.checkpointing import (
|
||||||
)
|
)
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.typing import EnvType
|
from ray.rllib.utils.typing import EnvType
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
@ -16,6 +17,7 @@ RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
|
||||||
RL_CONFIG_FILE = "config.pkl"
|
RL_CONFIG_FILE = "config.pkl"
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
env: Optional[EnvType] = None,
|
env: Optional[EnvType] = None,
|
||||||
|
|
|
@ -10,11 +10,13 @@ from ray.train.predictor import Predictor
|
||||||
from ray.train.sklearn._sklearn_utils import _set_cpu_params
|
from ray.train.sklearn._sklearn_utils import _set_cpu_params
|
||||||
from ray.train.sklearn.utils import load_checkpoint
|
from ray.train.sklearn.utils import load_checkpoint
|
||||||
from ray.util.joblib import register_ray
|
from ray.util.joblib import register_ray
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class SklearnPredictor(Predictor):
|
class SklearnPredictor(Predictor):
|
||||||
"""A predictor for scikit-learn compatible estimators.
|
"""A predictor for scikit-learn compatible estimators.
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,13 @@ from ray.air._internal.checkpointing import (
|
||||||
)
|
)
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import MODEL_KEY
|
from ray.air.constants import MODEL_KEY
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def to_air_checkpoint(
|
def to_air_checkpoint(
|
||||||
path: str,
|
path: str,
|
||||||
estimator: BaseEstimator,
|
estimator: BaseEstimator,
|
||||||
|
@ -41,6 +43,7 @@ def to_air_checkpoint(
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
|
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||||
from ray.train._internal.dl_predictor import DLPredictor
|
from ray.train._internal.dl_predictor import DLPredictor
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
@ -17,6 +18,7 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class TensorflowPredictor(DLPredictor):
|
class TensorflowPredictor(DLPredictor):
|
||||||
"""A predictor for TensorFlow models.
|
"""A predictor for TensorFlow models.
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,13 @@ from tensorflow import keras
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def to_air_checkpoint(
|
def to_air_checkpoint(
|
||||||
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
|
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
|
@ -29,6 +31,7 @@ def to_air_checkpoint(
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
|
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ray.train.predictor import DataBatchType
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.train.torch.utils import load_checkpoint
|
from ray.train.torch.utils import load_checkpoint
|
||||||
from ray.train._internal.dl_predictor import DLPredictor
|
from ray.train._internal.dl_predictor import DLPredictor
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
@ -16,6 +17,7 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class TorchPredictor(DLPredictor):
|
class TorchPredictor(DLPredictor):
|
||||||
"""A predictor for PyTorch models.
|
"""A predictor for PyTorch models.
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,13 @@ from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||||
from ray.train.data_parallel_trainer import _load_checkpoint
|
from ray.train.data_parallel_trainer import _load_checkpoint
|
||||||
from ray.air._internal.torch_utils import load_torch_model
|
from ray.air._internal.torch_utils import load_torch_model
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def to_air_checkpoint(
|
def to_air_checkpoint(
|
||||||
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
|
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
|
||||||
) -> Checkpoint:
|
) -> Checkpoint:
|
||||||
|
@ -29,6 +31,7 @@ def to_air_checkpoint(
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
|
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
|
||||||
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
|
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
|
||||||
|
|
|
@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
|
||||||
)
|
)
|
||||||
from ray.air.checkpoint import Checkpoint
|
from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import MODEL_KEY
|
from ray.air.constants import MODEL_KEY
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def to_air_checkpoint(
|
def to_air_checkpoint(
|
||||||
path: str,
|
path: str,
|
||||||
booster: xgboost.Booster,
|
booster: xgboost.Booster,
|
||||||
|
@ -39,6 +41,7 @@ def to_air_checkpoint(
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
checkpoint: Checkpoint,
|
checkpoint: Checkpoint,
|
||||||
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
|
||||||
|
|
|
@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
|
||||||
from ray.air.constants import TENSOR_COLUMN_NAME
|
from ray.air.constants import TENSOR_COLUMN_NAME
|
||||||
from ray.train.predictor import Predictor
|
from ray.train.predictor import Predictor
|
||||||
from ray.train.xgboost.utils import load_checkpoint
|
from ray.train.xgboost.utils import load_checkpoint
|
||||||
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
class XGBoostPredictor(Predictor):
|
class XGBoostPredictor(Predictor):
|
||||||
"""A predictor for XGBoost models.
|
"""A predictor for XGBoost models.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue