[AIR] Add Scaling Config validation (#23889)

Adds a `ScalingConfigDataClass.validate_config` classmethod to allow for a generic way of validating ScalingConfigs by allowing only certain keys.

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
Antoni Baum 2022-04-19 22:05:47 +02:00 committed by GitHub
parent 7f3031f451
commit 1fc6db30a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 174 additions and 11 deletions

View file

@ -1,13 +1,22 @@
from dataclasses import dataclass
from typing import Dict, Any, Optional, List, Mapping, Callable, Union, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Union,
)
from ray.tune.syncer import SyncConfig
from ray.util import PublicAPI
if TYPE_CHECKING:
from ray.tune.trainable import PlacementGroupFactory
from ray.tune.callback import Callback
from ray.tune.stopper import Stopper
from ray.tune.trainable import PlacementGroupFactory
ScalingConfig = Dict[str, Any]

View file

@ -2,8 +2,13 @@ import pytest
import ray
from ray.ml import Checkpoint
from ray.ml.config import ScalingConfigDataClass
from ray.ml.trainer import Trainer
from ray.ml.preprocessor import Preprocessor
from ray.ml.utils.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
class DummyTrainer(Trainer):
@ -53,6 +58,38 @@ def test_scaling_config():
DummyTrainer(scaling_config=None)
def test_scaling_config_validate_config_valid_class():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config), ["num_workers"]
)
def test_scaling_config_validate_config_valid_dict():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])
def test_scaling_config_validate_config_prohibited_class():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["trainer_resources"],
)
def test_scaling_config_validate_config_prohibited_dict():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dict_keys_set(
scaling_config,
["trainer_resources"],
)
def test_datasets():
with pytest.raises(ValueError):
DummyTrainer(datasets="invalid")

View file

@ -27,6 +27,15 @@ class DummyPreprocessor(Preprocessor):
class DummyTrainer(Trainer):
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
]
def __init__(self, train_loop, custom_arg=None, **kwargs):
self.custom_arg = custom_arg
self.train_loop = train_loop

View file

@ -7,7 +7,7 @@ import ray
from ray import tune
from ray.ml.constants import TRAIN_DATASET_KEY, PREPROCESSOR_KEY
from ray.ml.trainer import Trainer
from ray.ml.config import ScalingConfig, RunConfig, ScalingConfigDataClass
from ray.ml.config import ScalingConfig, RunConfig
from ray.ml.trainer import GenDataset
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
@ -181,6 +181,14 @@ class DataParallelTrainer(Trainer):
resume_from_checkpoint: A checkpoint to resume training from.
"""
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
]
def __init__(
self,
*,
@ -250,7 +258,9 @@ class DataParallelTrainer(Trainer):
)
def training_loop(self) -> None:
scaling_config_dataclass = ScalingConfigDataClass(**self.scaling_config)
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
self.scaling_config
)
train_loop_per_worker = construct_train_func(
self.train_loop_per_worker,

View file

@ -65,6 +65,13 @@ class GBDTTrainer(Trainer):
**train_kwargs: Additional kwargs passed to framework ``train()`` function.
"""
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
]
_dmatrix_cls: type
_ray_params_cls: type
_tune_callback_cls: type

View file

@ -1,19 +1,22 @@
import abc
import inspect
import logging
from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import ray
from ray.ml.preprocessor import Preprocessor
from ray.util import PublicAPI
from ray.ml.checkpoint import Checkpoint
from ray.ml.result import Result
from ray.ml.config import RunConfig, ScalingConfig, ScalingConfigDataClass
from ray.ml.constants import TRAIN_DATASET_KEY
from ray.ml.preprocessor import Preprocessor
from ray.ml.result import Result
from ray.ml.utils.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.tune import Trainable
from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
from ray.util.ml_utils.dict import merge_dicts
@ -133,6 +136,8 @@ class Trainer(abc.ABC):
resume_from_checkpoint: A checkpoint to resume training from.
"""
_scaling_config_allowed_keys: List[str] = []
def __init__(
self,
*,
@ -210,6 +215,25 @@ class Trainer(abc.ABC):
f"with value `{self.resume_from_checkpoint}`."
)
@classmethod
def _validate_and_get_scaling_config_data_class(
cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]]
) -> ScalingConfigDataClass:
"""Return scaling config dataclass after validating updated keys."""
if isinstance(dataclass_or_dict, dict):
ensure_only_allowed_dict_keys_set(
dataclass_or_dict, cls._scaling_config_allowed_keys
)
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)
return scaling_config_dataclass
ensure_only_allowed_dataclass_keys_updated(
dataclass=dataclass_or_dict,
allowed_keys=cls._scaling_config_allowed_keys,
)
return dataclass_or_dict
def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.
@ -359,8 +383,10 @@ class Trainer(abc.ABC):
@classmethod
def default_resource_request(cls, config):
updated_scaling_config = config.get("scaling_config", scaling_config)
scaling_config_dataclass = ScalingConfigDataClass(
**updated_scaling_config
scaling_config_dataclass = (
trainer_cls._validate_and_get_scaling_config_data_class(
updated_scaling_config
)
)
return scaling_config_dataclass.as_placement_group_factory()

View file

@ -0,0 +1,56 @@
import dataclasses
from typing import Iterable
def ensure_only_allowed_dict_keys_set(
data: dict,
allowed_keys: Iterable[str],
):
"""
Validate dict by raising an exception if any key not included in
``allowed_keys`` is set.
Args:
data: Dict to check.
allowed_keys: Iterable of keys that can be contained in dict keys.
"""
allowed_keys_set = set(allowed_keys)
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]
if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
"Remove them from the dict."
)
def ensure_only_allowed_dataclass_keys_updated(
dataclass: dataclasses.dataclass,
allowed_keys: Iterable[str],
):
"""
Validate dataclass by raising an exception if any key not included in
``allowed_keys`` differs from the default value.
Args:
dataclass: Dict or dataclass to check.
allowed_keys: dataclass attribute keys that can have a value different than
the default one.
"""
default_data = dataclass.__class__()
allowed_keys = set(allowed_keys)
# These keys should not have been updated in the `dataclass` object
prohibited_keys = set(default_data.__dict__) - allowed_keys
bad_keys = [
key
for key in prohibited_keys
if dataclass.__dict__[key] != default_data.__dict__[key]
]
if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be updated in the current context. "
"Remove them from the dataclass."
)

View file

@ -175,6 +175,15 @@ class TunerTest(unittest.TestCase):
def test_tuner_trainer_fail(self):
class DummyTrainer(Trainer):
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
]
def training_loop(self) -> None:
raise RuntimeError("There is an error in trainer!")