[AIR] Refactor ScalingConfig key validation (#25549)

Follow another approach mentioned in #25350.

The scaling config is now converted to the dataclass letting us use a single function for validation of both user supplied dicts and dataclasses. This PR also fixes the fact the scaling config wasn't validated in the GBDT Trainer and validates that allowed keys set in Trainers are present in the dataclass.

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Jimmy Yao 2022-06-13 09:43:24 -07:00 committed by GitHub
parent feb8c29063
commit 7bb142e3e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 54 additions and 73 deletions

View file

@ -2,28 +2,6 @@ import dataclasses
from typing import Iterable
def ensure_only_allowed_dict_keys_set(
data: dict,
allowed_keys: Iterable[str],
):
"""
Validate dict by raising an exception if any key not included in
``allowed_keys`` is set.
Args:
data: Dict to check.
allowed_keys: Iterable of keys that can be contained in dict keys.
"""
allowed_keys_set = set(allowed_keys)
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]
if bad_keys:
raise ValueError(
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
"Remove them from the dict."
)
def ensure_only_allowed_dataclass_keys_updated(
dataclass: dataclasses.dataclass,
allowed_keys: Iterable[str],
@ -32,6 +10,9 @@ def ensure_only_allowed_dataclass_keys_updated(
Validate dataclass by raising an exception if any key not included in
``allowed_keys`` differs from the default value.
A ``ValueError`` will also be raised if any of the ``allowed_keys``
is not present in ``dataclass.__dict__``.
Args:
dataclass: Dict or dataclass to check.
allowed_keys: dataclass attribute keys that can have a value different than
@ -41,6 +22,16 @@ def ensure_only_allowed_dataclass_keys_updated(
allowed_keys = set(allowed_keys)
# TODO: split keys_not_in_dict validation to a separate function.
keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__]
if keys_not_in_dict:
raise ValueError(
f"Key(s) {keys_not_in_dict} are not present in "
f"{dataclass.__class__.__name__}. "
"Remove them from `allowed_keys`. "
f"Valid keys: {list(default_data.__dict__.keys())}"
)
# These keys should not have been updated in the `dataclass` object
prohibited_keys = set(default_data.__dict__) - allowed_keys

View file

@ -5,10 +5,7 @@ from ray.air import Checkpoint
from ray.air.config import ScalingConfigDataClass
from ray.train import BaseTrainer
from ray.air.preprocessor import Preprocessor
from ray.air._internal.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
class DummyTrainer(BaseTrainer):
@ -65,29 +62,28 @@ def test_scaling_config_validate_config_valid_class():
)
def test_scaling_config_validate_config_valid_dict():
scaling_config = {"num_workers": 2}
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])
def test_scaling_config_validate_config_prohibited_class():
# Check for prohibited keys
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc_info:
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["trainer_resources"],
)
assert "num_workers" in str(exc_info.value)
assert "to be updated" in str(exc_info.value)
def test_scaling_config_validate_config_prohibited_dict():
# Check for prohibited keys
def test_scaling_config_validate_config_bad_allowed_keys():
# Check for keys not present in dict
scaling_config = {"num_workers": 2}
with pytest.raises(ValueError):
ensure_only_allowed_dict_keys_set(
scaling_config,
["trainer_resources"],
with pytest.raises(ValueError) as exc_info:
ensure_only_allowed_dataclass_keys_updated(
ScalingConfigDataClass(**scaling_config),
["BAD_KEY"],
)
assert "BAD_KEY" in str(exc_info.value)
assert "are not present in" in str(exc_info.value)
def test_datasets():

View file

@ -13,10 +13,7 @@ from ray.air.config import (
ScalingConfigDataClass,
)
from ray.air.result import Result
from ray.air._internal.config import (
ensure_only_allowed_dataclass_keys_updated,
ensure_only_allowed_dict_keys_set,
)
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.tune import Trainable
from ray.tune.error import TuneError
from ray.tune.function_runner import wrap_function
@ -225,12 +222,7 @@ class BaseTrainer(abc.ABC):
) -> ScalingConfigDataClass:
"""Return scaling config dataclass after validating updated keys."""
if isinstance(dataclass_or_dict, dict):
ensure_only_allowed_dict_keys_set(
dataclass_or_dict, cls._scaling_config_allowed_keys
)
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)
return scaling_config_dataclass
dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict)
ensure_only_allowed_dataclass_keys_updated(
dataclass=dataclass_or_dict,

View file

@ -228,11 +228,9 @@ class DataParallelTrainer(BaseTrainer):
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"resources_per_worker",
"additional_resources_per_worker",
"use_gpu",
"placement_strategy",
]
_dataset_config = {

View file

@ -17,16 +17,15 @@ if TYPE_CHECKING:
def _convert_scaling_config_to_ray_params(
scaling_config: ScalingConfig,
scaling_config: ScalingConfigDataClass,
ray_params_cls: Type["xgboost_ray.RayParams"],
default_ray_params: Optional[Dict[str, Any]] = None,
) -> "xgboost_ray.RayParams":
default_ray_params = default_ray_params or {}
scaling_config_dataclass = ScalingConfigDataClass(**scaling_config)
resources_per_worker = scaling_config_dataclass.additional_resources_per_worker
num_workers = scaling_config_dataclass.num_workers
cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker
gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker
resources_per_worker = scaling_config.additional_resources_per_worker
num_workers = scaling_config.num_workers
cpus_per_worker = scaling_config.num_cpus_per_worker
gpus_per_worker = scaling_config.num_gpus_per_worker
ray_params = ray_params_cls(
num_actors=int(num_workers),
@ -67,11 +66,9 @@ class GBDTTrainer(BaseTrainer):
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"resources_per_worker",
"additional_resources_per_worker",
"use_gpu",
"placement_strategy",
]
_dmatrix_cls: type
_ray_params_cls: type
@ -143,8 +140,11 @@ class GBDTTrainer(BaseTrainer):
@property
def _ray_params(self) -> "xgboost_ray.RayParams":
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
self.scaling_config
)
return _convert_scaling_config_to_ray_params(
self.scaling_config, self._ray_params_cls, self._default_ray_params
scaling_config_dataclass, self._ray_params_cls, self._default_ray_params
)
def preprocess_datasets(self) -> None:
@ -197,6 +197,7 @@ class GBDTTrainer(BaseTrainer):
def as_trainable(self) -> Type[Trainable]:
trainable_cls = super().as_trainable()
trainer_cls = self.__class__
scaling_config = self.scaling_config
ray_params_cls = self._ray_params_cls
default_ray_params = self._default_ray_params
@ -214,8 +215,13 @@ class GBDTTrainer(BaseTrainer):
@classmethod
def default_resource_request(cls, config):
updated_scaling_config = config.get("scaling_config", scaling_config)
scaling_config_dataclass = (
trainer_cls._validate_and_get_scaling_config_data_class(
updated_scaling_config
)
)
return _convert_scaling_config_to_ray_params(
updated_scaling_config, ray_params_cls, default_ray_params
scaling_config_dataclass, ray_params_cls, default_ray_params
).get_tune_resources()
return GBDTTrainable

View file

@ -29,12 +29,11 @@ class DummyPreprocessor(Preprocessor):
class DummyTrainer(BaseTrainer):
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
"num_workers",
"use_gpu",
"resources_per_worker",
"placement_strategy",
]
def __init__(self, train_loop, custom_arg=None, **kwargs):

View file

@ -172,7 +172,7 @@ def test_validation(ray_start_4_cpus):
label_column="target",
datasets={TRAIN_DATASET_KEY: train_dataset, "cv": valid_dataset},
)
with pytest.raises(ValueError, match="are not allowed to be set"):
with pytest.raises(ValueError, match="are not allowed to be updated"):
SklearnTrainer(
estimator=RandomForestClassifier(),
scaling_config={"num_workers": 2},

View file

@ -25,12 +25,11 @@ from ray.tune.tuner import Tuner
class DummyTrainer(BaseTrainer):
_scaling_config_allowed_keys = [
"num_workers",
"num_cpus_per_worker",
"num_gpus_per_worker",
"additional_resources_per_worker",
"use_gpu",
"trainer_resources",
"num_workers",
"use_gpu",
"resources_per_worker",
"placement_strategy",
]
def training_loop(self) -> None: