mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Add Scaling Config validation (#23889)
Adds a `ScalingConfigDataClass.validate_config` classmethod to allow for a generic way of validating ScalingConfigs by allowing only certain keys. Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
7f3031f451
commit
1fc6db30a5
8 changed files with 174 additions and 11 deletions
|
@ -1,13 +1,22 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.util import PublicAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.tune.trainable import PlacementGroupFactory
|
||||
|
||||
ScalingConfig = Dict[str, Any]
|
||||
|
||||
|
|
|
@ -2,8 +2,13 @@ import pytest
|
|||
|
||||
import ray
|
||||
from ray.ml import Checkpoint
|
||||
from ray.ml.config import ScalingConfigDataClass
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.utils.config import (
|
||||
ensure_only_allowed_dataclass_keys_updated,
|
||||
ensure_only_allowed_dict_keys_set,
|
||||
)
|
||||
|
||||
|
||||
class DummyTrainer(Trainer):
|
||||
|
@ -53,6 +58,38 @@ def test_scaling_config():
|
|||
DummyTrainer(scaling_config=None)
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_valid_class():
|
||||
scaling_config = {"num_workers": 2}
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
ScalingConfigDataClass(**scaling_config), ["num_workers"]
|
||||
)
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_valid_dict():
|
||||
scaling_config = {"num_workers": 2}
|
||||
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_prohibited_class():
|
||||
# Check for prohibited keys
|
||||
scaling_config = {"num_workers": 2}
|
||||
with pytest.raises(ValueError):
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
ScalingConfigDataClass(**scaling_config),
|
||||
["trainer_resources"],
|
||||
)
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_prohibited_dict():
|
||||
# Check for prohibited keys
|
||||
scaling_config = {"num_workers": 2}
|
||||
with pytest.raises(ValueError):
|
||||
ensure_only_allowed_dict_keys_set(
|
||||
scaling_config,
|
||||
["trainer_resources"],
|
||||
)
|
||||
|
||||
|
||||
def test_datasets():
|
||||
with pytest.raises(ValueError):
|
||||
DummyTrainer(datasets="invalid")
|
||||
|
|
|
@ -27,6 +27,15 @@ class DummyPreprocessor(Preprocessor):
|
|||
|
||||
|
||||
class DummyTrainer(Trainer):
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"trainer_resources",
|
||||
]
|
||||
|
||||
def __init__(self, train_loop, custom_arg=None, **kwargs):
|
||||
self.custom_arg = custom_arg
|
||||
self.train_loop = train_loop
|
||||
|
|
|
@ -7,7 +7,7 @@ import ray
|
|||
from ray import tune
|
||||
from ray.ml.constants import TRAIN_DATASET_KEY, PREPROCESSOR_KEY
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray.ml.config import ScalingConfig, RunConfig, ScalingConfigDataClass
|
||||
from ray.ml.config import ScalingConfig, RunConfig
|
||||
from ray.ml.trainer import GenDataset
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
|
@ -181,6 +181,14 @@ class DataParallelTrainer(Trainer):
|
|||
resume_from_checkpoint: A checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
@ -250,7 +258,9 @@ class DataParallelTrainer(Trainer):
|
|||
)
|
||||
|
||||
def training_loop(self) -> None:
|
||||
scaling_config_dataclass = ScalingConfigDataClass(**self.scaling_config)
|
||||
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
|
||||
self.scaling_config
|
||||
)
|
||||
|
||||
train_loop_per_worker = construct_train_func(
|
||||
self.train_loop_per_worker,
|
||||
|
|
|
@ -65,6 +65,13 @@ class GBDTTrainer(Trainer):
|
|||
**train_kwargs: Additional kwargs passed to framework ``train()`` function.
|
||||
"""
|
||||
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
]
|
||||
_dmatrix_cls: type
|
||||
_ray_params_cls: type
|
||||
_tune_callback_cls: type
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import abc
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.util import PublicAPI
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.result import Result
|
||||
from ray.ml.config import RunConfig, ScalingConfig, ScalingConfigDataClass
|
||||
from ray.ml.constants import TRAIN_DATASET_KEY
|
||||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.result import Result
|
||||
from ray.ml.utils.config import (
|
||||
ensure_only_allowed_dataclass_keys_updated,
|
||||
ensure_only_allowed_dict_keys_set,
|
||||
)
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.util import PublicAPI
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
|
@ -133,6 +136,8 @@ class Trainer(abc.ABC):
|
|||
resume_from_checkpoint: A checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
_scaling_config_allowed_keys: List[str] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
@ -210,6 +215,25 @@ class Trainer(abc.ABC):
|
|||
f"with value `{self.resume_from_checkpoint}`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_and_get_scaling_config_data_class(
|
||||
cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]]
|
||||
) -> ScalingConfigDataClass:
|
||||
"""Return scaling config dataclass after validating updated keys."""
|
||||
if isinstance(dataclass_or_dict, dict):
|
||||
ensure_only_allowed_dict_keys_set(
|
||||
dataclass_or_dict, cls._scaling_config_allowed_keys
|
||||
)
|
||||
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)
|
||||
|
||||
return scaling_config_dataclass
|
||||
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
dataclass=dataclass_or_dict,
|
||||
allowed_keys=cls._scaling_config_allowed_keys,
|
||||
)
|
||||
return dataclass_or_dict
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Called during fit() to perform initial setup on the Trainer.
|
||||
|
||||
|
@ -359,8 +383,10 @@ class Trainer(abc.ABC):
|
|||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
updated_scaling_config = config.get("scaling_config", scaling_config)
|
||||
scaling_config_dataclass = ScalingConfigDataClass(
|
||||
**updated_scaling_config
|
||||
scaling_config_dataclass = (
|
||||
trainer_cls._validate_and_get_scaling_config_data_class(
|
||||
updated_scaling_config
|
||||
)
|
||||
)
|
||||
return scaling_config_dataclass.as_placement_group_factory()
|
||||
|
||||
|
|
56
python/ray/ml/utils/config.py
Normal file
56
python/ray/ml/utils/config.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import dataclasses
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def ensure_only_allowed_dict_keys_set(
|
||||
data: dict,
|
||||
allowed_keys: Iterable[str],
|
||||
):
|
||||
"""
|
||||
Validate dict by raising an exception if any key not included in
|
||||
``allowed_keys`` is set.
|
||||
|
||||
Args:
|
||||
data: Dict to check.
|
||||
allowed_keys: Iterable of keys that can be contained in dict keys.
|
||||
"""
|
||||
allowed_keys_set = set(allowed_keys)
|
||||
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]
|
||||
|
||||
if bad_keys:
|
||||
raise ValueError(
|
||||
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
|
||||
"Remove them from the dict."
|
||||
)
|
||||
|
||||
|
||||
def ensure_only_allowed_dataclass_keys_updated(
|
||||
dataclass: dataclasses.dataclass,
|
||||
allowed_keys: Iterable[str],
|
||||
):
|
||||
"""
|
||||
Validate dataclass by raising an exception if any key not included in
|
||||
``allowed_keys`` differs from the default value.
|
||||
|
||||
Args:
|
||||
dataclass: Dict or dataclass to check.
|
||||
allowed_keys: dataclass attribute keys that can have a value different than
|
||||
the default one.
|
||||
"""
|
||||
default_data = dataclass.__class__()
|
||||
|
||||
allowed_keys = set(allowed_keys)
|
||||
|
||||
# These keys should not have been updated in the `dataclass` object
|
||||
prohibited_keys = set(default_data.__dict__) - allowed_keys
|
||||
|
||||
bad_keys = [
|
||||
key
|
||||
for key in prohibited_keys
|
||||
if dataclass.__dict__[key] != default_data.__dict__[key]
|
||||
]
|
||||
if bad_keys:
|
||||
raise ValueError(
|
||||
f"Key(s) {bad_keys} are not allowed to be updated in the current context. "
|
||||
"Remove them from the dataclass."
|
||||
)
|
|
@ -175,6 +175,15 @@ class TunerTest(unittest.TestCase):
|
|||
|
||||
def test_tuner_trainer_fail(self):
|
||||
class DummyTrainer(Trainer):
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"trainer_resources",
|
||||
]
|
||||
|
||||
def training_loop(self) -> None:
|
||||
raise RuntimeError("There is an error in trainer!")
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue