[rfc] [tune/rllib] Fetch _progress_metrics from trainable for verbose=2 display (#26967)

RLLibs trainables produce a large number of metrics which makethe log output with verbose=2 illegible. This PR introduces a private `_progress_metrics` property for trainables. If set, the trial progress callback will only print these metrics per default, unless overridden e.g. with a custom `TrialProgressCallback`.
This commit is contained in:
Kai Fricke 2022-07-27 16:04:23 +01:00 committed by GitHub
parent 87b164c84b
commit a5ea99cf95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 16 deletions

View file

@ -7,7 +7,7 @@ import os
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
@ -16,7 +16,9 @@ from ray.tune.logger import logger, pretty_print
from ray.tune.result import (
AUTO_RESULT_KEYS,
DEFAULT_METRIC,
DONE,
EPISODE_REWARD_MEAN,
EXPERIMENT_TAG,
MEAN_ACCURACY,
MEAN_LOSS,
NODE_IP,
@ -24,11 +26,14 @@ from ray.tune.result import (
TIME_TOTAL_S,
TIMESTEPS_TOTAL,
TRAINING_ITERATION,
TRIAL_ID,
)
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
from ray.tune.trainable import Trainable
from ray.tune.utils import unflattened_lookup
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.ml_utils.dict import flatten_dict
from ray.util.queue import Queue
try:
@ -52,6 +57,9 @@ except NameError:
IS_NOTEBOOK = False
SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
@PublicAPI
class ProgressReporter:
"""Abstract class for experiment progress reporting.
@ -1093,11 +1101,18 @@ class TrialProgressCallback(Callback):
"""
def __init__(self, metric: Optional[str] = None):
def __init__(
self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
):
self._last_print = collections.defaultdict(float)
self._completed_trials = set()
self._last_result_str = {}
self._metric = metric
self._progress_metrics = set(progress_metrics or [])
# Only use progress metrics if at least two metrics are in there
if self._metric and self._progress_metrics:
self._progress_metrics.add(self._metric)
def on_trial_result(
self,
@ -1179,16 +1194,27 @@ class TrialProgressCallback(Callback):
self._last_print[trial] = time.time()
def _print_result(self, result: Dict):
print_result = result.copy()
print_result.pop("config", None)
print_result.pop("hist_stats", None)
print_result.pop("trial_id", None)
print_result.pop("experiment_tag", None)
print_result.pop("done", None)
for auto_result in AUTO_RESULT_KEYS:
print_result.pop(auto_result, None)
if self._progress_metrics:
# If progress metrics are given, only report these
flat_result = flatten_dict(result)
print_result_str = ",".join([f"{k}={v}" for k, v in print_result.items()])
print_result = {}
for metric in self._progress_metrics:
print_result[metric] = flat_result.get(metric)
else:
# Else, skip auto populated results
print_result = result.copy()
for skip_result in SKIP_RESULTS_IN_REPORT:
print_result.pop(skip_result, None)
for auto_result in AUTO_RESULT_KEYS:
print_result.pop(auto_result, None)
print_result_str = ",".join(
[f"{k}={v}" for k, v in print_result.items() if v is not None]
)
return print_result_str
@ -1206,3 +1232,13 @@ def detect_reporter(**kwargs) -> TuneReporterBase:
else:
progress_reporter = CLIReporter(**kwargs)
return progress_reporter
def detect_progress_metrics(
trainable: Optional[Union["Trainable", Callable]]
) -> Optional[List[str]]:
"""Detect progress metrics to report."""
if not trainable:
return None
return getattr(trainable, "_progress_metrics", None)

View file

@ -17,6 +17,7 @@ from ray.tune.progress_reporter import (
ProgressReporter,
RemoteReporterMixin,
detect_reporter,
detect_progress_metrics,
)
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls, is_function_trainable
@ -59,22 +60,32 @@ from ray.util.queue import Empty, Queue
logger = logging.getLogger(__name__)
def _check_default_resources_override(
def _get_trainable(
run_identifier: Union[Experiment, str, Type, Callable]
) -> bool:
) -> Optional[Type[Trainable]]:
if isinstance(run_identifier, Experiment):
run_identifier = run_identifier.run_identifier
if isinstance(run_identifier, type):
if not issubclass(run_identifier, Trainable):
# If obscure dtype, assume it is overridden.
return True
return None
trainable_cls = run_identifier
elif callable(run_identifier):
trainable_cls = run_identifier
elif isinstance(run_identifier, str):
trainable_cls = get_trainable_cls(run_identifier)
else:
return None
return trainable_cls
def _check_default_resources_override(
run_identifier: Union[Experiment, str, Type, Callable]
) -> bool:
trainable_cls = _get_trainable(run_identifier)
if not trainable_cls:
# Default to True
return True
@ -610,8 +621,12 @@ def run(
"from your scheduler or from your call to `tune.run()`"
)
progress_metrics = detect_progress_metrics(_get_trainable(run_or_experiment))
# Create syncer callbacks
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)
callbacks = create_default_callbacks(
callbacks, sync_config, metric=metric, progress_metrics=progress_metrics
)
runner = TrialRunner(
search_alg=search_alg,

View file

@ -25,6 +25,7 @@ def create_default_callbacks(
callbacks: Optional[List[Callback]],
sync_config: SyncConfig,
metric: Optional[str] = None,
progress_metrics: Optional[List[str]] = None,
):
"""Create default callbacks for `Tuner.fit()`.
@ -60,7 +61,9 @@ def create_default_callbacks(
)
if not has_trial_progress_callback:
trial_progress_callback = TrialProgressCallback(metric=metric)
trial_progress_callback = TrialProgressCallback(
metric=metric, progress_metrics=progress_metrics
)
callbacks.append(trial_progress_callback)
# Track syncer obj/index to move callback after loggers

View file

@ -200,6 +200,13 @@ class Algorithm(Trainable):
# List of keys that are always fully overridden if present in any dict or sub-dict
_override_all_key_list = ["off_policy_estimation_methods"]
_progress_metrics = [
"episode_reward_mean",
"evaluation/episode_reward_mean",
"num_env_steps_sampled",
"num_env_steps_trained",
]
@PublicAPI
def __init__(
self,