[rfc] [tune/rllib] Fetch _progress_metrics from trainable for verbose=2 display (#26967)

RLLibs trainables produce a large number of metrics which makethe log output with verbose=2 illegible. This PR introduces a private `_progress_metrics` property for trainables. If set, the trial progress callback will only print these metrics per default, unless overridden e.g. with a custom `TrialProgressCallback`.
This commit is contained in:
Kai Fricke 2022-07-27 16:04:23 +01:00 committed by GitHub
parent 87b164c84b
commit a5ea99cf95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 16 deletions

View file

@ -7,7 +7,7 @@ import os
import sys import sys
import time import time
import warnings import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
@ -16,7 +16,9 @@ from ray.tune.logger import logger, pretty_print
from ray.tune.result import ( from ray.tune.result import (
AUTO_RESULT_KEYS, AUTO_RESULT_KEYS,
DEFAULT_METRIC, DEFAULT_METRIC,
DONE,
EPISODE_REWARD_MEAN, EPISODE_REWARD_MEAN,
EXPERIMENT_TAG,
MEAN_ACCURACY, MEAN_ACCURACY,
MEAN_LOSS, MEAN_LOSS,
NODE_IP, NODE_IP,
@ -24,11 +26,14 @@ from ray.tune.result import (
TIME_TOTAL_S, TIME_TOTAL_S,
TIMESTEPS_TOTAL, TIMESTEPS_TOTAL,
TRAINING_ITERATION, TRAINING_ITERATION,
TRIAL_ID,
) )
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
from ray.tune.trainable import Trainable
from ray.tune.utils import unflattened_lookup from ray.tune.utils import unflattened_lookup
from ray.tune.utils.log import Verbosity, has_verbosity from ray.tune.utils.log import Verbosity, has_verbosity
from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.ml_utils.dict import flatten_dict
from ray.util.queue import Queue from ray.util.queue import Queue
try: try:
@ -52,6 +57,9 @@ except NameError:
IS_NOTEBOOK = False IS_NOTEBOOK = False
SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
@PublicAPI @PublicAPI
class ProgressReporter: class ProgressReporter:
"""Abstract class for experiment progress reporting. """Abstract class for experiment progress reporting.
@ -1093,11 +1101,18 @@ class TrialProgressCallback(Callback):
""" """
def __init__(self, metric: Optional[str] = None): def __init__(
self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
):
self._last_print = collections.defaultdict(float) self._last_print = collections.defaultdict(float)
self._completed_trials = set() self._completed_trials = set()
self._last_result_str = {} self._last_result_str = {}
self._metric = metric self._metric = metric
self._progress_metrics = set(progress_metrics or [])
# Only use progress metrics if at least two metrics are in there
if self._metric and self._progress_metrics:
self._progress_metrics.add(self._metric)
def on_trial_result( def on_trial_result(
self, self,
@ -1179,16 +1194,27 @@ class TrialProgressCallback(Callback):
self._last_print[trial] = time.time() self._last_print[trial] = time.time()
def _print_result(self, result: Dict): def _print_result(self, result: Dict):
print_result = result.copy() if self._progress_metrics:
print_result.pop("config", None) # If progress metrics are given, only report these
print_result.pop("hist_stats", None) flat_result = flatten_dict(result)
print_result.pop("trial_id", None)
print_result.pop("experiment_tag", None)
print_result.pop("done", None)
for auto_result in AUTO_RESULT_KEYS:
print_result.pop(auto_result, None)
print_result_str = ",".join([f"{k}={v}" for k, v in print_result.items()]) print_result = {}
for metric in self._progress_metrics:
print_result[metric] = flat_result.get(metric)
else:
# Else, skip auto populated results
print_result = result.copy()
for skip_result in SKIP_RESULTS_IN_REPORT:
print_result.pop(skip_result, None)
for auto_result in AUTO_RESULT_KEYS:
print_result.pop(auto_result, None)
print_result_str = ",".join(
[f"{k}={v}" for k, v in print_result.items() if v is not None]
)
return print_result_str return print_result_str
@ -1206,3 +1232,13 @@ def detect_reporter(**kwargs) -> TuneReporterBase:
else: else:
progress_reporter = CLIReporter(**kwargs) progress_reporter = CLIReporter(**kwargs)
return progress_reporter return progress_reporter
def detect_progress_metrics(
trainable: Optional[Union["Trainable", Callable]]
) -> Optional[List[str]]:
"""Detect progress metrics to report."""
if not trainable:
return None
return getattr(trainable, "_progress_metrics", None)

View file

@ -17,6 +17,7 @@ from ray.tune.progress_reporter import (
ProgressReporter, ProgressReporter,
RemoteReporterMixin, RemoteReporterMixin,
detect_reporter, detect_reporter,
detect_progress_metrics,
) )
from ray.tune.execution.ray_trial_executor import RayTrialExecutor from ray.tune.execution.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls, is_function_trainable from ray.tune.registry import get_trainable_cls, is_function_trainable
@ -59,22 +60,32 @@ from ray.util.queue import Empty, Queue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _check_default_resources_override( def _get_trainable(
run_identifier: Union[Experiment, str, Type, Callable] run_identifier: Union[Experiment, str, Type, Callable]
) -> bool: ) -> Optional[Type[Trainable]]:
if isinstance(run_identifier, Experiment): if isinstance(run_identifier, Experiment):
run_identifier = run_identifier.run_identifier run_identifier = run_identifier.run_identifier
if isinstance(run_identifier, type): if isinstance(run_identifier, type):
if not issubclass(run_identifier, Trainable): if not issubclass(run_identifier, Trainable):
# If obscure dtype, assume it is overridden. # If obscure dtype, assume it is overridden.
return True return None
trainable_cls = run_identifier trainable_cls = run_identifier
elif callable(run_identifier): elif callable(run_identifier):
trainable_cls = run_identifier trainable_cls = run_identifier
elif isinstance(run_identifier, str): elif isinstance(run_identifier, str):
trainable_cls = get_trainable_cls(run_identifier) trainable_cls = get_trainable_cls(run_identifier)
else: else:
return None
return trainable_cls
def _check_default_resources_override(
run_identifier: Union[Experiment, str, Type, Callable]
) -> bool:
trainable_cls = _get_trainable(run_identifier)
if not trainable_cls:
# Default to True # Default to True
return True return True
@ -610,8 +621,12 @@ def run(
"from your scheduler or from your call to `tune.run()`" "from your scheduler or from your call to `tune.run()`"
) )
progress_metrics = detect_progress_metrics(_get_trainable(run_or_experiment))
# Create syncer callbacks # Create syncer callbacks
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric) callbacks = create_default_callbacks(
callbacks, sync_config, metric=metric, progress_metrics=progress_metrics
)
runner = TrialRunner( runner = TrialRunner(
search_alg=search_alg, search_alg=search_alg,

View file

@ -25,6 +25,7 @@ def create_default_callbacks(
callbacks: Optional[List[Callback]], callbacks: Optional[List[Callback]],
sync_config: SyncConfig, sync_config: SyncConfig,
metric: Optional[str] = None, metric: Optional[str] = None,
progress_metrics: Optional[List[str]] = None,
): ):
"""Create default callbacks for `Tuner.fit()`. """Create default callbacks for `Tuner.fit()`.
@ -60,7 +61,9 @@ def create_default_callbacks(
) )
if not has_trial_progress_callback: if not has_trial_progress_callback:
trial_progress_callback = TrialProgressCallback(metric=metric) trial_progress_callback = TrialProgressCallback(
metric=metric, progress_metrics=progress_metrics
)
callbacks.append(trial_progress_callback) callbacks.append(trial_progress_callback)
# Track syncer obj/index to move callback after loggers # Track syncer obj/index to move callback after loggers

View file

@ -200,6 +200,13 @@ class Algorithm(Trainable):
# List of keys that are always fully overridden if present in any dict or sub-dict # List of keys that are always fully overridden if present in any dict or sub-dict
_override_all_key_list = ["off_policy_estimation_methods"] _override_all_key_list = ["off_policy_estimation_methods"]
_progress_metrics = [
"episode_reward_mean",
"evaluation/episode_reward_mean",
"num_env_steps_sampled",
"num_env_steps_trained",
]
@PublicAPI @PublicAPI
def __init__( def __init__(
self, self,