[tune] Fix a number of reporter regressions and add end-to-end tests (#7274)

This commit is contained in:
Eric Liang 2020-02-25 14:31:56 -08:00 committed by GitHub
parent 75f683eec6
commit 1ea05a2c08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 241 additions and 71 deletions

View file

@ -269,8 +269,9 @@ def convert_to_experiment_list(experiments):
if (type(exp_list) is list
and all(isinstance(exp, Experiment) for exp in exp_list)):
if len(exp_list) > 1:
logger.warning("All experiments will be "
"using the same SearchAlgorithm.")
logger.info(
"Running with multiple concurrent experiments. "
"All experiments will be using the same SearchAlgorithm.")
else:
raise TuneError("Invalid argument: {}".format(experiments))

View file

@ -247,7 +247,7 @@ def wrap_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
logger.info("tune.track signature detected.")
logger.debug("tune.track signature detected.")
except Exception:
logger.info(
"Function inspection failed - assuming reporter signature.")

View file

@ -3,9 +3,8 @@ from __future__ import print_function
import collections
import time
from ray.tune.result import (CONFIG_PREFIX, EPISODE_REWARD_MEAN, MEAN_ACCURACY,
MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S,
TIMESTEPS_TOTAL)
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
from ray.tune.utils import flatten_dict
try:
@ -33,11 +32,12 @@ class ProgressReporter:
"""
raise NotImplementedError
def report(self, trials, *sys_info):
def report(self, trials, done, *sys_info):
"""Reports progress across trials.
Args:
trials (list[Trial]): Trials to report on.
done (bool): Whether this is the last progress report attempt.
sys_info: System info.
"""
raise NotImplementedError
@ -113,7 +113,7 @@ class TuneReporterBase(ProgressReporter):
"of metric columns.")
self._metric_columns.append(metric)
def _progress_str(self, trials, *sys_info, fmt="psql", delim="\n"):
def _progress_str(self, trials, done, *sys_info, fmt="psql", delim="\n"):
"""Returns full progress string.
This string contains a progress table and error table. The progress
@ -123,21 +123,24 @@ class TuneReporterBase(ProgressReporter):
Args:
trials (list[Trial]): Trials to report on.
done (bool): Whether this is the last progress report attempt.
fmt (str): Table format. See `tablefmt` in tabulate API.
delim (str): Delimiter between messages.
"""
messages = ["== Status ==", memory_debug_str(), *sys_info]
if self._max_progress_rows > 0:
if done:
max_progress = None
max_error = None
else:
max_progress = self._max_progress_rows
max_error = self._max_error_rows
messages.append(
trial_progress_str(
trials,
metric_columns=self._metric_columns,
fmt=fmt,
max_rows=self._max_progress_rows))
if self._max_error_rows > 0:
messages.append(
trial_errors_str(
trials, fmt=fmt, max_rows=self._max_error_rows))
max_rows=max_progress))
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
return delim.join(messages) + delim
@ -172,13 +175,13 @@ class JupyterNotebookReporter(TuneReporterBase):
max_report_frequency)
self._overwrite = overwrite
def report(self, trials, *sys_info):
def report(self, trials, done, *sys_info):
from IPython.display import clear_output
from IPython.core.display import display, HTML
if self._overwrite:
clear_output(wait=True)
progress_str = self._progress_str(
trials, *sys_info, fmt="html", delim="<br>")
trials, done, *sys_info, fmt="html", delim="<br>")
display(HTML(progress_str))
@ -209,8 +212,8 @@ class CLIReporter(TuneReporterBase):
super(CLIReporter, self).__init__(metric_columns, max_progress_rows,
max_error_rows, max_report_frequency)
def report(self, trials, *sys_info):
print(self._progress_str(trials, *sys_info))
def report(self, trials, done, *sys_info):
print(self._progress_str(trials, done, *sys_info))
def memory_debug_str():
@ -266,10 +269,8 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
num_trials_strs = [
"{} {}".format(len(trials_by_state[state]), state)
for state in trials_by_state
for state in sorted(trials_by_state)
]
messages.append("Number of trials: {} ({})".format(
num_trials, ", ".join(num_trials_strs)))
max_rows = max_rows or float("inf")
if num_trials > max_rows:
@ -277,16 +278,19 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
trials = []
overflow_strs = []
for state in trials_by_state:
for state in sorted(trials_by_state):
trials += trials_by_state_trunc[state]
overflow = len(trials_by_state[state]) - len(
num = len(trials_by_state[state]) - len(
trials_by_state_trunc[state])
overflow_strs.append("{} {}".format(overflow, state))
if num > 0:
overflow_strs.append("{} {}".format(num, state))
# Build overflow string.
overflow = num_trials - max_rows
overflow_str = ", ".join(overflow_strs)
messages.append("Table truncated to {} rows. {} trials ({}) not "
"shown.".format(max_rows, overflow, overflow_str))
else:
overflow = False
messages.append("Number of trials: {} ({})".format(
num_trials, ", ".join(num_trials_strs)))
# Pre-process trials to figure out what columns to show.
if isinstance(metric_columns, collections.Mapping):
@ -297,18 +301,22 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
k for k in keys if any(
t.last_result.get(k) is not None for t in trials)
]
keys = sorted(keys)
# Build trial rows.
params = list(set().union(*[t.evaluated_params for t in trials]))
params = sorted(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
# Format column headings
if isinstance(metric_columns, collections.Mapping):
formatted_columns = [metric_columns[k] for k in keys]
else:
formatted_columns = keys
columns = ["Trial name", "status", "loc"] + params + formatted_columns
columns = (["Trial name", "status", "loc"] + params + formatted_columns)
# Tabulate.
messages.append(
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
if overflow:
messages.append("... {} more trials not shown ({})".format(
overflow, overflow_str))
return delim.join(messages)
@ -326,8 +334,7 @@ def trial_errors_str(trials, fmt="psql", max_rows=None):
num_failed = len(failed)
if num_failed > 0:
messages.append("Number of errored trials: {}".format(num_failed))
max_rows = max_rows or float("inf")
if num_failed > max_rows:
if num_failed > (max_rows or float("inf")):
messages.append("Table truncated to {} rows ({} overflow)".format(
max_rows, num_failed - max_rows))
error_table = []
@ -358,7 +365,7 @@ def _fair_filter_trials(trials_by_state, max_trials):
# Determine number of trials to keep per state.
while max_trials > 0 and not no_change:
no_change = True
for state in trials_by_state:
for state in sorted(trials_by_state):
if num_trials_by_state[state] < len(trials_by_state[state]):
no_change = False
max_trials -= 1
@ -366,15 +373,13 @@ def _fair_filter_trials(trials_by_state, max_trials):
# Sort by start time, descending.
sorted_trials_by_state = {
state: sorted(
trials_by_state[state],
reverse=True,
key=lambda t: t.start_time if t.start_time else float("-inf"))
for state in trials_by_state
trials_by_state[state], reverse=False, key=lambda t: t.trial_id)
for state in sorted(trials_by_state)
}
# Truncate oldest trials.
filtered_trials = {
state: sorted_trials_by_state[state][:num_trials_by_state[state]]
for state in trials_by_state
for state in sorted(trials_by_state)
}
return filtered_trials
@ -390,7 +395,8 @@ def _get_trial_info(trial, parameters, metrics):
metrics (list[str]): Names of metrics to include.
"""
result = flatten_dict(trial.last_result)
config = flatten_dict(trial.config)
trial_info = [str(trial), trial.status, str(trial.location)]
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
trial_info += [config.get(param) for param in parameters]
trial_info += [result.get(metric) for metric in metrics]
return trial_info

View file

@ -1,5 +1,4 @@
import copy
import itertools
import logging
import json
import math
@ -270,9 +269,8 @@ class PopulationBasedTraining(FIFOScheduler):
"""
trial_name, trial_to_clone_name = (trial_state.orig_tag,
new_state.orig_tag)
trial_id = "".join(itertools.takewhile(str.isdigit, trial_name))
trial_to_clone_id = "".join(
itertools.takewhile(str.isdigit, trial_to_clone_name))
trial_id = trial.trial_id
trial_to_clone_id = trial_to_clone.trial_id
trial_path = os.path.join(trial.local_dir,
"pbt_policy_" + trial_id + ".txt")
trial_to_clone_path = os.path.join(

View file

@ -73,6 +73,7 @@ class BasicVariantGenerator(SearchAlgorithm):
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
for _ in range(unresolved_spec.get("num_samples", 1)):
for resolved_vars, spec in generate_variants(unresolved_spec):
trial_id = "%05d" % self._counter
experiment_tag = str(self._counter)
if resolved_vars:
experiment_tag += "_{}".format(format_vars(resolved_vars))
@ -82,6 +83,7 @@ class BasicVariantGenerator(SearchAlgorithm):
output_path,
self._parser,
evaluated_params=flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
def is_finished(self):

View file

@ -1,17 +1,134 @@
import collections
import time
import subprocess
import tempfile
import unittest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
from ray.tune.trial import Trial
from ray.tune.progress_reporter import CLIReporter, _fair_filter_trials
from ray.tune.progress_reporter import (CLIReporter, _fair_filter_trials,
trial_progress_str)
EXPECTED_RESULT_1 = """Result logdir: /foo
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
+--------------+------------+-------+-----+-----+
| Trial name | status | loc | a | b |
|--------------+------------+-------+-----+-----|
| 00001 | PENDING | here | 1 | 2 |
| 00002 | RUNNING | here | 2 | 4 |
| 00000 | TERMINATED | here | 0 | 0 |
+--------------+------------+-------+-----+-----+
... 2 more trials not shown (2 RUNNING)"""
EXPECTED_RESULT_2 = """Result logdir: /foo
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
+--------------+------------+-------+-----+-----+
| Trial name | status | loc | a | b |
|--------------+------------+-------+-----+-----|
| 00000 | TERMINATED | here | 0 | 0 |
| 00001 | PENDING | here | 1 | 2 |
| 00002 | RUNNING | here | 2 | 4 |
| 00003 | RUNNING | here | 3 | 6 |
| 00004 | RUNNING | here | 4 | 8 |
+--------------+------------+-------+-----+-----+"""
END_TO_END_COMMAND = """
import ray
from ray import tune
def f(config):
return {"done": True}
ray.init(num_cpus=1)
tune.run_experiments({
"one": {
"run": f,
"config": {
"a": tune.grid_search(list(range(10))),
},
},
"two": {
"run": f,
"config": {
"b": tune.grid_search(list(range(10))),
},
},
"three": {
"run": f,
"config": {
"c": tune.grid_search(list(range(10))),
},
},
}, reuse_actors=True, verbose=1)"""
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
+--------------+----------+-------+-----+-----+
| Trial name | status | loc | a | b |
|--------------+----------+-------+-----+-----|
| f_00001 | PENDING | | 1 | |
| f_00002 | PENDING | | 2 | |
| f_00003 | PENDING | | 3 | |
| f_00004 | PENDING | | 4 | |
| f_00005 | PENDING | | 5 | |
| f_00006 | PENDING | | 6 | |
| f_00007 | PENDING | | 7 | |
| f_00008 | PENDING | | 8 | |
| f_00009 | PENDING | | 9 | |
| f_00010 | PENDING | | | 0 |
| f_00011 | PENDING | | | 1 |
| f_00012 | PENDING | | | 2 |
| f_00013 | PENDING | | | 3 |
| f_00014 | PENDING | | | 4 |
| f_00015 | PENDING | | | 5 |
| f_00016 | PENDING | | | 6 |
| f_00017 | PENDING | | | 7 |
| f_00018 | PENDING | | | 8 |
| f_00019 | PENDING | | | 9 |
| f_00000 | RUNNING | | 0 | |
+--------------+----------+-------+-----+-----+
... 10 more trials not shown (10 PENDING)"""
EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
+--------------+------------+-------+-----+-----+-----+
| Trial name | status | loc | a | b | c |
|--------------+------------+-------+-----+-----+-----|
| f_00000 | TERMINATED | | 0 | | |
| f_00001 | TERMINATED | | 1 | | |
| f_00002 | TERMINATED | | 2 | | |
| f_00003 | TERMINATED | | 3 | | |
| f_00004 | TERMINATED | | 4 | | |
| f_00005 | TERMINATED | | 5 | | |
| f_00006 | TERMINATED | | 6 | | |
| f_00007 | TERMINATED | | 7 | | |
| f_00008 | TERMINATED | | 8 | | |
| f_00009 | TERMINATED | | 9 | | |
| f_00010 | TERMINATED | | | 0 | |
| f_00011 | TERMINATED | | | 1 | |
| f_00012 | TERMINATED | | | 2 | |
| f_00013 | TERMINATED | | | 3 | |
| f_00014 | TERMINATED | | | 4 | |
| f_00015 | TERMINATED | | | 5 | |
| f_00016 | TERMINATED | | | 6 | |
| f_00017 | TERMINATED | | | 7 | |
| f_00018 | TERMINATED | | | 8 | |
| f_00019 | TERMINATED | | | 9 | |
| f_00020 | TERMINATED | | | | 0 |
| f_00021 | TERMINATED | | | | 1 |
| f_00022 | TERMINATED | | | | 2 |
| f_00023 | TERMINATED | | | | 3 |
| f_00024 | TERMINATED | | | | 4 |
| f_00025 | TERMINATED | | | | 5 |
| f_00026 | TERMINATED | | | | 6 |
| f_00027 | TERMINATED | | | | 7 |
| f_00028 | TERMINATED | | | | 8 |
| f_00029 | TERMINATED | | | | 9 |
+--------------+------------+-------+-----+-----+-----+"""
class ProgressReporterTest(unittest.TestCase):
def mock_trial(self, status, start_time):
def mock_trial(self, status, i):
mock = MagicMock()
mock.status = status
mock.start_time = start_time
mock.trial_id = "%05d" % i
return mock
def testFairFilterTrials(self):
@ -25,14 +142,15 @@ class ProgressReporterTest(unittest.TestCase):
num_trials_under = 2 # num of trials for each underrepresented state
num_trials_over = 10 # num of trials for each overrepresented state
i = 0
for state in states_under:
for _ in range(num_trials_under):
trials_by_state[state].append(
self.mock_trial(state, time.time()))
trials_by_state[state].append(self.mock_trial(state, i))
i += 1
for state in states_over:
for _ in range(num_trials_over):
trials_by_state[state].append(
self.mock_trial(state, time.time()))
trials_by_state[state].append(self.mock_trial(state, i))
i += 1
filtered_trials_by_state = _fair_filter_trials(
trials_by_state, max_trials=max_trials)
@ -46,8 +164,7 @@ class ProgressReporterTest(unittest.TestCase):
self.assertEqual(len(state_trials), expected_num_trials)
# Make sure trials are sorted newest-first within state.
for i in range(len(state_trials) - 1):
self.assertGreaterEqual(state_trials[i].start_time,
state_trials[i + 1].start_time)
assert state_trials[i].trial_id < state_trials[i + 1].trial_id
def testAddMetricColumn(self):
"""Tests edge cases of add_metric_column."""
@ -67,3 +184,50 @@ class ProgressReporterTest(unittest.TestCase):
reporter = CLIReporter()
reporter.add_metric_column("foo", "bar")
self.assertIn("foo", reporter._metric_columns)
def testProgressStr(self):
trials = []
for i in range(5):
t = Mock()
if i == 0:
t.status = "TERMINATED"
elif i == 1:
t.status = "PENDING"
else:
t.status = "RUNNING"
t.trial_id = "%05d" % i
t.local_dir = "/foo"
t.location = "here"
t.config = {"a": i, "b": i * 2}
t.evaluated_params = t.config
t.last_result = {"config": {"a": i, "b": i * 2}}
t.__str__ = lambda self: self.trial_id
trials.append(t)
prog1 = trial_progress_str(trials, ["a", "b"], fmt="psql", max_rows=3)
print(prog1)
assert prog1 == EXPECTED_RESULT_1
prog2 = trial_progress_str(
trials, ["a", "b"], fmt="psql", max_rows=None)
print(prog2)
assert prog2 == EXPECTED_RESULT_2
def testEndToEndReporting(self):
with tempfile.NamedTemporaryFile(suffix=".py") as f:
f.write(END_TO_END_COMMAND.encode("utf-8"))
f.flush()
output = subprocess.check_output(["python3", f.name])
output = output.decode("utf-8")
try:
assert EXPECTED_END_TO_END_START in output
assert EXPECTED_END_TO_END_END in output
except Exception:
print("*** BEGIN OUTPUT ***")
print(output)
print("*** END OUTPUT ***")
raise
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -688,7 +688,7 @@ class BOHBSuite(unittest.TestCase):
class _MockTrial(Trial):
def __init__(self, i, config):
self.trainable_name = "trial_{}".format(i)
self.trial_id = Trial.generate_id()
self.trial_id = str(i)
self.config = config
self.experiment_tag = "{}tag".format(i)
self.trial_name_creator = None

View file

@ -204,7 +204,7 @@ class Trial:
self.sync_on_checkpoint = sync_on_checkpoint
self.checkpoint_manager = CheckpointManager(
keep_checkpoints_num, checkpoint_score_attr,
checkpoint_deleter(str(self), self.runner))
checkpoint_deleter(self._trainable_name(), self.runner))
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
@ -271,7 +271,8 @@ class Trial:
if not self.result_logger:
if not self.logdir:
self.logdir = Trial.create_logdir(
str(self) + "_" + self.experiment_tag, self.local_dir)
self._trainable_name() + "_" + self.experiment_tag,
self.local_dir)
else:
os.makedirs(self.logdir, exist_ok=True)
@ -296,7 +297,8 @@ class Trial:
def set_runner(self, runner):
self.runner = runner
self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner)
self.checkpoint_manager.delete = checkpoint_deleter(
self._trainable_name(), runner)
def set_location(self, location):
"""Sets the location of the trial."""
@ -465,9 +467,12 @@ class Trial:
return self.saving_to is not None
def __repr__(self):
return str(self)
return self._trainable_name(include_trial_id=True)
def __str__(self):
return self._trainable_name(include_trial_id=True)
def _trainable_name(self, include_trial_id=False):
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
Can be overridden with a custom string creator.
@ -482,6 +487,7 @@ class Trial:
identifier = "{}_{}".format(self.trainable_name, env)
else:
identifier = self.trainable_name
if include_trial_id:
identifier += "_" + self.trial_id
return identifier.replace("/", "_")

View file

@ -62,7 +62,7 @@ def _report_progress(runner, reporter, done=False):
if reporter.should_report(trials, done=done):
sched_debug_str = runner.scheduler_alg.debug_string()
executor_debug_str = runner.trial_executor.debug_string()
reporter.report(trials, sched_debug_str, executor_debug_str)
reporter.report(trials, done, sched_debug_str, executor_debug_str)
def run(run_or_experiment,
@ -242,10 +242,6 @@ def run(run_or_experiment,
experiments = run_or_experiment
else:
experiments = [run_or_experiment]
if len(experiments) > 1:
logger.info(
"Running multiple concurrent experiments is experimental and may "
"not work with certain features.")
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
@ -349,9 +345,6 @@ def run(run_or_experiment,
trials = runner.get_trials()
if return_trials:
return trials
logger.info("Returning an analysis object by default. You can call "
"`analysis.trials` to retrieve a list of trials. "
"This message will be removed in future versions of Tune.")
return ExperimentAnalysis(runner.checkpoint_file, trials=trials)