mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Refactor ScalingConfig
key validation (#25549)
Follow another approach mentioned in #25350. The scaling config is now converted to the dataclass letting us use a single function for validation of both user supplied dicts and dataclasses. This PR also fixes the fact the scaling config wasn't validated in the GBDT Trainer and validates that allowed keys set in Trainers are present in the dataclass. Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
feb8c29063
commit
7bb142e3e4
8 changed files with 54 additions and 73 deletions
|
@ -2,28 +2,6 @@ import dataclasses
|
|||
from typing import Iterable
|
||||
|
||||
|
||||
def ensure_only_allowed_dict_keys_set(
|
||||
data: dict,
|
||||
allowed_keys: Iterable[str],
|
||||
):
|
||||
"""
|
||||
Validate dict by raising an exception if any key not included in
|
||||
``allowed_keys`` is set.
|
||||
|
||||
Args:
|
||||
data: Dict to check.
|
||||
allowed_keys: Iterable of keys that can be contained in dict keys.
|
||||
"""
|
||||
allowed_keys_set = set(allowed_keys)
|
||||
bad_keys = [key for key in data.keys() if key not in allowed_keys_set]
|
||||
|
||||
if bad_keys:
|
||||
raise ValueError(
|
||||
f"Key(s) {bad_keys} are not allowed to be set in the current context. "
|
||||
"Remove them from the dict."
|
||||
)
|
||||
|
||||
|
||||
def ensure_only_allowed_dataclass_keys_updated(
|
||||
dataclass: dataclasses.dataclass,
|
||||
allowed_keys: Iterable[str],
|
||||
|
@ -32,6 +10,9 @@ def ensure_only_allowed_dataclass_keys_updated(
|
|||
Validate dataclass by raising an exception if any key not included in
|
||||
``allowed_keys`` differs from the default value.
|
||||
|
||||
A ``ValueError`` will also be raised if any of the ``allowed_keys``
|
||||
is not present in ``dataclass.__dict__``.
|
||||
|
||||
Args:
|
||||
dataclass: Dict or dataclass to check.
|
||||
allowed_keys: dataclass attribute keys that can have a value different than
|
||||
|
@ -41,6 +22,16 @@ def ensure_only_allowed_dataclass_keys_updated(
|
|||
|
||||
allowed_keys = set(allowed_keys)
|
||||
|
||||
# TODO: split keys_not_in_dict validation to a separate function.
|
||||
keys_not_in_dict = [key for key in allowed_keys if key not in default_data.__dict__]
|
||||
if keys_not_in_dict:
|
||||
raise ValueError(
|
||||
f"Key(s) {keys_not_in_dict} are not present in "
|
||||
f"{dataclass.__class__.__name__}. "
|
||||
"Remove them from `allowed_keys`. "
|
||||
f"Valid keys: {list(default_data.__dict__.keys())}"
|
||||
)
|
||||
|
||||
# These keys should not have been updated in the `dataclass` object
|
||||
prohibited_keys = set(default_data.__dict__) - allowed_keys
|
||||
|
||||
|
|
|
@ -5,10 +5,7 @@ from ray.air import Checkpoint
|
|||
from ray.air.config import ScalingConfigDataClass
|
||||
from ray.train import BaseTrainer
|
||||
from ray.air.preprocessor import Preprocessor
|
||||
from ray.air._internal.config import (
|
||||
ensure_only_allowed_dataclass_keys_updated,
|
||||
ensure_only_allowed_dict_keys_set,
|
||||
)
|
||||
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
|
||||
|
||||
|
||||
class DummyTrainer(BaseTrainer):
|
||||
|
@ -65,29 +62,28 @@ def test_scaling_config_validate_config_valid_class():
|
|||
)
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_valid_dict():
|
||||
scaling_config = {"num_workers": 2}
|
||||
ensure_only_allowed_dict_keys_set(scaling_config, ["num_workers"])
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_prohibited_class():
|
||||
# Check for prohibited keys
|
||||
scaling_config = {"num_workers": 2}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
ScalingConfigDataClass(**scaling_config),
|
||||
["trainer_resources"],
|
||||
)
|
||||
assert "num_workers" in str(exc_info.value)
|
||||
assert "to be updated" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_scaling_config_validate_config_prohibited_dict():
|
||||
# Check for prohibited keys
|
||||
def test_scaling_config_validate_config_bad_allowed_keys():
|
||||
# Check for keys not present in dict
|
||||
scaling_config = {"num_workers": 2}
|
||||
with pytest.raises(ValueError):
|
||||
ensure_only_allowed_dict_keys_set(
|
||||
scaling_config,
|
||||
["trainer_resources"],
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
ScalingConfigDataClass(**scaling_config),
|
||||
["BAD_KEY"],
|
||||
)
|
||||
assert "BAD_KEY" in str(exc_info.value)
|
||||
assert "are not present in" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_datasets():
|
||||
|
|
|
@ -13,10 +13,7 @@ from ray.air.config import (
|
|||
ScalingConfigDataClass,
|
||||
)
|
||||
from ray.air.result import Result
|
||||
from ray.air._internal.config import (
|
||||
ensure_only_allowed_dataclass_keys_updated,
|
||||
ensure_only_allowed_dict_keys_set,
|
||||
)
|
||||
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import wrap_function
|
||||
|
@ -225,12 +222,7 @@ class BaseTrainer(abc.ABC):
|
|||
) -> ScalingConfigDataClass:
|
||||
"""Return scaling config dataclass after validating updated keys."""
|
||||
if isinstance(dataclass_or_dict, dict):
|
||||
ensure_only_allowed_dict_keys_set(
|
||||
dataclass_or_dict, cls._scaling_config_allowed_keys
|
||||
)
|
||||
scaling_config_dataclass = ScalingConfigDataClass(**dataclass_or_dict)
|
||||
|
||||
return scaling_config_dataclass
|
||||
dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict)
|
||||
|
||||
ensure_only_allowed_dataclass_keys_updated(
|
||||
dataclass=dataclass_or_dict,
|
||||
|
|
|
@ -228,11 +228,9 @@ class DataParallelTrainer(BaseTrainer):
|
|||
|
||||
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"resources_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"placement_strategy",
|
||||
]
|
||||
|
||||
_dataset_config = {
|
||||
|
|
|
@ -17,16 +17,15 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def _convert_scaling_config_to_ray_params(
|
||||
scaling_config: ScalingConfig,
|
||||
scaling_config: ScalingConfigDataClass,
|
||||
ray_params_cls: Type["xgboost_ray.RayParams"],
|
||||
default_ray_params: Optional[Dict[str, Any]] = None,
|
||||
) -> "xgboost_ray.RayParams":
|
||||
default_ray_params = default_ray_params or {}
|
||||
scaling_config_dataclass = ScalingConfigDataClass(**scaling_config)
|
||||
resources_per_worker = scaling_config_dataclass.additional_resources_per_worker
|
||||
num_workers = scaling_config_dataclass.num_workers
|
||||
cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker
|
||||
gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker
|
||||
resources_per_worker = scaling_config.additional_resources_per_worker
|
||||
num_workers = scaling_config.num_workers
|
||||
cpus_per_worker = scaling_config.num_cpus_per_worker
|
||||
gpus_per_worker = scaling_config.num_gpus_per_worker
|
||||
|
||||
ray_params = ray_params_cls(
|
||||
num_actors=int(num_workers),
|
||||
|
@ -67,11 +66,9 @@ class GBDTTrainer(BaseTrainer):
|
|||
|
||||
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"resources_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"placement_strategy",
|
||||
]
|
||||
_dmatrix_cls: type
|
||||
_ray_params_cls: type
|
||||
|
@ -143,8 +140,11 @@ class GBDTTrainer(BaseTrainer):
|
|||
|
||||
@property
|
||||
def _ray_params(self) -> "xgboost_ray.RayParams":
|
||||
scaling_config_dataclass = self._validate_and_get_scaling_config_data_class(
|
||||
self.scaling_config
|
||||
)
|
||||
return _convert_scaling_config_to_ray_params(
|
||||
self.scaling_config, self._ray_params_cls, self._default_ray_params
|
||||
scaling_config_dataclass, self._ray_params_cls, self._default_ray_params
|
||||
)
|
||||
|
||||
def preprocess_datasets(self) -> None:
|
||||
|
@ -197,6 +197,7 @@ class GBDTTrainer(BaseTrainer):
|
|||
|
||||
def as_trainable(self) -> Type[Trainable]:
|
||||
trainable_cls = super().as_trainable()
|
||||
trainer_cls = self.__class__
|
||||
scaling_config = self.scaling_config
|
||||
ray_params_cls = self._ray_params_cls
|
||||
default_ray_params = self._default_ray_params
|
||||
|
@ -214,8 +215,13 @@ class GBDTTrainer(BaseTrainer):
|
|||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
updated_scaling_config = config.get("scaling_config", scaling_config)
|
||||
scaling_config_dataclass = (
|
||||
trainer_cls._validate_and_get_scaling_config_data_class(
|
||||
updated_scaling_config
|
||||
)
|
||||
)
|
||||
return _convert_scaling_config_to_ray_params(
|
||||
updated_scaling_config, ray_params_cls, default_ray_params
|
||||
scaling_config_dataclass, ray_params_cls, default_ray_params
|
||||
).get_tune_resources()
|
||||
|
||||
return GBDTTrainable
|
||||
|
|
|
@ -29,12 +29,11 @@ class DummyPreprocessor(Preprocessor):
|
|||
|
||||
class DummyTrainer(BaseTrainer):
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"trainer_resources",
|
||||
"num_workers",
|
||||
"use_gpu",
|
||||
"resources_per_worker",
|
||||
"placement_strategy",
|
||||
]
|
||||
|
||||
def __init__(self, train_loop, custom_arg=None, **kwargs):
|
||||
|
|
|
@ -172,7 +172,7 @@ def test_validation(ray_start_4_cpus):
|
|||
label_column="target",
|
||||
datasets={TRAIN_DATASET_KEY: train_dataset, "cv": valid_dataset},
|
||||
)
|
||||
with pytest.raises(ValueError, match="are not allowed to be set"):
|
||||
with pytest.raises(ValueError, match="are not allowed to be updated"):
|
||||
SklearnTrainer(
|
||||
estimator=RandomForestClassifier(),
|
||||
scaling_config={"num_workers": 2},
|
||||
|
|
|
@ -25,12 +25,11 @@ from ray.tune.tuner import Tuner
|
|||
|
||||
class DummyTrainer(BaseTrainer):
|
||||
_scaling_config_allowed_keys = [
|
||||
"num_workers",
|
||||
"num_cpus_per_worker",
|
||||
"num_gpus_per_worker",
|
||||
"additional_resources_per_worker",
|
||||
"use_gpu",
|
||||
"trainer_resources",
|
||||
"num_workers",
|
||||
"use_gpu",
|
||||
"resources_per_worker",
|
||||
"placement_strategy",
|
||||
]
|
||||
|
||||
def training_loop(self) -> None:
|
||||
|
|
Loading…
Add table
Reference in a new issue