mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air] Update to beta (#27393)
Update API references to beta. Needed as we are going to beta in 2.0. I left out RL/Scikit-Learn/HuggingFace.
This commit is contained in:
parent
4d87e8112a
commit
6dc3dbdd37
23 changed files with 30 additions and 30 deletions
|
@ -38,7 +38,7 @@ _CHECKPOINT_DIR_PREFIX = "checkpoint_tmp_"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class Checkpoint:
|
||||
"""Ray AIR Checkpoint.
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) ->
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class ScalingConfig:
|
||||
"""Configuration for scaling training.
|
||||
|
||||
|
@ -264,7 +264,7 @@ class ScalingConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class DatasetConfig:
|
||||
"""Configuration for ingest of a single Dataset.
|
||||
|
||||
|
@ -428,7 +428,7 @@ class DatasetConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class FailureConfig:
|
||||
"""Configuration related to failure handling of each run/trial.
|
||||
|
||||
|
@ -463,7 +463,7 @@ class FailureConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class CheckpointConfig:
|
||||
"""Configurable parameters for defining the checkpointing strategy.
|
||||
|
||||
|
@ -543,7 +543,7 @@ class CheckpointConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class RunConfig:
|
||||
"""Runtime configuration for individual trials that are run.
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import pandas as pd
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class Result:
|
||||
"""The final result of a ML training run or a Tune trial.
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@ if TYPE_CHECKING:
|
|||
from ray.air.data_batch_type import DataBatchType
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class PreprocessorNotFittedException(RuntimeError):
|
||||
"""Error raised when the preprocessor needs to be fitted first."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class Preprocessor(abc.ABC):
|
||||
"""Implements an ML preprocessing operation.
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ GenDataset = Union["Dataset", Callable[[], "Dataset"]]
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TrainingFailedError(RuntimeError):
|
||||
"""An error indicating that training has failed."""
|
||||
|
||||
|
@ -322,7 +322,7 @@ class BaseTrainer(abc.ABC):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
def fit(self) -> Result:
|
||||
"""Runs training.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from ray.train.predictor import Predictor
|
|||
from ray.util.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class BatchPredictor:
|
||||
"""Batch predictor class.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class HorovodTrainer(DataParallelTrainer):
|
||||
"""A Trainer for data parallel Horovod training.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class LightGBMCheckpoint(Checkpoint):
|
||||
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with LightGBM-specific
|
||||
functionality.
|
||||
|
|
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class LightGBMPredictor(Predictor):
|
||||
"""A predictor for LightGBM models.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class LightGBMTrainer(GBDTTrainer):
|
||||
"""A Trainer for data parallel LightGBM training.
|
||||
|
||||
|
|
|
@ -29,14 +29,14 @@ TYPE_TO_ENUM: Dict[Type[DataBatchType], DataType] = {
|
|||
}
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class PredictorNotSerializableException(RuntimeError):
|
||||
"""Error raised when trying to serialize a Predictor instance."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class Predictor(abc.ABC):
|
||||
"""Predictors load models from checkpoints to perform inference.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TensorflowCheckpoint(Checkpoint):
|
||||
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with TensorFlow-specific
|
||||
functionality.
|
||||
|
|
|
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TensorflowPredictor(DLPredictor):
|
||||
"""A predictor for TensorFlow models.
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TensorflowTrainer(DataParallelTrainer):
|
||||
"""A Trainer for data parallel Tensorflow training.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TorchCheckpoint(Checkpoint):
|
||||
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with Torch-specific
|
||||
functionality.
|
||||
|
|
|
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TorchPredictor(DLPredictor):
|
||||
"""A predictor for PyTorch models.
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TorchTrainer(DataParallelTrainer):
|
||||
"""A Trainer for data parallel PyTorch training.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class XGBoostCheckpoint(Checkpoint):
|
||||
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with XGBoost-specific
|
||||
functionality.
|
||||
|
|
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class XGBoostPredictor(Predictor):
|
||||
"""A predictor for XGBoost models.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
from ray.data.preprocessor import Preprocessor
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class XGBoostTrainer(GBDTTrainer):
|
||||
"""A Trainer for data parallel XGBoost training.
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def generate_variants(
|
|||
yield resolved_vars, spec
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
def grid_search(values: Iterable) -> Dict[str, List]:
|
||||
"""Convenience method for specifying grid search over a value.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.util import PublicAPI
|
|||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class TuneConfig:
|
||||
"""Tune specific configs.
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ _TUNER_INTERNAL = "_tuner_internal"
|
|||
_SELF = "self"
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@PublicAPI(stability="beta")
|
||||
class Tuner:
|
||||
"""Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue