diff --git a/python/ray/ml/config.py b/python/ray/ml/config.py index 522a75013..b64810538 100644 --- a/python/ray/ml/config.py +++ b/python/ray/ml/config.py @@ -1,8 +1,49 @@ -from typing import Dict, Any +from dataclasses import dataclass +from typing import Any, Dict, List, Optional -# Right now, RunConfig is just an arbitrary dict that specifies tune.run -# kwargs. -# TODO(xwjiang): After Tuner is implemented, define the schema -RunConfig = Dict[str, Any] +from ray.tune.callback import Callback +from ray.util import PublicAPI ScalingConfig = Dict[str, Any] + + +@dataclass +@PublicAPI(stability="alpha") +class FailureConfig: + """Configuration related to failure handling of each run/trial. + Args: + max_failures: Tries to recover a run at least this many times. + Will recover from the latest checkpoint if present. + Setting to -1 will lead to infinite recovery retries. + Setting to 0 will disable retries. Defaults to 0. + """ + + max_failures: int = 0 + + +@dataclass +@PublicAPI(stability="alpha") +class RunConfig: + """Runtime configuration for individual trials that are run. + + This contains information that applies to individual runs of Trainable classes. + This includes both running a Trainable by itself or running a hyperparameter + tuning job on top of a Trainable (applies to each trial). + + Args: + name: Name of the trial or experiment. If not provided, will be deduced + from the Trainable. + local_dir: Local dir to save training results to. + Defaults to ``~/ray_results``. + callbacks: Callbacks to invoke. + Refer to ray.tune.callback.Callback for more info. + """ + + # TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to + # reapply some of the args here? For callbacks, how do we enforce only stateless + # callbacks? + # TODO(xwjiang): Add more. + name: Optional[str] = None + local_dir: Optional[str] = None + callbacks: Optional[List[Callback]] = None + failure: Optional[FailureConfig] = None diff --git a/python/ray/ml/run_config.py b/python/ray/ml/run_config.py deleted file mode 100644 index 80f056694..000000000 --- a/python/ray/ml/run_config.py +++ /dev/null @@ -1,47 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -from ray.tune.callback import Callback -from ray.util import PublicAPI - - -@dataclass -@PublicAPI(stability="alpha") -class FailureConfig: - """Configuration related to failure handling of each run/trial. - Args: - max_failures: Tries to recover a run at least this many times. - Will recover from the latest checkpoint if present. - Setting to -1 will lead to infinite recovery retries. - Setting to 0 will disable retries. Defaults to 0. - """ - - max_failures: int = 0 - - -@dataclass -@PublicAPI(stability="alpha") -class RunConfig: - """Runtime configuration for individual trials that are run. - - This contains information that applies to individual runs of Trainable classes. - This includes both running a Trainable by itself or running a hyperparameter - tuning job on top of a Trainable (applies to each trial). - - Args: - name: Name of the trial or experiment. If not provided, will be deduced - from the Trainable. - local_dir: Local dir to save training results to. - Defaults to ``~/ray_results``. - callbacks: Callbacks to invoke. - Refer to ray.tune.callback.Callback for more info. - """ - - # TODO(xwjiang): Clarify RunConfig behavior across resume. Is one supposed to - # reapply some of the args here? For callbacks, how do we enforce only stateless - # callbacks? - # TODO(xwjiang): Add more. - name: Optional[str] = None - local_dir: Optional[str] = None - callbacks: Optional[List[Callback]] = None - failure: Optional[FailureConfig] = None diff --git a/python/ray/ml/trainer.py b/python/ray/ml/trainer.py index 9b7742644..522ae59af 100644 --- a/python/ray/ml/trainer.py +++ b/python/ray/ml/trainer.py @@ -5,7 +5,7 @@ from typing import Dict, Union, Callable, Optional, TYPE_CHECKING, Type from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint from ray.ml.result import Result -from ray.ml.config import ScalingConfig, RunConfig +from ray.ml.config import RunConfig, ScalingConfig from ray.tune import Trainable from ray.util import PublicAPI diff --git a/python/ray/tune/tuner.py b/python/ray/tune/tuner.py index 39821d3f2..4031bcd8c 100644 --- a/python/ray/tune/tuner.py +++ b/python/ray/tune/tuner.py @@ -13,7 +13,7 @@ class Tuner: tune_config: Tuning algorithm specific configs. Refer to ray.tune.tune_config.TuneConfig for more info. run_config: Runtime configuration that is specific to individual trials. - Refer to ray.ml.run_config.RunConfig for more info. + Refer to ray.ml.config.RunConfig for more info. Returns: ``ResultGrid`` object.