mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] Readable trial progress output (#5822)
* Cleaner, tabulated progress output. * Minor HTML changes, trial ID instead of name * Revert basic variant changes * Cleanup, address richard's comments, add progress_reporter.py * Add tabulate dependency * Added more info to table, auto-hide columns with no data. * lint * Address comments * Replace experiment tag w/ trial ID * Fixed tests. * Fixed test * Added requirement * Fix formatting
This commit is contained in:
parent
24b79fd0a6
commit
a851d7eb87
13 changed files with 231 additions and 163 deletions
|
@ -84,7 +84,7 @@ To run this example, you will need to install the following:
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install ray torch torchvision filelock
|
||||
$ pip install ray[tune] torch torchvision filelock
|
||||
|
||||
|
||||
This example runs a parallel grid search to train a Convolutional Neural Network using PyTorch.
|
||||
|
|
|
@ -17,6 +17,7 @@ sphinx-click
|
|||
sphinx-gallery
|
||||
sphinx-jsonschema
|
||||
sphinx_rtd_theme
|
||||
tabulate
|
||||
pandas
|
||||
flask
|
||||
uvicorn
|
||||
|
|
|
@ -29,7 +29,7 @@ Quick Start
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install ray torch torchvision filelock
|
||||
$ pip install ray[tune] torch torchvision filelock
|
||||
|
||||
|
||||
This example runs a small grid search to train a CNN using PyTorch and Tune.
|
||||
|
|
|
@ -16,5 +16,5 @@ __all__ = [
|
|||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run", "run_experiments", "Experiment", "function",
|
||||
"sample_from", "track", "uniform", "choice", "randint", "randn",
|
||||
"loguniform", "ExperimentAnalysis", "Analysis"
|
||||
"loguniform", "progress_reporter", "ExperimentAnalysis", "Analysis"
|
||||
]
|
||||
|
|
|
@ -11,8 +11,8 @@ from datetime import datetime
|
|||
|
||||
import pandas as pd
|
||||
from pandas.api.types import is_string_dtype, is_numeric_dtype
|
||||
from ray.tune.result import (TRAINING_ITERATION, MEAN_ACCURACY, MEAN_LOSS,
|
||||
TIME_TOTAL_S, TRIAL_ID, CONFIG_PREFIX)
|
||||
from ray.tune.result import (DEFAULT_EXPERIMENT_INFO_KEYS, DEFAULT_RESULT_KEYS,
|
||||
CONFIG_PREFIX)
|
||||
from ray.tune.analysis import Analysis
|
||||
from ray.tune import TuneError
|
||||
try:
|
||||
|
@ -26,9 +26,7 @@ EDITOR = os.getenv("EDITOR", "vim")
|
|||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", "experiment_tag",
|
||||
TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
MEAN_ACCURACY, MEAN_LOSS, TRIAL_ID)
|
||||
DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
|
||||
|
||||
DEFAULT_PROJECT_INFO_KEYS = (
|
||||
"name",
|
||||
|
@ -127,7 +125,7 @@ def list_trials(experiment_path,
|
|||
raise click.ClickException("No trial data found!")
|
||||
|
||||
def key_filter(k):
|
||||
return k in DEFAULT_EXPERIMENT_INFO_KEYS or k.startswith(CONFIG_PREFIX)
|
||||
return k in DEFAULT_CLI_KEYS or k.startswith(CONFIG_PREFIX)
|
||||
|
||||
col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
|
||||
|
||||
|
|
174
python/ray/tune/progress_reporter.py
Normal file
174
python/ray/tune/progress_reporter.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.tune.result import (DEFAULT_RESULT_KEYS, CONFIG_PREFIX, PID,
|
||||
EPISODE_REWARD_MEAN, MEAN_ACCURACY, MEAN_LOSS,
|
||||
HOSTNAME, TRAINING_ITERATION, TIME_TOTAL_S)
|
||||
from ray.tune.util import flatten_dict
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
raise ImportError("ray.tune in ray > 0.7.5 requires 'tabulate'. "
|
||||
"Please re-run 'pip install ray[tune]' or "
|
||||
"'pip install ray[rllib]'.")
|
||||
|
||||
DEFAULT_PROGRESS_KEYS = DEFAULT_RESULT_KEYS + (EPISODE_REWARD_MEAN, )
|
||||
# Truncated representations of column names (to accommodate small screens).
|
||||
REPORTED_REPRESENTATIONS = {
|
||||
EPISODE_REWARD_MEAN: "reward",
|
||||
MEAN_ACCURACY: "acc",
|
||||
MEAN_LOSS: "loss",
|
||||
TIME_TOTAL_S: "total time (s)",
|
||||
TRAINING_ITERATION: "iter",
|
||||
}
|
||||
|
||||
|
||||
class ProgressReporter(object):
|
||||
def report(self, trial_runner):
|
||||
"""Reports progress across all trials of the trial runner.
|
||||
|
||||
Args:
|
||||
trial_runner: Trial runner to report on.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JupyterNotebookReporter(ProgressReporter):
|
||||
def __init__(self, overwrite):
|
||||
"""Initializes a new JupyterNotebookReporter.
|
||||
|
||||
Args:
|
||||
overwrite (bool): Flag for overwriting the last reported progress.
|
||||
"""
|
||||
self.overwrite = overwrite
|
||||
|
||||
def report(self, trial_runner):
|
||||
delim = "<br>"
|
||||
messages = [
|
||||
"== Status ==",
|
||||
memory_debug_str(),
|
||||
trial_runner.debug_string(delim=delim),
|
||||
trial_progress_str(trial_runner.get_trials(), fmt="html")
|
||||
]
|
||||
from IPython.display import clear_output
|
||||
from IPython.core.display import display, HTML
|
||||
if self.overwrite:
|
||||
clear_output(wait=True)
|
||||
display(HTML(delim.join(messages) + delim))
|
||||
|
||||
|
||||
class CLIReporter(ProgressReporter):
|
||||
def report(self, trial_runner):
|
||||
messages = [
|
||||
"== Status ==",
|
||||
memory_debug_str(),
|
||||
trial_runner.debug_string(),
|
||||
trial_progress_str(trial_runner.get_trials())
|
||||
]
|
||||
print("\n".join(messages) + "\n")
|
||||
|
||||
|
||||
def memory_debug_str():
|
||||
try:
|
||||
import psutil
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
used_gb = total_gb - psutil.virtual_memory().available / (1024**3)
|
||||
if used_gb > total_gb * 0.9:
|
||||
warn = (": ***LOW MEMORY*** less than 10% of the memory on "
|
||||
"this node is available for use. This can cause "
|
||||
"unexpected crashes. Consider "
|
||||
"reducing the memory used by your application "
|
||||
"or reducing the Ray object store size by setting "
|
||||
"`object_store_memory` when calling `ray.init`.")
|
||||
else:
|
||||
warn = ""
|
||||
return "Memory usage on this node: {}/{} GiB{}".format(
|
||||
round(used_gb, 1), round(total_gb, 1), warn)
|
||||
except ImportError:
|
||||
return ("Unknown memory usage. Please run `pip install psutil` "
|
||||
"(or ray[debug]) to resolve)")
|
||||
|
||||
|
||||
def trial_progress_str(trials, metrics=None, fmt="psql", max_rows=100):
|
||||
"""Returns a human readable message for printing to the console.
|
||||
|
||||
This contains a table where each row represents a trial, its parameters
|
||||
and the current values of its metrics.
|
||||
|
||||
Args:
|
||||
trials (List[Trial]): List of trials to get progress string for.
|
||||
metrics (List[str]): Names of metrics to include. Defaults to
|
||||
metrics defined in DEFAULT_RESULT_KEYS.
|
||||
fmt (str): Output format (see tablefmt in tabulate API).
|
||||
max_rows (int): Maximum number of rows in the trial table.
|
||||
"""
|
||||
messages = []
|
||||
delim = "<br>" if fmt == "html" else "\n"
|
||||
if len(trials) < 1:
|
||||
return delim.join(messages)
|
||||
|
||||
num_trials = len(trials)
|
||||
trials_per_state = {}
|
||||
for t in trials:
|
||||
trials_per_state[t.status] = trials_per_state.get(t.status, 0) + 1
|
||||
messages.append("Number of trials: {} ({})".format(num_trials,
|
||||
trials_per_state))
|
||||
for local_dir in sorted({t.local_dir for t in trials}):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
|
||||
if num_trials > max_rows:
|
||||
overflow = num_trials - max_rows
|
||||
# TODO(ujvl): suggestion for users to view more rows.
|
||||
messages.append("Table truncated to {} rows ({} overflow).".format(
|
||||
max_rows, overflow))
|
||||
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
keys = list(metrics or DEFAULT_PROGRESS_KEYS)
|
||||
keys = [k for k in keys if any(t.last_result.get(k) for t in trials)]
|
||||
has_failed = any(t.error_file for t in trials)
|
||||
# Build rows.
|
||||
trial_table = []
|
||||
params = list(set().union(*[t.evaluated_params for t in trials]))
|
||||
for trial in trials[:min(num_trials, max_rows)]:
|
||||
trial_table.append(_get_trial_info(trial, params, keys, has_failed))
|
||||
# Parse columns.
|
||||
parsed_columns = [REPORTED_REPRESENTATIONS.get(k, k) for k in keys]
|
||||
columns = ["Trial name", "ID", "status", "loc"]
|
||||
columns += ["failures", "error file"] if has_failed else []
|
||||
columns += params + parsed_columns
|
||||
messages.append(
|
||||
tabulate(trial_table, headers=columns, tablefmt=fmt, showindex=False))
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
def _get_trial_info(trial, parameters, metrics, include_error_data=False):
|
||||
"""Returns the following information about a trial:
|
||||
|
||||
name | ID | status | loc | # failures | error_file | params... | metrics...
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to get information for.
|
||||
parameters (List[str]): Names of trial parameters to include.
|
||||
metrics (List[str]): Names of metrics to include.
|
||||
include_error_data (bool): Include error file and # of failures.
|
||||
"""
|
||||
result = flatten_dict(trial.last_result)
|
||||
trial_info = [str(trial), trial.trial_id, trial.status]
|
||||
trial_info += [_location_str(result.get(HOSTNAME), result.get(PID))]
|
||||
if include_error_data:
|
||||
# TODO(ujvl): File path is too long to display in a single row.
|
||||
trial_info += [trial.num_failures, trial.error_file]
|
||||
trial_info += [result.get(CONFIG_PREFIX + param) for param in parameters]
|
||||
trial_info += [result.get(metric) for metric in metrics]
|
||||
return trial_info
|
||||
|
||||
|
||||
def _location_str(hostname, pid):
|
||||
if not pid:
|
||||
return ""
|
||||
elif hostname == os.uname()[1]:
|
||||
return "pid={}".format(pid)
|
||||
else:
|
||||
return "{}:{}".format(hostname, pid)
|
|
@ -18,6 +18,9 @@ HOSTNAME = "hostname"
|
|||
# (Auto-filled) The auto-assigned id of the trial.
|
||||
TRIAL_ID = "trial_id"
|
||||
|
||||
# (Auto-filled) The auto-assigned id of the trial.
|
||||
EXPERIMENT_TAG = "experiment_tag"
|
||||
|
||||
# (Auto-filled) The node ip of the machine hosting the training process.
|
||||
NODE_IP = "node_ip"
|
||||
|
||||
|
@ -57,6 +60,11 @@ TRAINING_ITERATION = "training_iteration"
|
|||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
|
||||
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, MEAN_ACCURACY,
|
||||
MEAN_LOSS)
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
RESULT_DUPLICATE = "__duplicate__"
|
||||
|
|
|
@ -21,10 +21,10 @@ from ray.tune import register_env, register_trainable, run_experiments
|
|||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
HOSTNAME, NODE_IP, PID, EPISODES_TOTAL,
|
||||
TRAINING_ITERATION, TIMESTEPS_THIS_ITER,
|
||||
TIME_THIS_ITER_S, TIME_TOTAL_S, TRIAL_ID)
|
||||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION, TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object, flatten_dict
|
||||
from ray.tune.experiment import Experiment
|
||||
|
@ -117,6 +117,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
HOSTNAME,
|
||||
NODE_IP,
|
||||
TRIAL_ID,
|
||||
EXPERIMENT_TAG,
|
||||
PID,
|
||||
TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S,
|
||||
|
@ -1244,7 +1245,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
}, "tune-pong")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(str(trials[0]), "PPO_Pong-v0_0")
|
||||
self.assertTrue("PPO_Pong-v0" in str(trials[0]))
|
||||
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
|
||||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
|
|
|
@ -696,6 +696,7 @@ class BOHBSuite(unittest.TestCase):
|
|||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.trial_id = Trial.generate_id()
|
||||
self.config = config
|
||||
self.experiment_tag = "{}tag".format(i)
|
||||
self.trial_name_creator = None
|
||||
|
|
|
@ -17,9 +17,7 @@ from ray.tune.logger import pretty_print, UnifiedLogger
|
|||
# need because there are cyclic imports that may cause specific names to not
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
import ray.tune.registry
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL,
|
||||
EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY)
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
|
||||
|
@ -313,54 +311,6 @@ class Trial(object):
|
|||
else:
|
||||
return False
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
if not self.last_result:
|
||||
return self._status_string()
|
||||
|
||||
def location_string(hostname, pid):
|
||||
if hostname == os.uname()[1]:
|
||||
return "pid={}".format(pid)
|
||||
else:
|
||||
return "{} pid={}".format(hostname, pid)
|
||||
|
||||
pieces = [
|
||||
"{}".format(self._status_string()), "[{}]".format(
|
||||
self.resources.summary_string()), "[{}]".format(
|
||||
location_string(
|
||||
self.last_result.get(HOSTNAME),
|
||||
self.last_result.get(PID))), "{} s".format(
|
||||
int(self.last_result.get(TIME_TOTAL_S, 0)))
|
||||
]
|
||||
|
||||
if self.last_result.get(TRAINING_ITERATION) is not None:
|
||||
pieces.append("{} iter".format(
|
||||
self.last_result[TRAINING_ITERATION]))
|
||||
|
||||
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
|
||||
pieces.append("{} ts".format(self.last_result[TIMESTEPS_TOTAL]))
|
||||
|
||||
if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
|
||||
pieces.append("{} rew".format(
|
||||
format(self.last_result[EPISODE_REWARD_MEAN], ".3g")))
|
||||
|
||||
if self.last_result.get(MEAN_LOSS) is not None:
|
||||
pieces.append("{} loss".format(
|
||||
format(self.last_result[MEAN_LOSS], ".3g")))
|
||||
|
||||
if self.last_result.get(MEAN_ACCURACY) is not None:
|
||||
pieces.append("{} acc".format(
|
||||
format(self.last_result[MEAN_ACCURACY], ".3g")))
|
||||
|
||||
return ", ".join(pieces)
|
||||
|
||||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status, ", {} failures: {}".format(self.num_failures,
|
||||
self.error_file)
|
||||
if self.error_file else "")
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint.value is not None
|
||||
|
||||
|
@ -380,6 +330,8 @@ class Trial(object):
|
|||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
result.update(trial_id=self.trial_id, done=terminate)
|
||||
if self.experiment_tag:
|
||||
result.update(experiment_tag=self.experiment_tag)
|
||||
if self.verbose and (terminate or time.time() - self.last_debug >
|
||||
DEBUG_PRINT_INTERVAL):
|
||||
print("Result for {}:".format(self))
|
||||
|
@ -429,7 +381,7 @@ class Trial(object):
|
|||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
|
||||
"""Combines ``env`` with ``trainable_name`` and ``trial_id``.
|
||||
|
||||
Can be overriden with a custom string creator.
|
||||
"""
|
||||
|
@ -443,8 +395,7 @@ class Trial(object):
|
|||
identifier = "{}_{}".format(self.trainable_name, env)
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
identifier += "_" + self.trial_id
|
||||
return identifier.replace("/", "_")
|
||||
|
||||
def __getstate__(self):
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
|
@ -396,88 +395,12 @@ class TrialRunner(object):
|
|||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
messages = self._debug_messages()
|
||||
states = collections.defaultdict(set)
|
||||
limit_per_state = collections.Counter()
|
||||
for t in self._trials:
|
||||
states[t.status].add(t)
|
||||
|
||||
# Show at most max_debug total, but divide the limit fairly
|
||||
while max_debug > 0:
|
||||
start_num = max_debug
|
||||
for s in states:
|
||||
if limit_per_state[s] >= len(states[s]):
|
||||
continue
|
||||
max_debug -= 1
|
||||
limit_per_state[s] += 1
|
||||
if max_debug == start_num:
|
||||
break
|
||||
|
||||
for local_dir in sorted({t.local_dir for t in self._trials}):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
|
||||
num_trials_per_state = {
|
||||
state: len(trials)
|
||||
for state, trials in states.items()
|
||||
}
|
||||
total_number_of_trials = sum(num_trials_per_state.values())
|
||||
if total_number_of_trials > 0:
|
||||
messages.append("Number of trials: {} ({})"
|
||||
"".format(total_number_of_trials,
|
||||
num_trials_per_state))
|
||||
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
sorted_trials = sorted(
|
||||
trials, key=lambda t: _naturalize(t.experiment_tag))
|
||||
if len(trials) > limit:
|
||||
tail_length = limit // 2
|
||||
first = sorted_trials[:tail_length]
|
||||
for t in first:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
messages.append(
|
||||
" ... {} not shown".format(len(trials) - tail_length * 2))
|
||||
last = sorted_trials[-tail_length:]
|
||||
for t in last:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
else:
|
||||
for t in sorted_trials:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
messages.append(self.trial_executor.debug_string())
|
||||
messages.append(self._memory_debug_string())
|
||||
return messages
|
||||
|
||||
def _memory_debug_string(self):
|
||||
try:
|
||||
import psutil
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
used_gb = total_gb - psutil.virtual_memory().available / (1024**3)
|
||||
if used_gb > total_gb * 0.9:
|
||||
warn = (": ***LOW MEMORY*** less than 10% of the memory on "
|
||||
"this node is available for use. This can cause "
|
||||
"unexpected crashes. Consider "
|
||||
"reducing the memory used by your application "
|
||||
"or reducing the Ray object store size by setting "
|
||||
"`object_store_memory` when calling `ray.init`.")
|
||||
else:
|
||||
warn = ""
|
||||
return "Memory usage on this node: {}/{} GiB{}".format(
|
||||
round(used_gb, 1), round(total_gb, 1), warn)
|
||||
except ImportError:
|
||||
return ("Unknown memory usage. Please run `pip install psutil` "
|
||||
"(or ray[debug]) to resolve)")
|
||||
def debug_string(self, delim="\n"):
|
||||
messages = [
|
||||
self._scheduler_alg.debug_string(),
|
||||
self.trial_executor.debug_string()
|
||||
]
|
||||
return delim.join(messages)
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
|
|
@ -13,6 +13,7 @@ from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
|||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.syncer import wait_for_sync
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
FIFOScheduler, MedianStoppingRule)
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
@ -26,6 +27,12 @@ _SCHEDULERS = {
|
|||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||
}
|
||||
|
||||
try:
|
||||
class_name = get_ipython().__class__.__name__
|
||||
IS_NOTEBOOK = True if "Terminal" not in class_name else False
|
||||
except NameError:
|
||||
IS_NOTEBOOK = False
|
||||
|
||||
|
||||
def _make_scheduler(args):
|
||||
if args.scheduler in _SCHEDULERS:
|
||||
|
@ -181,13 +188,13 @@ def run(run_or_experiment,
|
|||
>>> tune.run(mytrainable, num_samples=5, reuse_actors=True)
|
||||
|
||||
>>> tune.run(
|
||||
"PG",
|
||||
num_samples=5,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"lr": tune.sample_from(lambda _: np.random.rand())
|
||||
}
|
||||
)
|
||||
>>> "PG",
|
||||
>>> num_samples=5,
|
||||
>>> config={
|
||||
>>> "env": "CartPole-v0",
|
||||
>>> "lr": tune.sample_from(lambda _: np.random.rand())
|
||||
>>> }
|
||||
>>> )
|
||||
"""
|
||||
trial_executor = trial_executor or RayTrialExecutor(
|
||||
queue_trials=queue_trials,
|
||||
|
@ -238,15 +245,17 @@ def run(run_or_experiment,
|
|||
|
||||
runner.add_experiment(experiment)
|
||||
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
if IS_NOTEBOOK:
|
||||
reporter = JupyterNotebookReporter(overwrite=verbose < 2)
|
||||
else:
|
||||
reporter = CLIReporter()
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
|
||||
if verbose:
|
||||
print(runner.debug_string())
|
||||
reporter.report(runner)
|
||||
last_debug = time.time()
|
||||
|
||||
try:
|
||||
|
@ -255,7 +264,7 @@ def run(run_or_experiment,
|
|||
logger.exception("Trial Runner checkpointing failed.")
|
||||
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
reporter.report(runner)
|
||||
|
||||
wait_for_sync()
|
||||
|
||||
|
|
|
@ -73,11 +73,13 @@ if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
|
|||
|
||||
extras = {
|
||||
"rllib": [
|
||||
"pyyaml", "gym[atari]", "opencv-python-headless", "lz4", "scipy"
|
||||
"pyyaml", "gym[atari]", "opencv-python-headless", "lz4", "scipy",
|
||||
"tabulate"
|
||||
],
|
||||
"debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"],
|
||||
"dashboard": ["aiohttp", "psutil", "setproctitle"],
|
||||
"serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas"],
|
||||
"tune": ["tabulate"],
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue