mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Fix a number of reporter regressions and add end-to-end tests (#7274)
This commit is contained in:
parent
75f683eec6
commit
1ea05a2c08
9 changed files with 241 additions and 71 deletions
|
@ -269,8 +269,9 @@ def convert_to_experiment_list(experiments):
|
|||
if (type(exp_list) is list
|
||||
and all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
if len(exp_list) > 1:
|
||||
logger.warning("All experiments will be "
|
||||
"using the same SearchAlgorithm.")
|
||||
logger.info(
|
||||
"Running with multiple concurrent experiments. "
|
||||
"All experiments will be using the same SearchAlgorithm.")
|
||||
else:
|
||||
raise TuneError("Invalid argument: {}".format(experiments))
|
||||
|
||||
|
|
|
@ -247,7 +247,7 @@ def wrap_function(train_func):
|
|||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
logger.info("tune.track signature detected.")
|
||||
logger.debug("tune.track signature detected.")
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Function inspection failed - assuming reporter signature.")
|
||||
|
|
|
@ -3,9 +3,8 @@ from __future__ import print_function
|
|||
import collections
|
||||
import time
|
||||
|
||||
from ray.tune.result import (CONFIG_PREFIX, EPISODE_REWARD_MEAN, MEAN_ACCURACY,
|
||||
MEAN_LOSS, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
TIMESTEPS_TOTAL)
|
||||
from ray.tune.result import (EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL)
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
try:
|
||||
|
@ -33,11 +32,12 @@ class ProgressReporter:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report(self, trials, *sys_info):
|
||||
def report(self, trials, done, *sys_info):
|
||||
"""Reports progress across trials.
|
||||
|
||||
Args:
|
||||
trials (list[Trial]): Trials to report on.
|
||||
done (bool): Whether this is the last progress report attempt.
|
||||
sys_info: System info.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -113,7 +113,7 @@ class TuneReporterBase(ProgressReporter):
|
|||
"of metric columns.")
|
||||
self._metric_columns.append(metric)
|
||||
|
||||
def _progress_str(self, trials, *sys_info, fmt="psql", delim="\n"):
|
||||
def _progress_str(self, trials, done, *sys_info, fmt="psql", delim="\n"):
|
||||
"""Returns full progress string.
|
||||
|
||||
This string contains a progress table and error table. The progress
|
||||
|
@ -123,21 +123,24 @@ class TuneReporterBase(ProgressReporter):
|
|||
|
||||
Args:
|
||||
trials (list[Trial]): Trials to report on.
|
||||
done (bool): Whether this is the last progress report attempt.
|
||||
fmt (str): Table format. See `tablefmt` in tabulate API.
|
||||
delim (str): Delimiter between messages.
|
||||
"""
|
||||
messages = ["== Status ==", memory_debug_str(), *sys_info]
|
||||
if self._max_progress_rows > 0:
|
||||
if done:
|
||||
max_progress = None
|
||||
max_error = None
|
||||
else:
|
||||
max_progress = self._max_progress_rows
|
||||
max_error = self._max_error_rows
|
||||
messages.append(
|
||||
trial_progress_str(
|
||||
trials,
|
||||
metric_columns=self._metric_columns,
|
||||
fmt=fmt,
|
||||
max_rows=self._max_progress_rows))
|
||||
if self._max_error_rows > 0:
|
||||
messages.append(
|
||||
trial_errors_str(
|
||||
trials, fmt=fmt, max_rows=self._max_error_rows))
|
||||
max_rows=max_progress))
|
||||
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
|
||||
return delim.join(messages) + delim
|
||||
|
||||
|
||||
|
@ -172,13 +175,13 @@ class JupyterNotebookReporter(TuneReporterBase):
|
|||
max_report_frequency)
|
||||
self._overwrite = overwrite
|
||||
|
||||
def report(self, trials, *sys_info):
|
||||
def report(self, trials, done, *sys_info):
|
||||
from IPython.display import clear_output
|
||||
from IPython.core.display import display, HTML
|
||||
if self._overwrite:
|
||||
clear_output(wait=True)
|
||||
progress_str = self._progress_str(
|
||||
trials, *sys_info, fmt="html", delim="<br>")
|
||||
trials, done, *sys_info, fmt="html", delim="<br>")
|
||||
display(HTML(progress_str))
|
||||
|
||||
|
||||
|
@ -209,8 +212,8 @@ class CLIReporter(TuneReporterBase):
|
|||
super(CLIReporter, self).__init__(metric_columns, max_progress_rows,
|
||||
max_error_rows, max_report_frequency)
|
||||
|
||||
def report(self, trials, *sys_info):
|
||||
print(self._progress_str(trials, *sys_info))
|
||||
def report(self, trials, done, *sys_info):
|
||||
print(self._progress_str(trials, done, *sys_info))
|
||||
|
||||
|
||||
def memory_debug_str():
|
||||
|
@ -266,10 +269,8 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
|||
|
||||
num_trials_strs = [
|
||||
"{} {}".format(len(trials_by_state[state]), state)
|
||||
for state in trials_by_state
|
||||
for state in sorted(trials_by_state)
|
||||
]
|
||||
messages.append("Number of trials: {} ({})".format(
|
||||
num_trials, ", ".join(num_trials_strs)))
|
||||
|
||||
max_rows = max_rows or float("inf")
|
||||
if num_trials > max_rows:
|
||||
|
@ -277,16 +278,19 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
|||
trials_by_state_trunc = _fair_filter_trials(trials_by_state, max_rows)
|
||||
trials = []
|
||||
overflow_strs = []
|
||||
for state in trials_by_state:
|
||||
for state in sorted(trials_by_state):
|
||||
trials += trials_by_state_trunc[state]
|
||||
overflow = len(trials_by_state[state]) - len(
|
||||
num = len(trials_by_state[state]) - len(
|
||||
trials_by_state_trunc[state])
|
||||
overflow_strs.append("{} {}".format(overflow, state))
|
||||
if num > 0:
|
||||
overflow_strs.append("{} {}".format(num, state))
|
||||
# Build overflow string.
|
||||
overflow = num_trials - max_rows
|
||||
overflow_str = ", ".join(overflow_strs)
|
||||
messages.append("Table truncated to {} rows. {} trials ({}) not "
|
||||
"shown.".format(max_rows, overflow, overflow_str))
|
||||
else:
|
||||
overflow = False
|
||||
messages.append("Number of trials: {} ({})".format(
|
||||
num_trials, ", ".join(num_trials_strs)))
|
||||
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
if isinstance(metric_columns, collections.Mapping):
|
||||
|
@ -297,18 +301,22 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
|||
k for k in keys if any(
|
||||
t.last_result.get(k) is not None for t in trials)
|
||||
]
|
||||
keys = sorted(keys)
|
||||
# Build trial rows.
|
||||
params = list(set().union(*[t.evaluated_params for t in trials]))
|
||||
params = sorted(set().union(*[t.evaluated_params for t in trials]))
|
||||
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
||||
# Format column headings
|
||||
if isinstance(metric_columns, collections.Mapping):
|
||||
formatted_columns = [metric_columns[k] for k in keys]
|
||||
else:
|
||||
formatted_columns = keys
|
||||
columns = ["Trial name", "status", "loc"] + params + formatted_columns
|
||||
columns = (["Trial name", "status", "loc"] + params + formatted_columns)
|
||||
# Tabulate.
|
||||
messages.append(
|
||||
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
if overflow:
|
||||
messages.append("... {} more trials not shown ({})".format(
|
||||
overflow, overflow_str))
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
|
@ -326,8 +334,7 @@ def trial_errors_str(trials, fmt="psql", max_rows=None):
|
|||
num_failed = len(failed)
|
||||
if num_failed > 0:
|
||||
messages.append("Number of errored trials: {}".format(num_failed))
|
||||
max_rows = max_rows or float("inf")
|
||||
if num_failed > max_rows:
|
||||
if num_failed > (max_rows or float("inf")):
|
||||
messages.append("Table truncated to {} rows ({} overflow)".format(
|
||||
max_rows, num_failed - max_rows))
|
||||
error_table = []
|
||||
|
@ -358,7 +365,7 @@ def _fair_filter_trials(trials_by_state, max_trials):
|
|||
# Determine number of trials to keep per state.
|
||||
while max_trials > 0 and not no_change:
|
||||
no_change = True
|
||||
for state in trials_by_state:
|
||||
for state in sorted(trials_by_state):
|
||||
if num_trials_by_state[state] < len(trials_by_state[state]):
|
||||
no_change = False
|
||||
max_trials -= 1
|
||||
|
@ -366,15 +373,13 @@ def _fair_filter_trials(trials_by_state, max_trials):
|
|||
# Sort by start time, descending.
|
||||
sorted_trials_by_state = {
|
||||
state: sorted(
|
||||
trials_by_state[state],
|
||||
reverse=True,
|
||||
key=lambda t: t.start_time if t.start_time else float("-inf"))
|
||||
for state in trials_by_state
|
||||
trials_by_state[state], reverse=False, key=lambda t: t.trial_id)
|
||||
for state in sorted(trials_by_state)
|
||||
}
|
||||
# Truncate oldest trials.
|
||||
filtered_trials = {
|
||||
state: sorted_trials_by_state[state][:num_trials_by_state[state]]
|
||||
for state in trials_by_state
|
||||
for state in sorted(trials_by_state)
|
||||
}
|
||||
return filtered_trials
|
||||
|
||||
|
@ -390,7 +395,8 @@ def _get_trial_info(trial, parameters, metrics):
|
|||
metrics (list[str]): Names of metrics to include.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
config = flatten_dict(trial.config)
|
||||
trial_info = [str(trial), trial.status, str(trial.location)]
|
||||
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
|
||||
trial_info += [config.get(param) for param in parameters]
|
||||
trial_info += [result.get(metric) for metric in metrics]
|
||||
return trial_info
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
|
@ -270,9 +269,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
"""
|
||||
trial_name, trial_to_clone_name = (trial_state.orig_tag,
|
||||
new_state.orig_tag)
|
||||
trial_id = "".join(itertools.takewhile(str.isdigit, trial_name))
|
||||
trial_to_clone_id = "".join(
|
||||
itertools.takewhile(str.isdigit, trial_to_clone_name))
|
||||
trial_id = trial.trial_id
|
||||
trial_to_clone_id = trial_to_clone.trial_id
|
||||
trial_path = os.path.join(trial.local_dir,
|
||||
"pbt_policy_" + trial_id + ".txt")
|
||||
trial_to_clone_path = os.path.join(
|
||||
|
|
|
@ -73,6 +73,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
for _ in range(unresolved_spec.get("num_samples", 1)):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
trial_id = "%05d" % self._counter
|
||||
experiment_tag = str(self._counter)
|
||||
if resolved_vars:
|
||||
experiment_tag += "_{}".format(format_vars(resolved_vars))
|
||||
|
@ -82,6 +83,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
|||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
|
||||
def is_finished(self):
|
||||
|
|
|
@ -1,17 +1,134 @@
|
|||
import collections
|
||||
import time
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.progress_reporter import CLIReporter, _fair_filter_trials
|
||||
from ray.tune.progress_reporter import (CLIReporter, _fair_filter_trials,
|
||||
trial_progress_str)
|
||||
|
||||
EXPECTED_RESULT_1 = """Result logdir: /foo
|
||||
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+
|
||||
| Trial name | status | loc | a | b |
|
||||
|--------------+------------+-------+-----+-----|
|
||||
| 00001 | PENDING | here | 1 | 2 |
|
||||
| 00002 | RUNNING | here | 2 | 4 |
|
||||
| 00000 | TERMINATED | here | 0 | 0 |
|
||||
+--------------+------------+-------+-----+-----+
|
||||
... 2 more trials not shown (2 RUNNING)"""
|
||||
|
||||
EXPECTED_RESULT_2 = """Result logdir: /foo
|
||||
Number of trials: 5 (1 PENDING, 3 RUNNING, 1 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+
|
||||
| Trial name | status | loc | a | b |
|
||||
|--------------+------------+-------+-----+-----|
|
||||
| 00000 | TERMINATED | here | 0 | 0 |
|
||||
| 00001 | PENDING | here | 1 | 2 |
|
||||
| 00002 | RUNNING | here | 2 | 4 |
|
||||
| 00003 | RUNNING | here | 3 | 6 |
|
||||
| 00004 | RUNNING | here | 4 | 8 |
|
||||
+--------------+------------+-------+-----+-----+"""
|
||||
|
||||
END_TO_END_COMMAND = """
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
def f(config):
|
||||
return {"done": True}
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
tune.run_experiments({
|
||||
"one": {
|
||||
"run": f,
|
||||
"config": {
|
||||
"a": tune.grid_search(list(range(10))),
|
||||
},
|
||||
},
|
||||
"two": {
|
||||
"run": f,
|
||||
"config": {
|
||||
"b": tune.grid_search(list(range(10))),
|
||||
},
|
||||
},
|
||||
"three": {
|
||||
"run": f,
|
||||
"config": {
|
||||
"c": tune.grid_search(list(range(10))),
|
||||
},
|
||||
},
|
||||
}, reuse_actors=True, verbose=1)"""
|
||||
|
||||
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
|
||||
+--------------+----------+-------+-----+-----+
|
||||
| Trial name | status | loc | a | b |
|
||||
|--------------+----------+-------+-----+-----|
|
||||
| f_00001 | PENDING | | 1 | |
|
||||
| f_00002 | PENDING | | 2 | |
|
||||
| f_00003 | PENDING | | 3 | |
|
||||
| f_00004 | PENDING | | 4 | |
|
||||
| f_00005 | PENDING | | 5 | |
|
||||
| f_00006 | PENDING | | 6 | |
|
||||
| f_00007 | PENDING | | 7 | |
|
||||
| f_00008 | PENDING | | 8 | |
|
||||
| f_00009 | PENDING | | 9 | |
|
||||
| f_00010 | PENDING | | | 0 |
|
||||
| f_00011 | PENDING | | | 1 |
|
||||
| f_00012 | PENDING | | | 2 |
|
||||
| f_00013 | PENDING | | | 3 |
|
||||
| f_00014 | PENDING | | | 4 |
|
||||
| f_00015 | PENDING | | | 5 |
|
||||
| f_00016 | PENDING | | | 6 |
|
||||
| f_00017 | PENDING | | | 7 |
|
||||
| f_00018 | PENDING | | | 8 |
|
||||
| f_00019 | PENDING | | | 9 |
|
||||
| f_00000 | RUNNING | | 0 | |
|
||||
+--------------+----------+-------+-----+-----+
|
||||
... 10 more trials not shown (10 PENDING)"""
|
||||
|
||||
EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
|
||||
+--------------+------------+-------+-----+-----+-----+
|
||||
| Trial name | status | loc | a | b | c |
|
||||
|--------------+------------+-------+-----+-----+-----|
|
||||
| f_00000 | TERMINATED | | 0 | | |
|
||||
| f_00001 | TERMINATED | | 1 | | |
|
||||
| f_00002 | TERMINATED | | 2 | | |
|
||||
| f_00003 | TERMINATED | | 3 | | |
|
||||
| f_00004 | TERMINATED | | 4 | | |
|
||||
| f_00005 | TERMINATED | | 5 | | |
|
||||
| f_00006 | TERMINATED | | 6 | | |
|
||||
| f_00007 | TERMINATED | | 7 | | |
|
||||
| f_00008 | TERMINATED | | 8 | | |
|
||||
| f_00009 | TERMINATED | | 9 | | |
|
||||
| f_00010 | TERMINATED | | | 0 | |
|
||||
| f_00011 | TERMINATED | | | 1 | |
|
||||
| f_00012 | TERMINATED | | | 2 | |
|
||||
| f_00013 | TERMINATED | | | 3 | |
|
||||
| f_00014 | TERMINATED | | | 4 | |
|
||||
| f_00015 | TERMINATED | | | 5 | |
|
||||
| f_00016 | TERMINATED | | | 6 | |
|
||||
| f_00017 | TERMINATED | | | 7 | |
|
||||
| f_00018 | TERMINATED | | | 8 | |
|
||||
| f_00019 | TERMINATED | | | 9 | |
|
||||
| f_00020 | TERMINATED | | | | 0 |
|
||||
| f_00021 | TERMINATED | | | | 1 |
|
||||
| f_00022 | TERMINATED | | | | 2 |
|
||||
| f_00023 | TERMINATED | | | | 3 |
|
||||
| f_00024 | TERMINATED | | | | 4 |
|
||||
| f_00025 | TERMINATED | | | | 5 |
|
||||
| f_00026 | TERMINATED | | | | 6 |
|
||||
| f_00027 | TERMINATED | | | | 7 |
|
||||
| f_00028 | TERMINATED | | | | 8 |
|
||||
| f_00029 | TERMINATED | | | | 9 |
|
||||
+--------------+------------+-------+-----+-----+-----+"""
|
||||
|
||||
|
||||
class ProgressReporterTest(unittest.TestCase):
|
||||
def mock_trial(self, status, start_time):
|
||||
def mock_trial(self, status, i):
|
||||
mock = MagicMock()
|
||||
mock.status = status
|
||||
mock.start_time = start_time
|
||||
mock.trial_id = "%05d" % i
|
||||
return mock
|
||||
|
||||
def testFairFilterTrials(self):
|
||||
|
@ -25,14 +142,15 @@ class ProgressReporterTest(unittest.TestCase):
|
|||
num_trials_under = 2 # num of trials for each underrepresented state
|
||||
num_trials_over = 10 # num of trials for each overrepresented state
|
||||
|
||||
i = 0
|
||||
for state in states_under:
|
||||
for _ in range(num_trials_under):
|
||||
trials_by_state[state].append(
|
||||
self.mock_trial(state, time.time()))
|
||||
trials_by_state[state].append(self.mock_trial(state, i))
|
||||
i += 1
|
||||
for state in states_over:
|
||||
for _ in range(num_trials_over):
|
||||
trials_by_state[state].append(
|
||||
self.mock_trial(state, time.time()))
|
||||
trials_by_state[state].append(self.mock_trial(state, i))
|
||||
i += 1
|
||||
|
||||
filtered_trials_by_state = _fair_filter_trials(
|
||||
trials_by_state, max_trials=max_trials)
|
||||
|
@ -46,8 +164,7 @@ class ProgressReporterTest(unittest.TestCase):
|
|||
self.assertEqual(len(state_trials), expected_num_trials)
|
||||
# Make sure trials are sorted newest-first within state.
|
||||
for i in range(len(state_trials) - 1):
|
||||
self.assertGreaterEqual(state_trials[i].start_time,
|
||||
state_trials[i + 1].start_time)
|
||||
assert state_trials[i].trial_id < state_trials[i + 1].trial_id
|
||||
|
||||
def testAddMetricColumn(self):
|
||||
"""Tests edge cases of add_metric_column."""
|
||||
|
@ -67,3 +184,50 @@ class ProgressReporterTest(unittest.TestCase):
|
|||
reporter = CLIReporter()
|
||||
reporter.add_metric_column("foo", "bar")
|
||||
self.assertIn("foo", reporter._metric_columns)
|
||||
|
||||
def testProgressStr(self):
|
||||
trials = []
|
||||
for i in range(5):
|
||||
t = Mock()
|
||||
if i == 0:
|
||||
t.status = "TERMINATED"
|
||||
elif i == 1:
|
||||
t.status = "PENDING"
|
||||
else:
|
||||
t.status = "RUNNING"
|
||||
t.trial_id = "%05d" % i
|
||||
t.local_dir = "/foo"
|
||||
t.location = "here"
|
||||
t.config = {"a": i, "b": i * 2}
|
||||
t.evaluated_params = t.config
|
||||
t.last_result = {"config": {"a": i, "b": i * 2}}
|
||||
t.__str__ = lambda self: self.trial_id
|
||||
trials.append(t)
|
||||
prog1 = trial_progress_str(trials, ["a", "b"], fmt="psql", max_rows=3)
|
||||
print(prog1)
|
||||
assert prog1 == EXPECTED_RESULT_1
|
||||
prog2 = trial_progress_str(
|
||||
trials, ["a", "b"], fmt="psql", max_rows=None)
|
||||
print(prog2)
|
||||
assert prog2 == EXPECTED_RESULT_2
|
||||
|
||||
def testEndToEndReporting(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".py") as f:
|
||||
f.write(END_TO_END_COMMAND.encode("utf-8"))
|
||||
f.flush()
|
||||
output = subprocess.check_output(["python3", f.name])
|
||||
output = output.decode("utf-8")
|
||||
try:
|
||||
assert EXPECTED_END_TO_END_START in output
|
||||
assert EXPECTED_END_TO_END_END in output
|
||||
except Exception:
|
||||
print("*** BEGIN OUTPUT ***")
|
||||
print(output)
|
||||
print("*** END OUTPUT ***")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -688,7 +688,7 @@ class BOHBSuite(unittest.TestCase):
|
|||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.trial_id = Trial.generate_id()
|
||||
self.trial_id = str(i)
|
||||
self.config = config
|
||||
self.experiment_tag = "{}tag".format(i)
|
||||
self.trial_name_creator = None
|
||||
|
|
|
@ -204,7 +204,7 @@ class Trial:
|
|||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
keep_checkpoints_num, checkpoint_score_attr,
|
||||
checkpoint_deleter(str(self), self.runner))
|
||||
checkpoint_deleter(self._trainable_name(), self.runner))
|
||||
checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path)
|
||||
self.checkpoint_manager.newest_persistent_checkpoint = checkpoint
|
||||
|
||||
|
@ -271,7 +271,8 @@ class Trial:
|
|||
if not self.result_logger:
|
||||
if not self.logdir:
|
||||
self.logdir = Trial.create_logdir(
|
||||
str(self) + "_" + self.experiment_tag, self.local_dir)
|
||||
self._trainable_name() + "_" + self.experiment_tag,
|
||||
self.local_dir)
|
||||
else:
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
|
@ -296,7 +297,8 @@ class Trial:
|
|||
|
||||
def set_runner(self, runner):
|
||||
self.runner = runner
|
||||
self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner)
|
||||
self.checkpoint_manager.delete = checkpoint_deleter(
|
||||
self._trainable_name(), runner)
|
||||
|
||||
def set_location(self, location):
|
||||
"""Sets the location of the trial."""
|
||||
|
@ -465,9 +467,12 @@ class Trial:
|
|||
return self.saving_to is not None
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
return self._trainable_name(include_trial_id=True)
|
||||
|
||||
def __str__(self):
|
||||
return self._trainable_name(include_trial_id=True)
|
||||
|
||||
def _trainable_name(self, include_trial_id=False):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
|
||||
|
||||
Can be overridden with a custom string creator.
|
||||
|
@ -482,6 +487,7 @@ class Trial:
|
|||
identifier = "{}_{}".format(self.trainable_name, env)
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if include_trial_id:
|
||||
identifier += "_" + self.trial_id
|
||||
return identifier.replace("/", "_")
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def _report_progress(runner, reporter, done=False):
|
|||
if reporter.should_report(trials, done=done):
|
||||
sched_debug_str = runner.scheduler_alg.debug_string()
|
||||
executor_debug_str = runner.trial_executor.debug_string()
|
||||
reporter.report(trials, sched_debug_str, executor_debug_str)
|
||||
reporter.report(trials, done, sched_debug_str, executor_debug_str)
|
||||
|
||||
|
||||
def run(run_or_experiment,
|
||||
|
@ -242,10 +242,6 @@ def run(run_or_experiment,
|
|||
experiments = run_or_experiment
|
||||
else:
|
||||
experiments = [run_or_experiment]
|
||||
if len(experiments) > 1:
|
||||
logger.info(
|
||||
"Running multiple concurrent experiments is experimental and may "
|
||||
"not work with certain features.")
|
||||
|
||||
for i, exp in enumerate(experiments):
|
||||
if not isinstance(exp, Experiment):
|
||||
|
@ -349,9 +345,6 @@ def run(run_or_experiment,
|
|||
trials = runner.get_trials()
|
||||
if return_trials:
|
||||
return trials
|
||||
logger.info("Returning an analysis object by default. You can call "
|
||||
"`analysis.trials` to retrieve a list of trials. "
|
||||
"This message will be removed in future versions of Tune.")
|
||||
return ExperimentAnalysis(runner.checkpoint_file, trials=trials)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue