diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 47c42247a..652f8c1fa 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -7,7 +7,7 @@ import os import sys import time import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -16,7 +16,9 @@ from ray.tune.logger import logger, pretty_print from ray.tune.result import ( AUTO_RESULT_KEYS, DEFAULT_METRIC, + DONE, EPISODE_REWARD_MEAN, + EXPERIMENT_TAG, MEAN_ACCURACY, MEAN_LOSS, NODE_IP, @@ -24,11 +26,14 @@ from ray.tune.result import ( TIME_TOTAL_S, TIMESTEPS_TOTAL, TRAINING_ITERATION, + TRIAL_ID, ) from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location +from ray.tune.trainable import Trainable from ray.tune.utils import unflattened_lookup from ray.tune.utils.log import Verbosity, has_verbosity from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.ml_utils.dict import flatten_dict from ray.util.queue import Queue try: @@ -52,6 +57,9 @@ except NameError: IS_NOTEBOOK = False +SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE} + + @PublicAPI class ProgressReporter: """Abstract class for experiment progress reporting. @@ -1093,11 +1101,18 @@ class TrialProgressCallback(Callback): """ - def __init__(self, metric: Optional[str] = None): + def __init__( + self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None + ): self._last_print = collections.defaultdict(float) self._completed_trials = set() self._last_result_str = {} self._metric = metric + self._progress_metrics = set(progress_metrics or []) + + # Only use progress metrics if at least two metrics are in there + if self._metric and self._progress_metrics: + self._progress_metrics.add(self._metric) def on_trial_result( self, @@ -1179,16 +1194,27 @@ class TrialProgressCallback(Callback): self._last_print[trial] = time.time() def _print_result(self, result: Dict): - print_result = result.copy() - print_result.pop("config", None) - print_result.pop("hist_stats", None) - print_result.pop("trial_id", None) - print_result.pop("experiment_tag", None) - print_result.pop("done", None) - for auto_result in AUTO_RESULT_KEYS: - print_result.pop(auto_result, None) + if self._progress_metrics: + # If progress metrics are given, only report these + flat_result = flatten_dict(result) - print_result_str = ",".join([f"{k}={v}" for k, v in print_result.items()]) + print_result = {} + for metric in self._progress_metrics: + print_result[metric] = flat_result.get(metric) + + else: + # Else, skip auto populated results + print_result = result.copy() + + for skip_result in SKIP_RESULTS_IN_REPORT: + print_result.pop(skip_result, None) + + for auto_result in AUTO_RESULT_KEYS: + print_result.pop(auto_result, None) + + print_result_str = ",".join( + [f"{k}={v}" for k, v in print_result.items() if v is not None] + ) return print_result_str @@ -1206,3 +1232,13 @@ def detect_reporter(**kwargs) -> TuneReporterBase: else: progress_reporter = CLIReporter(**kwargs) return progress_reporter + + +def detect_progress_metrics( + trainable: Optional[Union["Trainable", Callable]] +) -> Optional[List[str]]: + """Detect progress metrics to report.""" + if not trainable: + return None + + return getattr(trainable, "_progress_metrics", None) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index c563326a5..f30142a1a 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -17,6 +17,7 @@ from ray.tune.progress_reporter import ( ProgressReporter, RemoteReporterMixin, detect_reporter, + detect_progress_metrics, ) from ray.tune.execution.ray_trial_executor import RayTrialExecutor from ray.tune.registry import get_trainable_cls, is_function_trainable @@ -59,22 +60,32 @@ from ray.util.queue import Empty, Queue logger = logging.getLogger(__name__) -def _check_default_resources_override( +def _get_trainable( run_identifier: Union[Experiment, str, Type, Callable] -) -> bool: +) -> Optional[Type[Trainable]]: if isinstance(run_identifier, Experiment): run_identifier = run_identifier.run_identifier if isinstance(run_identifier, type): if not issubclass(run_identifier, Trainable): # If obscure dtype, assume it is overridden. - return True + return None trainable_cls = run_identifier elif callable(run_identifier): trainable_cls = run_identifier elif isinstance(run_identifier, str): trainable_cls = get_trainable_cls(run_identifier) else: + return None + + return trainable_cls + + +def _check_default_resources_override( + run_identifier: Union[Experiment, str, Type, Callable] +) -> bool: + trainable_cls = _get_trainable(run_identifier) + if not trainable_cls: # Default to True return True @@ -610,8 +621,12 @@ def run( "from your scheduler or from your call to `tune.run()`" ) + progress_metrics = detect_progress_metrics(_get_trainable(run_or_experiment)) + # Create syncer callbacks - callbacks = create_default_callbacks(callbacks, sync_config, metric=metric) + callbacks = create_default_callbacks( + callbacks, sync_config, metric=metric, progress_metrics=progress_metrics + ) runner = TrialRunner( search_alg=search_alg, diff --git a/python/ray/tune/utils/callback.py b/python/ray/tune/utils/callback.py index 9459f13ff..d5b3f2e63 100644 --- a/python/ray/tune/utils/callback.py +++ b/python/ray/tune/utils/callback.py @@ -25,6 +25,7 @@ def create_default_callbacks( callbacks: Optional[List[Callback]], sync_config: SyncConfig, metric: Optional[str] = None, + progress_metrics: Optional[List[str]] = None, ): """Create default callbacks for `Tuner.fit()`. @@ -60,7 +61,9 @@ def create_default_callbacks( ) if not has_trial_progress_callback: - trial_progress_callback = TrialProgressCallback(metric=metric) + trial_progress_callback = TrialProgressCallback( + metric=metric, progress_metrics=progress_metrics + ) callbacks.append(trial_progress_callback) # Track syncer obj/index to move callback after loggers diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 7911cf976..c4f05b044 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -200,6 +200,13 @@ class Algorithm(Trainable): # List of keys that are always fully overridden if present in any dict or sub-dict _override_all_key_list = ["off_policy_estimation_methods"] + _progress_metrics = [ + "episode_reward_mean", + "evaluation/episode_reward_mean", + "num_env_steps_sampled", + "num_env_steps_trained", + ] + @PublicAPI def __init__( self,