mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] switch to a common RunConfig. (#23076)
This commit is contained in:
parent
2b38fe89e2
commit
f270d84094
4 changed files with 48 additions and 54 deletions
|
@ -1,8 +1,49 @@
|
|||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Right now, RunConfig is just an arbitrary dict that specifies tune.run
|
||||
# kwargs.
|
||||
# TODO(xwjiang): After Tuner is implemented, define the schema
|
||||
RunConfig = Dict[str, Any]
|
||||
from ray.tune.callback import Callback
|
||||
from ray.util import PublicAPI
|
||||
|
||||
ScalingConfig = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
class FailureConfig:
|
||||
"""Configuration related to failure handling of each run/trial.
|
||||
Args:
|
||||
max_failures: Tries to recover a run at least this many times.
|
||||
Will recover from the latest checkpoint if present.
|
||||
Setting to -1 will lead to infinite recovery retries.
|
||||
Setting to 0 will disable retries. Defaults to 0.
|
||||
"""
|
||||
|
||||
max_failures: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
class RunConfig:
|
||||
"""Runtime configuration for individual trials that are run.
|
||||
|
||||
This contains information that applies to individual runs of Trainable classes.
|
||||
This includes both running a Trainable by itself or running a hyperparameter
|
||||
tuning job on top of a Trainable (applies to each trial).
|
||||
|
||||
Args:
|
||||
name: Name of the trial or experiment. If not provided, will be deduced
|
||||
from the Trainable.
|
||||
local_dir: Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
callbacks: Callbacks to invoke.
|
||||
Refer to ray.tune.callback.Callback for more info.
|
||||
"""
|
||||
|
||||
# TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to
|
||||
# reapply some of the args here? For callbacks, how do we enforce only stateless
|
||||
# callbacks?
|
||||
# TODO(xwjiang): Add more.
|
||||
name: Optional[str] = None
|
||||
local_dir: Optional[str] = None
|
||||
callbacks: Optional[List[Callback]] = None
|
||||
failure: Optional[FailureConfig] = None
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.util import PublicAPI
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
class FailureConfig:
|
||||
"""Configuration related to failure handling of each run/trial.
|
||||
Args:
|
||||
max_failures: Tries to recover a run at least this many times.
|
||||
Will recover from the latest checkpoint if present.
|
||||
Setting to -1 will lead to infinite recovery retries.
|
||||
Setting to 0 will disable retries. Defaults to 0.
|
||||
"""
|
||||
|
||||
max_failures: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="alpha")
|
||||
class RunConfig:
|
||||
"""Runtime configuration for individual trials that are run.
|
||||
|
||||
This contains information that applies to individual runs of Trainable classes.
|
||||
This includes both running a Trainable by itself or running a hyperparameter
|
||||
tuning job on top of a Trainable (applies to each trial).
|
||||
|
||||
Args:
|
||||
name: Name of the trial or experiment. If not provided, will be deduced
|
||||
from the Trainable.
|
||||
local_dir: Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
callbacks: Callbacks to invoke.
|
||||
Refer to ray.tune.callback.Callback for more info.
|
||||
"""
|
||||
|
||||
# TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to
|
||||
# reapply some of the args here? For callbacks, how do we enforce only stateless
|
||||
# callbacks?
|
||||
# TODO(xwjiang): Add more.
|
||||
name: Optional[str] = None
|
||||
local_dir: Optional[str] = None
|
||||
callbacks: Optional[List[Callback]] = None
|
||||
failure: Optional[FailureConfig] = None
|
|
@ -5,7 +5,7 @@ from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type
|
|||
from ray.ml.preprocessor import Preprocessor
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.result import Result
|
||||
from ray.ml.config import ScalingConfig, RunConfig
|
||||
from ray.ml.config import RunConfig, ScalingConfig
|
||||
from ray.tune import Trainable
|
||||
from ray.util import PublicAPI
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ class Tuner:
|
|||
tune_config: Tuning algorithm specific configs.
|
||||
Refer to ray.tune.tune_config.TuneConfig for more info.
|
||||
run_config: Runtime configuration that is specific to individual trials.
|
||||
Refer to ray.ml.run_config.RunConfig for more info.
|
||||
Refer to ray.ml.config.RunConfig for more info.
|
||||
|
||||
Returns:
|
||||
``ResultGrid`` object.
|
||||
|
|
Loading…
Add table
Reference in a new issue