Annotate more api (#26501)

This commit is contained in:
Amog Kamsetty 2022-07-12 22:29:14 -07:00 committed by GitHub
parent 12d038b0e2
commit 8ca5584b9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 35 additions and 0 deletions

View file

@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.horovod.config import HorovodConfig
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class HorovodTrainer(DataParallelTrainer):
"""A Trainer for data parallel Horovod training.

View file

@ -11,11 +11,13 @@ from ray.air._internal.checkpointing import load_preprocessor_from_dir
from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class HuggingFacePredictor(Predictor):
"""A predictor for HuggingFace Transformers PyTorch models.

View file

@ -13,11 +13,13 @@ from ray.air._internal.checkpointing import (
)
from ray.air._internal.torch_utils import load_torch_model
from ray.air.checkpoint import Checkpoint
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],

View file

@ -8,11 +8,13 @@ from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.lightgbm.utils import load_checkpoint
from ray.train.predictor import Predictor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class LightGBMPredictor(Predictor):
"""A predictor for LightGBM models.

View file

@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: lightgbm.Booster,
@ -39,6 +41,7 @@ def to_air_checkpoint(
return checkpoint
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:

View file

@ -9,11 +9,13 @@ from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import EnvType
from ray.train.predictor import Predictor
from ray.train.rl.utils import load_checkpoint
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class RLPredictor(Predictor):
"""A predictor for RLlib policies.

View file

@ -8,6 +8,7 @@ from ray.air._internal.checkpointing import (
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import EnvType
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -16,6 +17,7 @@ RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
RL_CONFIG_FILE = "config.pkl"
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
env: Optional[EnvType] = None,

View file

@ -10,11 +10,13 @@ from ray.train.predictor import Predictor
from ray.train.sklearn._sklearn_utils import _set_cpu_params
from ray.train.sklearn.utils import load_checkpoint
from ray.util.joblib import register_ray
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class SklearnPredictor(Predictor):
"""A predictor for scikit-learn compatible estimators.

View file

@ -10,11 +10,13 @@ from ray.air._internal.checkpointing import (
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
estimator: BaseEstimator,
@ -41,6 +43,7 @@ def to_air_checkpoint(
return checkpoint
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:

View file

@ -10,6 +10,7 @@ from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
from ray.air.checkpoint import Checkpoint
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -17,6 +18,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@PublicAPI(stability="alpha")
class TensorflowPredictor(DLPredictor):
"""A predictor for TensorFlow models.

View file

@ -6,11 +6,13 @@ from tensorflow import keras
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
@ -29,6 +31,7 @@ def to_air_checkpoint(
return checkpoint
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],

View file

@ -9,6 +9,7 @@ from ray.train.predictor import DataBatchType
from ray.air.checkpoint import Checkpoint
from ray.train.torch.utils import load_checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -16,6 +17,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@PublicAPI(stability="alpha")
class TorchPredictor(DLPredictor):
"""A predictor for PyTorch models.

View file

@ -6,11 +6,13 @@ from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.air._internal.torch_utils import load_torch_model
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
@ -29,6 +31,7 @@ def to_air_checkpoint(
return checkpoint
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:

View file

@ -9,11 +9,13 @@ from ray.air._internal.checkpointing import (
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: xgboost.Booster,
@ -39,6 +41,7 @@ def to_air_checkpoint(
return checkpoint
@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:

View file

@ -7,11 +7,13 @@ from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.train.xgboost.utils import load_checkpoint
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@PublicAPI(stability="alpha")
class XGBoostPredictor(Predictor):
"""A predictor for XGBoost models.