mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rfc] [tune/rllib] Fetch _progress_metrics from trainable for verbose=2 display (#26967)
RLLibs trainables produce a large number of metrics which makethe log output with verbose=2 illegible. This PR introduces a private `_progress_metrics` property for trainables. If set, the trial progress callback will only print these metrics per default, unless overridden e.g. with a custom `TrialProgressCallback`.
This commit is contained in:
parent
87b164c84b
commit
a5ea99cf95
4 changed files with 77 additions and 16 deletions
|
@ -7,7 +7,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -16,7 +16,9 @@ from ray.tune.logger import logger, pretty_print
|
|||
from ray.tune.result import (
|
||||
AUTO_RESULT_KEYS,
|
||||
DEFAULT_METRIC,
|
||||
DONE,
|
||||
EPISODE_REWARD_MEAN,
|
||||
EXPERIMENT_TAG,
|
||||
MEAN_ACCURACY,
|
||||
MEAN_LOSS,
|
||||
NODE_IP,
|
||||
|
@ -24,11 +26,14 @@ from ray.tune.result import (
|
|||
TIME_TOTAL_S,
|
||||
TIMESTEPS_TOTAL,
|
||||
TRAINING_ITERATION,
|
||||
TRIAL_ID,
|
||||
)
|
||||
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils import unflattened_lookup
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity
|
||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.util.ml_utils.dict import flatten_dict
|
||||
from ray.util.queue import Queue
|
||||
|
||||
try:
|
||||
|
@ -52,6 +57,9 @@ except NameError:
|
|||
IS_NOTEBOOK = False
|
||||
|
||||
|
||||
SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class ProgressReporter:
|
||||
"""Abstract class for experiment progress reporting.
|
||||
|
@ -1093,11 +1101,18 @@ class TrialProgressCallback(Callback):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, metric: Optional[str] = None):
|
||||
def __init__(
|
||||
self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
|
||||
):
|
||||
self._last_print = collections.defaultdict(float)
|
||||
self._completed_trials = set()
|
||||
self._last_result_str = {}
|
||||
self._metric = metric
|
||||
self._progress_metrics = set(progress_metrics or [])
|
||||
|
||||
# Only use progress metrics if at least two metrics are in there
|
||||
if self._metric and self._progress_metrics:
|
||||
self._progress_metrics.add(self._metric)
|
||||
|
||||
def on_trial_result(
|
||||
self,
|
||||
|
@ -1179,16 +1194,27 @@ class TrialProgressCallback(Callback):
|
|||
self._last_print[trial] = time.time()
|
||||
|
||||
def _print_result(self, result: Dict):
|
||||
print_result = result.copy()
|
||||
print_result.pop("config", None)
|
||||
print_result.pop("hist_stats", None)
|
||||
print_result.pop("trial_id", None)
|
||||
print_result.pop("experiment_tag", None)
|
||||
print_result.pop("done", None)
|
||||
for auto_result in AUTO_RESULT_KEYS:
|
||||
print_result.pop(auto_result, None)
|
||||
if self._progress_metrics:
|
||||
# If progress metrics are given, only report these
|
||||
flat_result = flatten_dict(result)
|
||||
|
||||
print_result_str = ",".join([f"{k}={v}" for k, v in print_result.items()])
|
||||
print_result = {}
|
||||
for metric in self._progress_metrics:
|
||||
print_result[metric] = flat_result.get(metric)
|
||||
|
||||
else:
|
||||
# Else, skip auto populated results
|
||||
print_result = result.copy()
|
||||
|
||||
for skip_result in SKIP_RESULTS_IN_REPORT:
|
||||
print_result.pop(skip_result, None)
|
||||
|
||||
for auto_result in AUTO_RESULT_KEYS:
|
||||
print_result.pop(auto_result, None)
|
||||
|
||||
print_result_str = ",".join(
|
||||
[f"{k}={v}" for k, v in print_result.items() if v is not None]
|
||||
)
|
||||
return print_result_str
|
||||
|
||||
|
||||
|
@ -1206,3 +1232,13 @@ def detect_reporter(**kwargs) -> TuneReporterBase:
|
|||
else:
|
||||
progress_reporter = CLIReporter(**kwargs)
|
||||
return progress_reporter
|
||||
|
||||
|
||||
def detect_progress_metrics(
|
||||
trainable: Optional[Union["Trainable", Callable]]
|
||||
) -> Optional[List[str]]:
|
||||
"""Detect progress metrics to report."""
|
||||
if not trainable:
|
||||
return None
|
||||
|
||||
return getattr(trainable, "_progress_metrics", None)
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.tune.progress_reporter import (
|
|||
ProgressReporter,
|
||||
RemoteReporterMixin,
|
||||
detect_reporter,
|
||||
detect_progress_metrics,
|
||||
)
|
||||
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import get_trainable_cls, is_function_trainable
|
||||
|
@ -59,22 +60,32 @@ from ray.util.queue import Empty, Queue
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_default_resources_override(
|
||||
def _get_trainable(
|
||||
run_identifier: Union[Experiment, str, Type, Callable]
|
||||
) -> bool:
|
||||
) -> Optional[Type[Trainable]]:
|
||||
if isinstance(run_identifier, Experiment):
|
||||
run_identifier = run_identifier.run_identifier
|
||||
|
||||
if isinstance(run_identifier, type):
|
||||
if not issubclass(run_identifier, Trainable):
|
||||
# If obscure dtype, assume it is overridden.
|
||||
return True
|
||||
return None
|
||||
trainable_cls = run_identifier
|
||||
elif callable(run_identifier):
|
||||
trainable_cls = run_identifier
|
||||
elif isinstance(run_identifier, str):
|
||||
trainable_cls = get_trainable_cls(run_identifier)
|
||||
else:
|
||||
return None
|
||||
|
||||
return trainable_cls
|
||||
|
||||
|
||||
def _check_default_resources_override(
|
||||
run_identifier: Union[Experiment, str, Type, Callable]
|
||||
) -> bool:
|
||||
trainable_cls = _get_trainable(run_identifier)
|
||||
if not trainable_cls:
|
||||
# Default to True
|
||||
return True
|
||||
|
||||
|
@ -610,8 +621,12 @@ def run(
|
|||
"from your scheduler or from your call to `tune.run()`"
|
||||
)
|
||||
|
||||
progress_metrics = detect_progress_metrics(_get_trainable(run_or_experiment))
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)
|
||||
callbacks = create_default_callbacks(
|
||||
callbacks, sync_config, metric=metric, progress_metrics=progress_metrics
|
||||
)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
|
|
|
@ -25,6 +25,7 @@ def create_default_callbacks(
|
|||
callbacks: Optional[List[Callback]],
|
||||
sync_config: SyncConfig,
|
||||
metric: Optional[str] = None,
|
||||
progress_metrics: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create default callbacks for `Tuner.fit()`.
|
||||
|
||||
|
@ -60,7 +61,9 @@ def create_default_callbacks(
|
|||
)
|
||||
|
||||
if not has_trial_progress_callback:
|
||||
trial_progress_callback = TrialProgressCallback(metric=metric)
|
||||
trial_progress_callback = TrialProgressCallback(
|
||||
metric=metric, progress_metrics=progress_metrics
|
||||
)
|
||||
callbacks.append(trial_progress_callback)
|
||||
|
||||
# Track syncer obj/index to move callback after loggers
|
||||
|
|
|
@ -200,6 +200,13 @@ class Algorithm(Trainable):
|
|||
# List of keys that are always fully overridden if present in any dict or sub-dict
|
||||
_override_all_key_list = ["off_policy_estimation_methods"]
|
||||
|
||||
_progress_metrics = [
|
||||
"episode_reward_mean",
|
||||
"evaluation/episode_reward_mean",
|
||||
"num_env_steps_sampled",
|
||||
"num_env_steps_trained",
|
||||
]
|
||||
|
||||
@PublicAPI
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Add table
Reference in a new issue