[air] Better exception handling (#23695)

What: Raise meaningful exceptions when invalid parameters are passed.
Why: We want to catch invalid parameters and guide users to use the API in the correct way.
This commit is contained in:
Kai Fricke 2022-04-05 19:11:55 -07:00 committed by GitHub
parent 252596af58
commit fb50e0a70b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 218 additions and 10 deletions

View file

@ -113,6 +113,14 @@ py_test(
# Please keep these sorted alphabetically.
# --------------------------------------------------------------------
py_test(
name = "test_api",
size = "small",
srcs = ["tests/test_api.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test(
name = "test_checkpoints",
size = "small",

View file

@ -0,0 +1,5 @@
from ray.ml.checkpoint import Checkpoint
from ray.ml.config import RunConfig, ScalingConfig
from ray.ml.preprocessor import Preprocessor
__all__ = ["Checkpoint", "Preprocessor", "RunConfig", "ScalingConfig"]

View file

@ -1,10 +1,11 @@
from dataclasses import dataclass
from typing import Dict, Any, Optional, List
from typing import Dict, Any, Optional, List, TYPE_CHECKING
from ray.util import PublicAPI
from ray.tune.trainable import PlacementGroupFactory
from ray.tune.callback import Callback
if TYPE_CHECKING:
from ray.tune.trainable import PlacementGroupFactory
from ray.tune.callback import Callback
ScalingConfig = Dict[str, Any]
@ -81,8 +82,10 @@ class ScalingConfigDataClass:
if k not in ["CPU", "GPU"]
}
def as_placement_group_factory(self) -> PlacementGroupFactory:
def as_placement_group_factory(self) -> "PlacementGroupFactory":
"""Returns a PlacementGroupFactory to specify resources for Tune."""
from ray.tune.trainable import PlacementGroupFactory
trainer_resources = (
self.trainer_resources if self.trainer_resources else {"CPU": 1}
)
@ -144,5 +147,5 @@ class RunConfig:
# TODO(xwjiang): Add more.
name: Optional[str] = None
local_dir: Optional[str] = None
callbacks: Optional[List[Callback]] = None
callbacks: Optional[List["Callback"]] = None
failure: Optional[FailureConfig] = None

View file

@ -0,0 +1,119 @@
import pytest
import ray
from ray.ml import Checkpoint
from ray.ml.trainer import Trainer
from ray.ml.preprocessor import Preprocessor
class DummyTrainer(Trainer):
def training_loop(self) -> None:
pass
class DummyDataset(ray.data.Dataset):
def __init__(self):
pass
def test_run_config():
with pytest.raises(ValueError):
DummyTrainer(run_config="invalid")
with pytest.raises(ValueError):
DummyTrainer(run_config=False)
with pytest.raises(ValueError):
DummyTrainer(run_config=True)
with pytest.raises(ValueError):
DummyTrainer(run_config={})
# Succeed
DummyTrainer(run_config=None)
# Succeed
DummyTrainer(run_config=ray.ml.RunConfig())
def test_scaling_config():
with pytest.raises(ValueError):
DummyTrainer(scaling_config="invalid")
with pytest.raises(ValueError):
DummyTrainer(scaling_config=False)
with pytest.raises(ValueError):
DummyTrainer(scaling_config=True)
# Succeed
DummyTrainer(scaling_config={})
# Succeed
DummyTrainer(scaling_config=None)
def test_datasets():
with pytest.raises(ValueError):
DummyTrainer(datasets="invalid")
with pytest.raises(ValueError):
DummyTrainer(datasets=False)
with pytest.raises(ValueError):
DummyTrainer(datasets=True)
with pytest.raises(ValueError):
DummyTrainer(datasets={"test": "invalid"})
# Succeed
DummyTrainer(datasets=None)
# Succeed
DummyTrainer(datasets={"test": DummyDataset()})
def test_preprocessor():
with pytest.raises(ValueError):
DummyTrainer(preprocessor="invalid")
with pytest.raises(ValueError):
DummyTrainer(preprocessor=False)
with pytest.raises(ValueError):
DummyTrainer(preprocessor=True)
with pytest.raises(ValueError):
DummyTrainer(preprocessor={})
# Succeed
DummyTrainer(preprocessor=None)
# Succeed
DummyTrainer(preprocessor=Preprocessor())
def test_resume_from_checkpoint():
with pytest.raises(ValueError):
DummyTrainer(resume_from_checkpoint="invalid")
with pytest.raises(ValueError):
DummyTrainer(resume_from_checkpoint=False)
with pytest.raises(ValueError):
DummyTrainer(resume_from_checkpoint=True)
with pytest.raises(ValueError):
DummyTrainer(resume_from_checkpoint={})
# Succeed
DummyTrainer(resume_from_checkpoint=None)
# Succeed
DummyTrainer(resume_from_checkpoint=Checkpoint.from_dict({"empty": ""}))
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -183,6 +183,7 @@ class DataParallelTrainer(Trainer):
def __init__(
self,
*,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
train_loop_config: Optional[Dict] = None,
backend_config: Optional[BackendConfig] = None,
@ -198,6 +199,11 @@ class DataParallelTrainer(Trainer):
self.train_loop_per_worker = train_loop_per_worker
self.train_loop_config = train_loop_config
backend_config = (
backend_config if backend_config is not None else BackendConfig()
)
self.backend_config = backend_config
super(DataParallelTrainer, self).__init__(
scaling_config=scaling_config,
run_config=run_config,
@ -206,6 +212,9 @@ class DataParallelTrainer(Trainer):
resume_from_checkpoint=resume_from_checkpoint,
)
def _validate_attributes(self):
super()._validate_attributes()
if (
not self.scaling_config.get("use_gpu", False)
and "GPU" in ray.available_resources()

View file

@ -72,6 +72,7 @@ class GBDTTrainer(Trainer):
def __init__(
self,
*,
datasets: Dict[str, GenDataset],
label_column: str,
params: Dict[str, Any],
@ -87,8 +88,15 @@ class GBDTTrainer(Trainer):
self.dmatrix_params = dmatrix_params or {}
self.train_kwargs = train_kwargs
super().__init__(
scaling_config, run_config, datasets, preprocessor, resume_from_checkpoint
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def _validate_attributes(self):
super()._validate_attributes()
self._validate_config_and_datasets()
def _validate_config_and_datasets(self) -> None:

View file

@ -164,6 +164,7 @@ class HorovodTrainer(DataParallelTrainer):
def __init__(
self,
*,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
train_loop_config: Optional[Dict] = None,
horovod_config: Optional[HorovodConfig] = None,
@ -174,7 +175,7 @@ class HorovodTrainer(DataParallelTrainer):
resume_from_checkpoint: Optional[Checkpoint] = None,
):
super().__init__(
train_loop_per_worker,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
backend_config=horovod_config or HorovodConfig(),
scaling_config=scaling_config,

View file

@ -156,6 +156,7 @@ class TensorflowTrainer(DataParallelTrainer):
def __init__(
self,
*,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
train_loop_config: Optional[Dict] = None,
tensorflow_config: Optional[TensorflowConfig] = None,

View file

@ -165,6 +165,7 @@ class TorchTrainer(DataParallelTrainer):
def __init__(
self,
*,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
train_loop_config: Optional[Dict] = None,
torch_config: Optional[TorchConfig] = None,

View file

@ -3,6 +3,8 @@ import inspect
import logging
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
import ray
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
from ray.ml.result import Result
@ -133,6 +135,7 @@ class Trainer(abc.ABC):
def __init__(
self,
*,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
@ -140,12 +143,14 @@ class Trainer(abc.ABC):
resume_from_checkpoint: Optional[Checkpoint] = None,
):
self.scaling_config = scaling_config if scaling_config else {}
self.run_config = run_config if run_config else RunConfig()
self.datasets = datasets if datasets else {}
self.scaling_config = scaling_config if scaling_config is not None else {}
self.run_config = run_config if run_config is not None else RunConfig()
self.datasets = datasets if datasets is not None else {}
self.preprocessor = preprocessor
self.resume_from_checkpoint = resume_from_checkpoint
self._validate_attributes()
def __new__(cls, *args, **kwargs):
"""Store the init args as attributes so this can be merged with Tune hparams."""
trainer = super(Trainer, cls).__new__(cls)
@ -157,6 +162,54 @@ class Trainer(abc.ABC):
trainer._param_dict = {**arg_dict, **kwargs}
return trainer
def _validate_attributes(self):
"""Called on __init()__ to validate trainer attributes."""
# Run config
if not isinstance(self.run_config, RunConfig):
raise ValueError(
f"`run_config` should be an instance of `ray.ml.RunConfig`, "
f"found {type(self.run_config)} with value `{self.run_config}`."
)
# Scaling config
# Todo: move to ray.ml.ScalingConfig
if not isinstance(self.scaling_config, dict):
raise ValueError(
f"`scaling_config` should be an instance of `dict`, "
f"found {type(self.run_config)} with value `{self.run_config}`."
)
# Datasets
if not isinstance(self.datasets, dict):
raise ValueError(
f"`datasets` should be a dict mapping from a string to "
f"`ray.data.Dataset` objects, "
f"found {type(self.datasets)} with value `{self.datasets}`."
)
elif any(
not isinstance(ds, ray.data.Dataset) and not callable(ds)
for ds in self.datasets.values()
):
raise ValueError(
f"At least one value in the `datasets` dict is not a "
f"`ray.data.Dataset`: {self.datasets}"
)
# Preprocessor
if self.preprocessor is not None and not isinstance(
self.preprocessor, ray.ml.preprocessor.Preprocessor
):
raise ValueError(
f"`preprocessor` should be an instance of `ray.ml.Preprocessor`, "
f"found {type(self.preprocessor)} with value `{self.preprocessor}`."
)
if self.resume_from_checkpoint is not None and not isinstance(
self.resume_from_checkpoint, ray.ml.Checkpoint
):
raise ValueError(
f"`resume_from_checkpoint` should be an instance of "
f"`ray.ml.Checkpoint`, found {type(self.resume_from_checkpoint)} "
f"with value `{self.resume_from_checkpoint}`."
)
def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.