mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rfc] [tune/rllib] Fetch _progress_metrics from trainable for verbose=2 display (#26967)
RLLibs trainables produce a large number of metrics which makethe log output with verbose=2 illegible. This PR introduces a private `_progress_metrics` property for trainables. If set, the trial progress callback will only print these metrics per default, unless overridden e.g. with a custom `TrialProgressCallback`.
This commit is contained in:
parent
87b164c84b
commit
a5ea99cf95
4 changed files with 77 additions and 16 deletions
|
@ -7,7 +7,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -16,7 +16,9 @@ from ray.tune.logger import logger, pretty_print
|
||||||
from ray.tune.result import (
|
from ray.tune.result import (
|
||||||
AUTO_RESULT_KEYS,
|
AUTO_RESULT_KEYS,
|
||||||
DEFAULT_METRIC,
|
DEFAULT_METRIC,
|
||||||
|
DONE,
|
||||||
EPISODE_REWARD_MEAN,
|
EPISODE_REWARD_MEAN,
|
||||||
|
EXPERIMENT_TAG,
|
||||||
MEAN_ACCURACY,
|
MEAN_ACCURACY,
|
||||||
MEAN_LOSS,
|
MEAN_LOSS,
|
||||||
NODE_IP,
|
NODE_IP,
|
||||||
|
@ -24,11 +26,14 @@ from ray.tune.result import (
|
||||||
TIME_TOTAL_S,
|
TIME_TOTAL_S,
|
||||||
TIMESTEPS_TOTAL,
|
TIMESTEPS_TOTAL,
|
||||||
TRAINING_ITERATION,
|
TRAINING_ITERATION,
|
||||||
|
TRIAL_ID,
|
||||||
)
|
)
|
||||||
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
|
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
|
||||||
|
from ray.tune.trainable import Trainable
|
||||||
from ray.tune.utils import unflattened_lookup
|
from ray.tune.utils import unflattened_lookup
|
||||||
from ray.tune.utils.log import Verbosity, has_verbosity
|
from ray.tune.utils.log import Verbosity, has_verbosity
|
||||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||||
|
from ray.util.ml_utils.dict import flatten_dict
|
||||||
from ray.util.queue import Queue
|
from ray.util.queue import Queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -52,6 +57,9 @@ except NameError:
|
||||||
IS_NOTEBOOK = False
|
IS_NOTEBOOK = False
|
||||||
|
|
||||||
|
|
||||||
|
SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class ProgressReporter:
|
class ProgressReporter:
|
||||||
"""Abstract class for experiment progress reporting.
|
"""Abstract class for experiment progress reporting.
|
||||||
|
@ -1093,11 +1101,18 @@ class TrialProgressCallback(Callback):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, metric: Optional[str] = None):
|
def __init__(
|
||||||
|
self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
|
||||||
|
):
|
||||||
self._last_print = collections.defaultdict(float)
|
self._last_print = collections.defaultdict(float)
|
||||||
self._completed_trials = set()
|
self._completed_trials = set()
|
||||||
self._last_result_str = {}
|
self._last_result_str = {}
|
||||||
self._metric = metric
|
self._metric = metric
|
||||||
|
self._progress_metrics = set(progress_metrics or [])
|
||||||
|
|
||||||
|
# Only use progress metrics if at least two metrics are in there
|
||||||
|
if self._metric and self._progress_metrics:
|
||||||
|
self._progress_metrics.add(self._metric)
|
||||||
|
|
||||||
def on_trial_result(
|
def on_trial_result(
|
||||||
self,
|
self,
|
||||||
|
@ -1179,16 +1194,27 @@ class TrialProgressCallback(Callback):
|
||||||
self._last_print[trial] = time.time()
|
self._last_print[trial] = time.time()
|
||||||
|
|
||||||
def _print_result(self, result: Dict):
|
def _print_result(self, result: Dict):
|
||||||
print_result = result.copy()
|
if self._progress_metrics:
|
||||||
print_result.pop("config", None)
|
# If progress metrics are given, only report these
|
||||||
print_result.pop("hist_stats", None)
|
flat_result = flatten_dict(result)
|
||||||
print_result.pop("trial_id", None)
|
|
||||||
print_result.pop("experiment_tag", None)
|
|
||||||
print_result.pop("done", None)
|
|
||||||
for auto_result in AUTO_RESULT_KEYS:
|
|
||||||
print_result.pop(auto_result, None)
|
|
||||||
|
|
||||||
print_result_str = ",".join([f"{k}={v}" for k, v in print_result.items()])
|
print_result = {}
|
||||||
|
for metric in self._progress_metrics:
|
||||||
|
print_result[metric] = flat_result.get(metric)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Else, skip auto populated results
|
||||||
|
print_result = result.copy()
|
||||||
|
|
||||||
|
for skip_result in SKIP_RESULTS_IN_REPORT:
|
||||||
|
print_result.pop(skip_result, None)
|
||||||
|
|
||||||
|
for auto_result in AUTO_RESULT_KEYS:
|
||||||
|
print_result.pop(auto_result, None)
|
||||||
|
|
||||||
|
print_result_str = ",".join(
|
||||||
|
[f"{k}={v}" for k, v in print_result.items() if v is not None]
|
||||||
|
)
|
||||||
return print_result_str
|
return print_result_str
|
||||||
|
|
||||||
|
|
||||||
|
@ -1206,3 +1232,13 @@ def detect_reporter(**kwargs) -> TuneReporterBase:
|
||||||
else:
|
else:
|
||||||
progress_reporter = CLIReporter(**kwargs)
|
progress_reporter = CLIReporter(**kwargs)
|
||||||
return progress_reporter
|
return progress_reporter
|
||||||
|
|
||||||
|
|
||||||
|
def detect_progress_metrics(
|
||||||
|
trainable: Optional[Union["Trainable", Callable]]
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Detect progress metrics to report."""
|
||||||
|
if not trainable:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return getattr(trainable, "_progress_metrics", None)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ray.tune.progress_reporter import (
|
||||||
ProgressReporter,
|
ProgressReporter,
|
||||||
RemoteReporterMixin,
|
RemoteReporterMixin,
|
||||||
detect_reporter,
|
detect_reporter,
|
||||||
|
detect_progress_metrics,
|
||||||
)
|
)
|
||||||
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
|
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
|
||||||
from ray.tune.registry import get_trainable_cls, is_function_trainable
|
from ray.tune.registry import get_trainable_cls, is_function_trainable
|
||||||
|
@ -59,22 +60,32 @@ from ray.util.queue import Empty, Queue
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _check_default_resources_override(
|
def _get_trainable(
|
||||||
run_identifier: Union[Experiment, str, Type, Callable]
|
run_identifier: Union[Experiment, str, Type, Callable]
|
||||||
) -> bool:
|
) -> Optional[Type[Trainable]]:
|
||||||
if isinstance(run_identifier, Experiment):
|
if isinstance(run_identifier, Experiment):
|
||||||
run_identifier = run_identifier.run_identifier
|
run_identifier = run_identifier.run_identifier
|
||||||
|
|
||||||
if isinstance(run_identifier, type):
|
if isinstance(run_identifier, type):
|
||||||
if not issubclass(run_identifier, Trainable):
|
if not issubclass(run_identifier, Trainable):
|
||||||
# If obscure dtype, assume it is overridden.
|
# If obscure dtype, assume it is overridden.
|
||||||
return True
|
return None
|
||||||
trainable_cls = run_identifier
|
trainable_cls = run_identifier
|
||||||
elif callable(run_identifier):
|
elif callable(run_identifier):
|
||||||
trainable_cls = run_identifier
|
trainable_cls = run_identifier
|
||||||
elif isinstance(run_identifier, str):
|
elif isinstance(run_identifier, str):
|
||||||
trainable_cls = get_trainable_cls(run_identifier)
|
trainable_cls = get_trainable_cls(run_identifier)
|
||||||
else:
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return trainable_cls
|
||||||
|
|
||||||
|
|
||||||
|
def _check_default_resources_override(
|
||||||
|
run_identifier: Union[Experiment, str, Type, Callable]
|
||||||
|
) -> bool:
|
||||||
|
trainable_cls = _get_trainable(run_identifier)
|
||||||
|
if not trainable_cls:
|
||||||
# Default to True
|
# Default to True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -610,8 +621,12 @@ def run(
|
||||||
"from your scheduler or from your call to `tune.run()`"
|
"from your scheduler or from your call to `tune.run()`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
progress_metrics = detect_progress_metrics(_get_trainable(run_or_experiment))
|
||||||
|
|
||||||
# Create syncer callbacks
|
# Create syncer callbacks
|
||||||
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)
|
callbacks = create_default_callbacks(
|
||||||
|
callbacks, sync_config, metric=metric, progress_metrics=progress_metrics
|
||||||
|
)
|
||||||
|
|
||||||
runner = TrialRunner(
|
runner = TrialRunner(
|
||||||
search_alg=search_alg,
|
search_alg=search_alg,
|
||||||
|
|
|
@ -25,6 +25,7 @@ def create_default_callbacks(
|
||||||
callbacks: Optional[List[Callback]],
|
callbacks: Optional[List[Callback]],
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
|
progress_metrics: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""Create default callbacks for `Tuner.fit()`.
|
"""Create default callbacks for `Tuner.fit()`.
|
||||||
|
|
||||||
|
@ -60,7 +61,9 @@ def create_default_callbacks(
|
||||||
)
|
)
|
||||||
|
|
||||||
if not has_trial_progress_callback:
|
if not has_trial_progress_callback:
|
||||||
trial_progress_callback = TrialProgressCallback(metric=metric)
|
trial_progress_callback = TrialProgressCallback(
|
||||||
|
metric=metric, progress_metrics=progress_metrics
|
||||||
|
)
|
||||||
callbacks.append(trial_progress_callback)
|
callbacks.append(trial_progress_callback)
|
||||||
|
|
||||||
# Track syncer obj/index to move callback after loggers
|
# Track syncer obj/index to move callback after loggers
|
||||||
|
|
|
@ -200,6 +200,13 @@ class Algorithm(Trainable):
|
||||||
# List of keys that are always fully overridden if present in any dict or sub-dict
|
# List of keys that are always fully overridden if present in any dict or sub-dict
|
||||||
_override_all_key_list = ["off_policy_estimation_methods"]
|
_override_all_key_list = ["off_policy_estimation_methods"]
|
||||||
|
|
||||||
|
_progress_metrics = [
|
||||||
|
"episode_reward_mean",
|
||||||
|
"evaluation/episode_reward_mean",
|
||||||
|
"num_env_steps_sampled",
|
||||||
|
"num_env_steps_trained",
|
||||||
|
]
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Reference in a new issue