mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air] Better exception handling (#23695)
What: Raise meaningful exceptions when invalid parameters are passed. Why: We want to catch invalid parameters and guide users to use the API in the correct way.
This commit is contained in:
parent
252596af58
commit
fb50e0a70b
10 changed files with 218 additions and 10 deletions
|
@ -113,6 +113,14 @@ py_test(
|
|||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_api",
|
||||
size = "small",
|
||||
srcs = ["tests/test_api.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_checkpoints",
|
||||
size = "small",
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.config import RunConfig, ScalingConfig
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
|
||||
__all__ = ["Checkpoint", "Preprocessor", "RunConfig", "ScalingConfig"]
|
|
@ -1,10 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
|
||||
from ray.util import PublicAPI
|
||||
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
from ray.tune.callback import Callback
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
from ray.tune.callback import Callback
|
||||
|
||||
|
||||
ScalingConfig = Dict[str, Any]
|
||||
|
@ -81,8 +82,10 @@ class ScalingConfigDataClass:
|
|||
if k not in ["CPU", "GPU"]
|
||||
}
|
||||
|
||||
def as_placement_group_factory(self) -> PlacementGroupFactory:
|
||||
def as_placement_group_factory(self) -> "PlacementGroupFactory":
|
||||
"""Returns a PlacementGroupFactory to specify resources for Tune."""
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
|
||||
trainer_resources = (
|
||||
self.trainer_resources if self.trainer_resources else {"CPU": 1}
|
||||
)
|
||||
|
@ -144,5 +147,5 @@ class RunConfig:
|
|||
# TODO(xwjiang): Add more.
|
||||
name: Optional[str] = None
|
||||
local_dir: Optional[str] = None
|
||||
callbacks: Optional[List[Callback]] = None
|
||||
callbacks: Optional[List["Callback"]] = None
|
||||
failure: Optional[FailureConfig] = None
|
||||
|
|
119
python/ray/ml/tests/test_api.py
Normal file
119
python/ray/ml/tests/test_api.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.ml import Checkpoint
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
|
||||
|
||||
class DummyTrainer(Trainer):
|
||||
def training_loop(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyDataset(ray.data.Dataset):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_run_config():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(run_config="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(run_config=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(run_config=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(run_config={})
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(run_config=None)
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(run_config=ray.ml.RunConfig())
|
||||
|
||||
|
||||
def test_scaling_config():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(scaling_config="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(scaling_config=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(scaling_config=True)
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(scaling_config={})
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(scaling_config=None)
|
||||
|
||||
|
||||
def test_datasets():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(datasets="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(datasets=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(datasets=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(datasets={"test": "invalid"})
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(datasets=None)
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(datasets={"test": DummyDataset()})
|
||||
|
||||
|
||||
def test_preprocessor():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(preprocessor="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(preprocessor=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(preprocessor=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(preprocessor={})
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(preprocessor=None)
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(preprocessor=Preprocessor())
|
||||
|
||||
|
||||
def test_resume_from_checkpoint():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(resume_from_checkpoint="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(resume_from_checkpoint=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(resume_from_checkpoint=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(resume_from_checkpoint={})
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(resume_from_checkpoint=None)
|
||||
|
||||
# Succeed
|
||||
DummyTrainer(resume_from_checkpoint=Checkpoint.from_dict({"empty": ""}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
|
@ -183,6 +183,7 @@ class DataParallelTrainer(Trainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
||||
train_loop_config: Optional[Dict] = None,
|
||||
backend_config: Optional[BackendConfig] = None,
|
||||
|
@ -198,6 +199,11 @@ class DataParallelTrainer(Trainer):
|
|||
self.train_loop_per_worker = train_loop_per_worker
|
||||
self.train_loop_config = train_loop_config
|
||||
|
||||
backend_config = (
|
||||
backend_config if backend_config is not None else BackendConfig()
|
||||
)
|
||||
self.backend_config = backend_config
|
||||
|
||||
super(DataParallelTrainer, self).__init__(
|
||||
scaling_config=scaling_config,
|
||||
run_config=run_config,
|
||||
|
@ -206,6 +212,9 @@ class DataParallelTrainer(Trainer):
|
|||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
||||
def _validate_attributes(self):
|
||||
super()._validate_attributes()
|
||||
|
||||
if (
|
||||
not self.scaling_config.get("use_gpu", False)
|
||||
and "GPU" in ray.available_resources()
|
||||
|
|
|
@ -72,6 +72,7 @@ class GBDTTrainer(Trainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
datasets: Dict[str, GenDataset],
|
||||
label_column: str,
|
||||
params: Dict[str, Any],
|
||||
|
@ -87,8 +88,15 @@ class GBDTTrainer(Trainer):
|
|||
self.dmatrix_params = dmatrix_params or {}
|
||||
self.train_kwargs = train_kwargs
|
||||
super().__init__(
|
||||
scaling_config, run_config, datasets, preprocessor, resume_from_checkpoint
|
||||
scaling_config=scaling_config,
|
||||
run_config=run_config,
|
||||
datasets=datasets,
|
||||
preprocessor=preprocessor,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
||||
def _validate_attributes(self):
|
||||
super()._validate_attributes()
|
||||
self._validate_config_and_datasets()
|
||||
|
||||
def _validate_config_and_datasets(self) -> None:
|
||||
|
|
|
@ -164,6 +164,7 @@ class HorovodTrainer(DataParallelTrainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
||||
train_loop_config: Optional[Dict] = None,
|
||||
horovod_config: Optional[HorovodConfig] = None,
|
||||
|
@ -174,7 +175,7 @@ class HorovodTrainer(DataParallelTrainer):
|
|||
resume_from_checkpoint: Optional[Checkpoint] = None,
|
||||
):
|
||||
super().__init__(
|
||||
train_loop_per_worker,
|
||||
train_loop_per_worker=train_loop_per_worker,
|
||||
train_loop_config=train_loop_config,
|
||||
backend_config=horovod_config or HorovodConfig(),
|
||||
scaling_config=scaling_config,
|
||||
|
|
|
@ -156,6 +156,7 @@ class TensorflowTrainer(DataParallelTrainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
||||
train_loop_config: Optional[Dict] = None,
|
||||
tensorflow_config: Optional[TensorflowConfig] = None,
|
||||
|
|
|
@ -165,6 +165,7 @@ class TorchTrainer(DataParallelTrainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
|
||||
train_loop_config: Optional[Dict] = None,
|
||||
torch_config: Optional[TorchConfig] = None,
|
||||
|
|
|
@ -3,6 +3,8 @@ import inspect
|
|||
import logging
|
||||
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
|
||||
|
||||
import ray
|
||||
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.result import Result
|
||||
|
@ -133,6 +135,7 @@ class Trainer(abc.ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scaling_config: Optional[ScalingConfig] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
datasets: Optional[Dict[str, GenDataset]] = None,
|
||||
|
@ -140,12 +143,14 @@ class Trainer(abc.ABC):
|
|||
resume_from_checkpoint: Optional[Checkpoint] = None,
|
||||
):
|
||||
|
||||
self.scaling_config = scaling_config if scaling_config else {}
|
||||
self.run_config = run_config if run_config else RunConfig()
|
||||
self.datasets = datasets if datasets else {}
|
||||
self.scaling_config = scaling_config if scaling_config is not None else {}
|
||||
self.run_config = run_config if run_config is not None else RunConfig()
|
||||
self.datasets = datasets if datasets is not None else {}
|
||||
self.preprocessor = preprocessor
|
||||
self.resume_from_checkpoint = resume_from_checkpoint
|
||||
|
||||
self._validate_attributes()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Store the init args as attributes so this can be merged with Tune hparams."""
|
||||
trainer = super(Trainer, cls).__new__(cls)
|
||||
|
@ -157,6 +162,54 @@ class Trainer(abc.ABC):
|
|||
trainer._param_dict = {**arg_dict, **kwargs}
|
||||
return trainer
|
||||
|
||||
def _validate_attributes(self):
|
||||
"""Called on __init()__ to validate trainer attributes."""
|
||||
# Run config
|
||||
if not isinstance(self.run_config, RunConfig):
|
||||
raise ValueError(
|
||||
f"`run_config` should be an instance of `ray.ml.RunConfig`, "
|
||||
f"found {type(self.run_config)} with value `{self.run_config}`."
|
||||
)
|
||||
# Scaling config
|
||||
# Todo: move to ray.ml.ScalingConfig
|
||||
if not isinstance(self.scaling_config, dict):
|
||||
raise ValueError(
|
||||
f"`scaling_config` should be an instance of `dict`, "
|
||||
f"found {type(self.run_config)} with value `{self.run_config}`."
|
||||
)
|
||||
# Datasets
|
||||
if not isinstance(self.datasets, dict):
|
||||
raise ValueError(
|
||||
f"`datasets` should be a dict mapping from a string to "
|
||||
f"`ray.data.Dataset` objects, "
|
||||
f"found {type(self.datasets)} with value `{self.datasets}`."
|
||||
)
|
||||
elif any(
|
||||
not isinstance(ds, ray.data.Dataset) and not callable(ds)
|
||||
for ds in self.datasets.values()
|
||||
):
|
||||
raise ValueError(
|
||||
f"At least one value in the `datasets` dict is not a "
|
||||
f"`ray.data.Dataset`: {self.datasets}"
|
||||
)
|
||||
# Preprocessor
|
||||
if self.preprocessor is not None and not isinstance(
|
||||
self.preprocessor, ray.ml.preprocessor.Preprocessor
|
||||
):
|
||||
raise ValueError(
|
||||
f"`preprocessor` should be an instance of `ray.ml.Preprocessor`, "
|
||||
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
|
||||
)
|
||||
|
||||
if self.resume_from_checkpoint is not None and not isinstance(
|
||||
self.resume_from_checkpoint, ray.ml.Checkpoint
|
||||
):
|
||||
raise ValueError(
|
||||
f"`resume_from_checkpoint` should be an instance of "
|
||||
f"`ray.ml.Checkpoint`, found {type(self.resume_from_checkpoint)} "
|
||||
f"with value `{self.resume_from_checkpoint}`."
|
||||
)
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Called during fit() to perform initial setup on the Trainer.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue