mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[tune] Report trials by state fairly (#6395)
* Fairly represented trial states. * filter test * Indent * Add test to BUILD * Address Eric's comments (show truncation by state). * Sort trials, only show 20. * Fix lint
This commit is contained in:
parent
16be483af7
commit
4e1d1ed00d
5 changed files with 163 additions and 20 deletions
|
@ -72,6 +72,13 @@ py_test(
|
|||
tags = ["jenkins_only"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_progress_reporter",
|
||||
size = "small",
|
||||
srcs = ["tests/test_progress_reporter.py"],
|
||||
deps = [":tune_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_ray_trial_executor",
|
||||
size = "medium",
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX,
|
||||
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
|
@ -25,6 +27,8 @@ REPORTED_REPRESENTATIONS = {
|
|||
|
||||
|
||||
class ProgressReporter(object):
|
||||
# TODO(ujvl): Expose ProgressReporter in tune.run for custom reporting.
|
||||
|
||||
def report(self, trial_runner):
|
||||
"""Reports progress across all trials of the trial runner.
|
||||
|
||||
|
@ -49,7 +53,8 @@ class JupyterNotebookReporter(ProgressReporter):
|
|||
"== Status ==",
|
||||
memory_debug_str(),
|
||||
trial_runner.debug_string(delim=delim),
|
||||
trial_progress_str(trial_runner.get_trials(), fmt="html")
|
||||
trial_progress_str(trial_runner.get_trials(), fmt="html"),
|
||||
trial_errors_str(trial_runner.get_trials(), fmt="html"),
|
||||
]
|
||||
from IPython.display import clear_output
|
||||
from IPython.core.display import display, HTML
|
||||
|
@ -64,7 +69,8 @@ class CLIReporter(ProgressReporter):
|
|||
"== Status ==",
|
||||
memory_debug_str(),
|
||||
trial_runner.debug_string(),
|
||||
trial_progress_str(trial_runner.get_trials())
|
||||
trial_progress_str(trial_runner.get_trials()),
|
||||
trial_errors_str(trial_runner.get_trials()),
|
||||
]
|
||||
print("\n".join(messages) + "\n")
|
||||
|
||||
|
@ -90,7 +96,7 @@ def memory_debug_str():
|
|||
"(or ray[debug]) to resolve)")
|
||||
|
||||
|
||||
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
||||
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20):
|
||||
"""Returns a human readable message for printing to the console.
|
||||
|
||||
This contains a table where each row represents a trial, its parameters
|
||||
|
@ -109,52 +115,116 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
|||
return delim.join(messages)
|
||||
|
||||
num_trials = len(trials)
|
||||
trials_per_state = {}
|
||||
trials_by_state = collections.defaultdict(list)
|
||||
for t in trials:
|
||||
trials_per_state[t.status] = trials_per_state.get(t.status, 0) + 1
|
||||
messages.append("Number of trials: {} ({})".format(num_trials,
|
||||
trials_per_state))
|
||||
trials_by_state[t.status].append(t)
|
||||
|
||||
for local_dir in sorted({t.local_dir for t in trials}):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
|
||||
num_trials_strs = [
|
||||
"{} {}".format(len(trials_by_state[state]), state)
|
||||
for state in trials_by_state
|
||||
]
|
||||
messages.append("Number of trials: {} ({})".format(
|
||||
num_trials, ", ".join(num_trials_strs)))
|
||||
|
||||
if num_trials > max_rows:
|
||||
overflow = num_trials - max_rows
|
||||
# TODO(ujvl): suggestion for users to view more rows.
|
||||
messages.append("Table truncated to {} rows ({} overflow).".format(
|
||||
max_rows, overflow))
|
||||
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
|
||||
trials = []
|
||||
overflow_strs = []
|
||||
for state in trials_by_state:
|
||||
trials += trials_by_state_trunc[state]
|
||||
overflow = len(trials_by_state[state]) - len(
|
||||
trials_by_state_trunc[state])
|
||||
overflow_strs.append("{} {}".format(overflow, state))
|
||||
# Build overflow string.
|
||||
overflow = num_trials - max_rows
|
||||
overflow_str = ", ".join(overflow_strs)
|
||||
messages.append("Table truncated to {} rows. {} trials ({}) not "
|
||||
"shown.".format(max_rows, overflow, overflow_str))
|
||||
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
keys = list(metrics or DEFAULT_PROGRESS_KEYS)
|
||||
keys = [k for k in keys if any(t.last_result.get(k) for t in trials)]
|
||||
|
||||
# Build trial rows.
|
||||
trial_table = []
|
||||
params = list(set().union(*[t.evaluated_params for t in trials]))
|
||||
for trial in trials[:min(num_trials, max_rows)]:
|
||||
trial_table.append(_get_trial_info(trial, params, keys))
|
||||
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
||||
# Parse columns.
|
||||
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
|
||||
columns = ["Trial name", "status", "loc"]
|
||||
columns += params + parsed_columns
|
||||
messages.append(
|
||||
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
return delim.join(messages)
|
||||
|
||||
# Build trial error rows.
|
||||
|
||||
def trial_errors_str(trials, fmt="psql", max_rows=20):
|
||||
"""Returns a readable message regarding trial errors.
|
||||
|
||||
Args:
|
||||
trials (List[Trial]): List of trials to get progress string for.
|
||||
fmt (str): Output format (see tablefmt in tabulate API).
|
||||
max_rows (int): Maximum number of rows in the error table.
|
||||
"""
|
||||
messages = []
|
||||
failed = [t for t in trials if t.error_file]
|
||||
if len(failed) > 0:
|
||||
messages.append("Number of errored trials: {}".format(len(failed)))
|
||||
num_failed = len(failed)
|
||||
if num_failed > 0:
|
||||
messages.append("Number of errored trials: {}".format(num_failed))
|
||||
if num_failed > max_rows:
|
||||
messages.append("Table truncated to {} rows ({} overflow)".format(
|
||||
max_rows, num_failed - max_rows))
|
||||
error_table = []
|
||||
for trial in failed:
|
||||
for trial in failed[:max_rows]:
|
||||
row = [str(trial), trial.num_failures, trial.error_file]
|
||||
error_table.append(row)
|
||||
columns = ["Trial name", "# failures", "error file"]
|
||||
messages.append(
|
||||
tabulate(
|
||||
error_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
|
||||
delim = "<br>" if fmt == "html" else "\n"
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
def _fair_filter_trials(trials_by_state, max_trials):
|
||||
"""Filters trials such that each state is represented fairly.
|
||||
|
||||
The oldest trials are truncated if necessary.
|
||||
|
||||
Args:
|
||||
trials_by_state (Dict[str, List[Trial]]: Trials by state.
|
||||
max_trials (int): Maximum number of trials to return.
|
||||
Returns:
|
||||
Dict mapping state to List of fairly represented trials.
|
||||
"""
|
||||
num_trials_by_state = collections.defaultdict(int)
|
||||
no_change = False
|
||||
# Determine number of trials to keep per state.
|
||||
while max_trials > 0 and not no_change:
|
||||
no_change = True
|
||||
for state in trials_by_state:
|
||||
if num_trials_by_state[state] < len(trials_by_state[state]):
|
||||
no_change = False
|
||||
max_trials -= 1
|
||||
num_trials_by_state[state] += 1
|
||||
# Sort by start time, descending.
|
||||
sorted_trials_by_state = {
|
||||
state: sorted(
|
||||
trials_by_state[state],
|
||||
reverse=True,
|
||||
key=lambda t: t.start_time if t.start_time else float("-inf"))
|
||||
for state in trials_by_state
|
||||
}
|
||||
# Truncate oldest trials.
|
||||
filtered_trials = {
|
||||
state: sorted_trials_by_state[state][:num_trials_by_state[state]]
|
||||
for state in trials_by_state
|
||||
}
|
||||
return filtered_trials
|
||||
|
||||
|
||||
def _get_trial_info(trial, parameters, metrics):
|
||||
"""Returns the following information about a trial:
|
||||
|
||||
|
|
59
python/ray/tune/tests/test_progress_reporter.py
Normal file
59
python/ray/tune/tests/test_progress_reporter.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.progress_reporter import _fair_filter_trials
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import MagicMock
|
||||
else:
|
||||
from mock import MagicMock
|
||||
|
||||
|
||||
class ProgressReporterTest(unittest.TestCase):
|
||||
def mock_trial(self, status, start_time):
|
||||
mock = MagicMock()
|
||||
mock.status = status
|
||||
mock.start_time = start_time
|
||||
return mock
|
||||
|
||||
def testFairFilterTrials(self):
|
||||
"""Tests that trials are represented fairly."""
|
||||
trials_by_state = collections.defaultdict(list)
|
||||
# States for which trials are under and overrepresented
|
||||
states_under = (Trial.PAUSED, Trial.ERROR)
|
||||
states_over = (Trial.PENDING, Trial.RUNNING, Trial.TERMINATED)
|
||||
|
||||
max_trials = 13
|
||||
num_trials_under = 2 # num of trials for each underrepresented state
|
||||
num_trials_over = 10 # num of trials for each overrepresented state
|
||||
|
||||
for state in states_under:
|
||||
for _ in range(num_trials_under):
|
||||
trials_by_state[state].append(
|
||||
self.mock_trial(state, time.time()))
|
||||
for state in states_over:
|
||||
for _ in range(num_trials_over):
|
||||
trials_by_state[state].append(
|
||||
self.mock_trial(state, time.time()))
|
||||
|
||||
filtered_trials_by_state = _fair_filter_trials(
|
||||
trials_by_state, max_trials=max_trials)
|
||||
for state in trials_by_state:
|
||||
if state in states_under:
|
||||
expected_num_trials = num_trials_under
|
||||
else:
|
||||
expected_num_trials = (max_trials - num_trials_under *
|
||||
len(states_under)) / len(states_over)
|
||||
state_trials = filtered_trials_by_state[state]
|
||||
self.assertEqual(len(state_trials), expected_num_trials)
|
||||
# Make sure trials are sorted newest-first within state.
|
||||
for i in range(len(state_trials) - 1):
|
||||
self.assertGreaterEqual(state_trials[i].start_time,
|
||||
state_trials[i + 1].start_time)
|
|
@ -162,6 +162,7 @@ class Trial(object):
|
|||
|
||||
self.export_formats = export_formats
|
||||
self.status = Trial.PENDING
|
||||
self.start_time = None
|
||||
self.logdir = None
|
||||
self.runner = None
|
||||
self.result_logger = None
|
||||
|
@ -251,6 +252,12 @@ class Trial(object):
|
|||
"""Sets the location of the trial."""
|
||||
self.address = location
|
||||
|
||||
def set_status(self, status):
|
||||
"""Sets the status of the trial."""
|
||||
if status == Trial.RUNNING and self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
self.status = status
|
||||
|
||||
def close_logger(self):
|
||||
"""Closes logger."""
|
||||
if self.result_logger:
|
||||
|
|
|
@ -41,7 +41,7 @@ class TrialExecutor(object):
|
|||
"""
|
||||
logger.debug("Trial %s: Changing status from %s to %s.", trial,
|
||||
trial.status, status)
|
||||
trial.status = status
|
||||
trial.set_status(status)
|
||||
if status in [Trial.TERMINATED, Trial.ERROR]:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue