[AIR] switch to a common RunConfig. (#23076)

This commit is contained in:
xwjiang2010 2022-03-11 10:55:36 -08:00 committed by GitHub
parent 2b38fe89e2
commit f270d84094
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 54 deletions

View file

@ -1,8 +1,49 @@
from typing import Dict, Any
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# Right now, RunConfig is just an arbitrary dict that specifies tune.run
# kwargs.
# TODO(xwjiang): After Tuner is implemented, define the schema
RunConfig = Dict[str, Any]
from ray.tune.callback import Callback
from ray.util import PublicAPI
ScalingConfig = Dict[str, Any]
@dataclass
@PublicAPI(stability="alpha")
class FailureConfig:
"""Configuration related to failure handling of each run/trial.
Args:
max_failures: Tries to recover a run at least this many times.
Will recover from the latest checkpoint if present.
Setting to -1 will lead to infinite recovery retries.
Setting to 0 will disable retries. Defaults to 0.
"""
max_failures: int = 0
@dataclass
@PublicAPI(stability="alpha")
class RunConfig:
"""Runtime configuration for individual trials that are run.
This contains information that applies to individual runs of Trainable classes.
This includes both running a Trainable by itself or running a hyperparameter
tuning job on top of a Trainable (applies to each trial).
Args:
name: Name of the trial or experiment. If not provided, will be deduced
from the Trainable.
local_dir: Local dir to save training results to.
Defaults to ``~/ray_results``.
callbacks: Callbacks to invoke.
Refer to ray.tune.callback.Callback for more info.
"""
# TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to
# reapply some of the args here? For callbacks, how do we enforce only stateless
# callbacks?
# TODO(xwjiang): Add more.
name: Optional[str] = None
local_dir: Optional[str] = None
callbacks: Optional[List[Callback]] = None
failure: Optional[FailureConfig] = None

View file

@ -1,47 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional
from ray.tune.callback import Callback
from ray.util import PublicAPI
@dataclass
@PublicAPI(stability="alpha")
class FailureConfig:
"""Configuration related to failure handling of each run/trial.
Args:
max_failures: Tries to recover a run at least this many times.
Will recover from the latest checkpoint if present.
Setting to -1 will lead to infinite recovery retries.
Setting to 0 will disable retries. Defaults to 0.
"""
max_failures: int = 0
@dataclass
@PublicAPI(stability="alpha")
class RunConfig:
"""Runtime configuration for individual trials that are run.
This contains information that applies to individual runs of Trainable classes.
This includes both running a Trainable by itself or running a hyperparameter
tuning job on top of a Trainable (applies to each trial).
Args:
name: Name of the trial or experiment. If not provided, will be deduced
from the Trainable.
local_dir: Local dir to save training results to.
Defaults to ``~/ray_results``.
callbacks: Callbacks to invoke.
Refer to ray.tune.callback.Callback for more info.
"""
# TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to
# reapply some of the args here? For callbacks, how do we enforce only stateless
# callbacks?
# TODO(xwjiang): Add more.
name: Optional[str] = None
local_dir: Optional[str] = None
callbacks: Optional[List[Callback]] = None
failure: Optional[FailureConfig] = None

View file

@ -5,7 +5,7 @@ from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
from ray.ml.result import Result
from ray.ml.config import ScalingConfig, RunConfig
from ray.ml.config import RunConfig, ScalingConfig
from ray.tune import Trainable
from ray.util import PublicAPI

View file

@ -13,7 +13,7 @@ class Tuner:
tune_config: Tuning algorithm specific configs.
Refer to ray.tune.tune_config.TuneConfig for more info.
run_config: Runtime configuration that is specific to individual trials.
Refer to ray.ml.run_config.RunConfig for more info.
Refer to ray.ml.config.RunConfig for more info.
Returns:
``ResultGrid`` object.