[tune] Report trials by state fairly (#6395)

* Fairly represented trial states.

* filter test

* Indent

* Add test to BUILD

* Address Eric's comments (show truncation by state).

* Sort trials, only show 20.

* Fix lint
This commit is contained in:
Ujval Misra 2019-12-10 14:56:54 -08:00 committed by Eric Liang
parent 16be483af7
commit 4e1d1ed00d
5 changed files with 163 additions and 20 deletions

View file

@ -72,6 +72,13 @@ py_test(
tags = ["jenkins_only"],
)
py_test(
name = "test_progress_reporter",
size = "small",
srcs = ["tests/test_progress_reporter.py"],
deps = [":tune_lib"],
)
py_test(
name = "test_ray_trial_executor",
size = "medium",

View file

@ -1,5 +1,7 @@
from __future__ import print_function
import collections
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX,
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
@ -25,6 +27,8 @@ REPORTED_REPRESENTATIONS = {
class ProgressReporter(object):
# TODO(ujvl): Expose ProgressReporter in tune.run for custom reporting.
def report(self, trial_runner):
"""Reports progress across all trials of the trial runner.
@ -49,7 +53,8 @@ class JupyterNotebookReporter(ProgressReporter):
"== Status ==",
memory_debug_str(),
trial_runner.debug_string(delim=delim),
trial_progress_str(trial_runner.get_trials(), fmt="html")
trial_progress_str(trial_runner.get_trials(), fmt="html"),
trial_errors_str(trial_runner.get_trials(), fmt="html"),
]
from IPython.display import clear_output
from IPython.core.display import display, HTML
@ -64,7 +69,8 @@ class CLIReporter(ProgressReporter):
"== Status ==",
memory_debug_str(),
trial_runner.debug_string(),
trial_progress_str(trial_runner.get_trials())
trial_progress_str(trial_runner.get_trials()),
trial_errors_str(trial_runner.get_trials()),
]
print("\n".join(messages) + "\n")
@ -90,7 +96,7 @@ def memory_debug_str():
"(or ray[debug]) to resolve)")
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=20):
"""Returns a human readable message for printing to the console.
This contains a table where each row represents a trial, its parameters
@ -109,52 +115,116 @@ def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
return delim.join(messages)
num_trials = len(trials)
trials_per_state = {}
trials_by_state = collections.defaultdict(list)
for t in trials:
trials_per_state[t.status] = trials_per_state.get(t.status, 0) + 1
messages.append("Number of trials: {} ({})".format(num_trials,
trials_per_state))
trials_by_state[t.status].append(t)
for local_dir in sorted({t.local_dir for t in trials}):
messages.append("Result logdir: {}".format(local_dir))
num_trials_strs = [
"{} {}".format(len(trials_by_state[state]), state)
for state in trials_by_state
]
messages.append("Number of trials: {} ({})".format(
num_trials, ", ".join(num_trials_strs)))
if num_trials > max_rows:
overflow = num_trials - max_rows
# TODO(ujvl): suggestion for users to view more rows.
messages.append("Table truncated to {} rows ({} overflow).".format(
max_rows, overflow))
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
trials = []
overflow_strs = []
for state in trials_by_state:
trials += trials_by_state_trunc[state]
overflow = len(trials_by_state[state]) - len(
trials_by_state_trunc[state])
overflow_strs.append("{} {}".format(overflow, state))
# Build overflow string.
overflow = num_trials - max_rows
overflow_str = ", ".join(overflow_strs)
messages.append("Table truncated to {} rows. {} trials ({}) not "
"shown.".format(max_rows, overflow, overflow_str))
# Pre-process trials to figure out what columns to show.
keys = list(metrics or DEFAULT_PROGRESS_KEYS)
keys = [k for k in keys if any(t.last_result.get(k) for t in trials)]
# Build trial rows.
trial_table = []
params = list(set().union(*[t.evaluated_params for t in trials]))
for trial in trials[:min(num_trials, max_rows)]:
trial_table.append(_get_trial_info(trial, params, keys))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
# Parse columns.
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
columns = ["Trial name", "status", "loc"]
columns += params + parsed_columns
messages.append(
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
return delim.join(messages)
# Build trial error rows.
def trial_errors_str(trials, fmt="psql", max_rows=20):
"""Returns a readable message regarding trial errors.
Args:
trials (List[Trial]): List of trials to get progress string for.
fmt (str): Output format (see tablefmt in tabulate API).
max_rows (int): Maximum number of rows in the error table.
"""
messages = []
failed = [t for t in trials if t.error_file]
if len(failed) > 0:
messages.append("Number of errored trials: {}".format(len(failed)))
num_failed = len(failed)
if num_failed > 0:
messages.append("Number of errored trials: {}".format(num_failed))
if num_failed > max_rows:
messages.append("Table truncated to {} rows ({} overflow)".format(
max_rows, num_failed - max_rows))
error_table = []
for trial in failed:
for trial in failed[:max_rows]:
row = [str(trial), trial.num_failures, trial.error_file]
error_table.append(row)
columns = ["Trial name", "# failures", "error file"]
messages.append(
tabulate(
error_table, headers=columns, tablefmt=fmt, showindex=False))
delim = "<br>" if fmt == "html" else "\n"
return delim.join(messages)
def _fair_filter_trials(trials_by_state, max_trials):
"""Filters trials such that each state is represented fairly.
The oldest trials are truncated if necessary.
Args:
trials_by_state (Dict[str, List[Trial]]: Trials by state.
max_trials (int): Maximum number of trials to return.
Returns:
Dict mapping state to List of fairly represented trials.
"""
num_trials_by_state = collections.defaultdict(int)
no_change = False
# Determine number of trials to keep per state.
while max_trials > 0 and not no_change:
no_change = True
for state in trials_by_state:
if num_trials_by_state[state] < len(trials_by_state[state]):
no_change = False
max_trials -= 1
num_trials_by_state[state] += 1
# Sort by start time, descending.
sorted_trials_by_state = {
state: sorted(
trials_by_state[state],
reverse=True,
key=lambda t: t.start_time if t.start_time else float("-inf"))
for state in trials_by_state
}
# Truncate oldest trials.
filtered_trials = {
state: sorted_trials_by_state[state][:num_trials_by_state[state]]
for state in trials_by_state
}
return filtered_trials
def _get_trial_info(trial, parameters, metrics):
"""Returns the following information about a trial:

View file

@ -0,0 +1,59 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import sys
import time
import unittest
from ray.tune.trial import Trial
from ray.tune.progress_reporter import _fair_filter_trials
if sys.version_info >= (3, 3):
from unittest.mock import MagicMock
else:
from mock import MagicMock
class ProgressReporterTest(unittest.TestCase):
def mock_trial(self, status, start_time):
mock = MagicMock()
mock.status = status
mock.start_time = start_time
return mock
def testFairFilterTrials(self):
"""Tests that trials are represented fairly."""
trials_by_state = collections.defaultdict(list)
# States for which trials are under and overrepresented
states_under = (Trial.PAUSED, Trial.ERROR)
states_over = (Trial.PENDING, Trial.RUNNING, Trial.TERMINATED)
max_trials = 13
num_trials_under = 2 # num of trials for each underrepresented state
num_trials_over = 10 # num of trials for each overrepresented state
for state in states_under:
for _ in range(num_trials_under):
trials_by_state[state].append(
self.mock_trial(state, time.time()))
for state in states_over:
for _ in range(num_trials_over):
trials_by_state[state].append(
self.mock_trial(state, time.time()))
filtered_trials_by_state = _fair_filter_trials(
trials_by_state, max_trials=max_trials)
for state in trials_by_state:
if state in states_under:
expected_num_trials = num_trials_under
else:
expected_num_trials = (max_trials - num_trials_under *
len(states_under)) / len(states_over)
state_trials = filtered_trials_by_state[state]
self.assertEqual(len(state_trials), expected_num_trials)
# Make sure trials are sorted newest-first within state.
for i in range(len(state_trials) - 1):
self.assertGreaterEqual(state_trials[i].start_time,
state_trials[i + 1].start_time)

View file

@ -162,6 +162,7 @@ class Trial(object):
self.export_formats = export_formats
self.status = Trial.PENDING
self.start_time = None
self.logdir = None
self.runner = None
self.result_logger = None
@ -251,6 +252,12 @@ class Trial(object):
"""Sets the location of the trial."""
self.address = location
def set_status(self, status):
"""Sets the status of the trial."""
if status == Trial.RUNNING and self.start_time is None:
self.start_time = time.time()
self.status = status
def close_logger(self):
"""Closes logger."""
if self.result_logger:

View file

@ -41,7 +41,7 @@ class TrialExecutor(object):
"""
logger.debug("Trial %s: Changing status from %s to %s.", trial,
trial.status, status)
trial.status = status
trial.set_status(status)
if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)