[tune] Tune experiment analysis improvements (#10645)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2020-09-09 05:00:52 +01:00 committed by GitHub
parent d9c68fca5c
commit d7c7aba99c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 247 additions and 38 deletions

View file

@ -120,7 +120,7 @@ This example runs a parallel grid search to optimize an example objective functi
print("Best config: ", analysis.get_best_config(metric="mean_loss"))
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
df = analysis.results_df
If TensorBoard is installed, automatically visualize all trial results:

View file

@ -18,7 +18,7 @@ Here are some example operations for obtaining a summary of your experiment:
.. code-block:: python
# Get a dataframe for the last reported results of all of the trials
df = analysis.dataframe()
df = analysis.results_df
# Get a dataframe for the max accuracy seen for each trial
df = analysis.dataframe(metric="mean_accuracy", mode="max")

View file

@ -219,16 +219,24 @@ Analysis
analysis = tune.run(trainable, search_alg=algo, stop={"training_iteration": 20})
# Get the best hyperparameters
best_hyperparameters = analysis.get_best_config()
best_trial = analysis.best_trial # Get best trial
best_config = analysis.best_config # Get best trial's hyperparameters
best_logdir = analysis.best_logdir # Get best trial's logdir
best_checkpoint = analysis.best_checkpoint # Get best trial's best checkpoint
best_result = analysis.best_result # Get best trial's last results
best_result_df = analysis.best_result_df # Get best result as pandas dataframe
This object can also retrieve all training runs as dataframes, allowing you to do ad-hoc data analysis over your results.
.. code-block:: python
# Get a dataframe for the max score seen for each trial
# Get a dataframe with the last results for each trial
df_results = analysis.results_df
# Get a dataframe of results for a specific score or mode
df = analysis.dataframe(metric="score", mode="max")
What's Next?
-------------

View file

@ -806,7 +806,7 @@ class TuneCollector(threading.Thread):
# search through all the sub_directories in log directory
analysis = Analysis(str(self._logdir))
df = analysis.dataframe()
df = analysis.dataframe(metric="episode_reward_mean", mode="max")
if len(df) == 0 or "trial_id" not in df.columns:
return

View file

@ -1,11 +1,17 @@
import json
import logging
import os
from typing import Dict
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.utils import flatten_dict
try:
import pandas as pd
from pandas import DataFrame
except ImportError:
pd = None
DataFrame = None
from ray.tune.error import TuneError
from ray.tune.result import EXPR_PROGRESS_FILE, EXPR_PARAM_FILE,\
@ -80,6 +86,9 @@ class Analysis:
Returns:
pd.DataFrame: Constructed from a result dict of each trial.
"""
metric = self._validate_metric(metric)
mode = self._validate_mode(mode)
rows = self._retrieve_rows(metric=metric, mode=mode)
all_configs = self.get_all_configs(prefix=True)
for path, config in all_configs.items():
@ -227,6 +236,9 @@ class Analysis:
mode = self._validate_mode(mode)
checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)
if not checkpoint_paths:
logger.error(f"No checkpoints have been found for trial {trial}.")
return None
if mode == "max":
return max(checkpoint_paths, key=lambda x: x[1])[0]
else:
@ -316,7 +328,150 @@ class ExperimentAnalysis(Analysis):
os.path.dirname(experiment_checkpoint_path), default_metric,
default_mode)
def get_best_trial(self, metric=None, mode=None, scope="all"):
@property
def best_trial(self) -> Trial:
"""Get the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_trial(metric, mode, scope)` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_trial`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use the "
"`get_best_trial(metric, mode)` method to set the metric "
"and mode explicitly.")
return self.get_best_trial(self.default_metric, self.default_mode)
@property
def best_config(self) -> Dict:
"""Get the config of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_config(metric, mode, scope)` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_config`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use the "
"`get_best_config(metric, mode)` method to set the metric "
"and mode explicitly.")
return self.get_best_config(self.default_metric, self.default_mode)
@property
def best_checkpoint(self) -> Checkpoint:
"""Get the checkpoint of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_checkpoint(trial, metric, mode)` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_checkpoint`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use the "
"`get_best_checkpoint(trial, metric, mode)` method to set the "
"metric and mode explicitly.")
best_trial = self.best_trial
return self.get_best_checkpoint(best_trial, self.default_metric,
self.default_mode)
@property
def best_logdir(self) -> str:
"""Get the logdir of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_logdir(metric, mode)` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_logdir`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use the "
"`get_best_logdir(metric, mode, scope)` method to set the "
"metric and mode explicitly.")
return self.get_best_logdir(self.default_metric, self.default_mode)
@property
def best_dataframe(self) -> DataFrame:
"""Get the full result dataframe of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_logdir(metric, mode)` and use it to look for the dataframe
in the `self.trial_dataframes` dict.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_result`, pass a `metric` and `mode` "
"parameter to `tune.run()`.")
best_logdir = self.best_logdir
return self.trial_dataframes[best_logdir]
@property
def best_result(self) -> Dict:
"""Get the last result of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_trial(metric, mode, scope).last_result` instead.
"""
if not self.default_metric or not self.default_mode:
raise ValueError(
"To fetch the `best_result`, pass a `metric` and `mode` "
"parameter to `tune.run()`. Alternatively, use "
"`get_best_trial(metric, mode).last_result` to set "
"the metric and mode explicitly and fetch the last result.")
return self.best_trial.last_result
@property
def best_result_df(self) -> DataFrame:
"""Get the best result of the experiment as a pandas dataframe.
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
If you didn't pass these parameters, use
`get_best_trial(metric, mode, scope).last_result` instead.
"""
if not pd:
raise ValueError("`best_result_df` requires pandas. Install with "
"`pip install pandas`.")
best_result = flatten_dict(self.best_result, delimiter=".")
return pd.DataFrame.from_records([best_result], index="trial_id")
@property
def results(self) -> Dict[str, Dict]:
"""Get the last result of the all trials of the experiment"""
return {trial.trial_id: trial.last_result for trial in self.trials}
@property
def results_df(self) -> DataFrame:
if not pd:
raise ValueError("`best_result_df` requires pandas. Install with "
"`pip install pandas`.")
return pd.DataFrame.from_records(
[
flatten_dict(trial.last_result, delimiter=".")
for trial in self.trials
],
index="trial_id")
def get_best_trial(self, metric=None, mode=None, scope="last"):
"""Retrieve the best trial object.
Compares all trials' scores on ``metric``.
@ -380,7 +535,7 @@ class ExperimentAnalysis(Analysis):
"parameter?")
return best_trial
def get_best_config(self, metric=None, mode=None, scope="all"):
def get_best_config(self, metric=None, mode=None, scope="last"):
"""Retrieve the best config corresponding to the trial.
Compares all trials' scores on `metric`.
@ -407,7 +562,7 @@ class ExperimentAnalysis(Analysis):
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.config if best_trial else None
def get_best_logdir(self, metric=None, mode=None, scope="all"):
def get_best_logdir(self, metric=None, mode=None, scope="last"):
"""Retrieve the logdir corresponding to the best trial.
Compares all trials' scores on `metric`.

View file

@ -116,7 +116,8 @@ def list_trials(experiment_path,
_check_tabulate()
try:
checkpoints_df = Analysis(experiment_path).dataframe()
checkpoints_df = Analysis(experiment_path).dataframe(
metric="episode_reward_mean", mode="max")
except TuneError:
raise click.ClickException("No trial data found!")

View file

@ -160,6 +160,6 @@ if __name__ == "__main__":
# demo of the trained Generators
if not args.smoke_test:
logdirs = analysis.dataframe()["logdir"].tolist()
logdirs = analysis.results_df["logdir"].tolist()
model_paths = [os.path.join(d, "exported_models") for d in logdirs]
demo_gan(analysis, model_paths)

View file

@ -285,8 +285,10 @@ class BayesOptSearch(Searcher):
analysis (ExperimentAnalysis): Optionally, the previous analysis
to integrate.
"""
for (_, report), params in zip(analysis.dataframe().iterrows(),
analysis.get_all_configs().values()):
for (_, report), params in zip(
analysis.dataframe(metric=self._metric,
mode=self._mode).iterrows(),
analysis.get_all_configs().values()):
# We add the obtained results to the
# gaussian process optimizer
self._register_result(params, report)

View file

@ -39,5 +39,5 @@ print("Best config: ", analysis.get_best_config(
metric="mean_loss", mode="min"))
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
df = analysis.results_df
# __quick_start_end__

View file

@ -520,7 +520,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe()) <= top)
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= top)
patience = 5
stopper = EarlyStopping("test", top=top, mode="min", patience=patience)
@ -528,14 +529,16 @@ class TrainableFunctionApiTest(unittest.TestCase):
analysis = tune.run(train, num_samples=20, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe()) <= patience)
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= patience)
stopper = EarlyStopping("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(len(analysis.dataframe()) <= top)
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= top)
def testBadStoppingFunction(self):
def train(config, reporter):

View file

@ -7,7 +7,7 @@ import pandas as pd
from numpy import nan
import ray
from ray.tune import run, sample_from
from ray import tune
from ray.tune.examples.async_hyperband_example import MyTrainableClass
@ -26,7 +26,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
ray.shutdown()
def run_test_exp(self):
self.ea = run(
self.ea = tune.run(
MyTrainableClass,
name=self.test_name,
local_dir=self.test_dir,
@ -34,13 +34,14 @@ class ExperimentAnalysisSuite(unittest.TestCase):
checkpoint_freq=1,
num_samples=self.num_samples,
config={
"width": sample_from(
"width": tune.sample_from(
lambda spec: 10 + int(90 * random.random())),
"height": sample_from(lambda spec: int(100 * random.random())),
"height": tune.sample_from(
lambda spec: int(100 * random.random())),
})
def nan_test_exp(self):
nan_ea = run(
nan_ea = tune.run(
lambda x: nan,
name="testing_nan",
local_dir=self.test_dir,
@ -48,14 +49,15 @@ class ExperimentAnalysisSuite(unittest.TestCase):
checkpoint_freq=1,
num_samples=self.num_samples,
config={
"width": sample_from(
"width": tune.sample_from(
lambda spec: 10 + int(90 * random.random())),
"height": sample_from(lambda spec: int(100 * random.random())),
"height": tune.sample_from(
lambda spec: int(100 * random.random())),
})
return nan_ea
def testDataframe(self):
df = self.ea.dataframe()
df = self.ea.dataframe(self.metric, mode="max")
self.assertTrue(isinstance(df, pd.DataFrame))
self.assertEquals(df.shape[0], self.num_samples)
@ -143,21 +145,50 @@ class ExperimentAnalysisSuite(unittest.TestCase):
self.assertEqual(df.training_iteration.max(), 1)
def testIgnoreOtherExperiment(self):
analysis = run(
analysis = tune.run(
MyTrainableClass,
name="test_example",
local_dir=self.test_dir,
stop={"training_iteration": 1},
num_samples=1,
config={
"width": sample_from(
"width": tune.sample_from(
lambda spec: 10 + int(90 * random.random())),
"height": sample_from(lambda spec: int(100 * random.random())),
"height": tune.sample_from(
lambda spec: int(100 * random.random())),
})
df = analysis.dataframe()
df = analysis.dataframe(self.metric, mode="max")
self.assertEquals(df.shape[0], 1)
class ExperimentAnalysisPropertySuite(unittest.TestCase):
def testBestProperties(self):
def train(config):
for i in range(10):
with tune.checkpoint_dir(i):
pass
tune.report(res=config["base"] + i)
ea = tune.run(
train,
config={"base": tune.grid_search([100, 200, 300])},
metric="res",
mode="max")
trials = ea.trials
self.assertEquals(ea.best_trial, trials[2])
self.assertEquals(ea.best_config, trials[2].config)
self.assertEquals(ea.best_logdir, trials[2].logdir)
self.assertEquals(ea.best_checkpoint, trials[2].checkpoint.value)
self.assertTrue(
all(ea.best_dataframe["trial_id"] == trials[2].trial_id))
self.assertEquals(ea.results_df.loc[trials[2].trial_id, "res"], 309)
self.assertEquals(ea.best_result["res"], 309)
self.assertEquals(ea.best_result_df.loc[trials[2].trial_id, "res"],
309)
if __name__ == "__main__":
import pytest
import sys

View file

@ -83,10 +83,10 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
num_samples=1,
config={"id": grid_search(list(range(5)))})
max_all = ea.get_best_trial("score",
"max").metric_analysis["score"]["max"]
min_all = ea.get_best_trial("score",
"min").metric_analysis["score"]["min"]
max_all = ea.get_best_trial("score", "max",
"all").metric_analysis["score"]["max"]
min_all = ea.get_best_trial("score", "min",
"all").metric_analysis["score"]["min"]
max_last = ea.get_best_trial("score", "max",
"last").metric_analysis["score"]["last"]
max_avg = ea.get_best_trial("score", "max",
@ -149,7 +149,7 @@ class AnalysisSuite(unittest.TestCase):
def testDataframe(self):
analysis = Analysis(self.test_dir)
df = analysis.dataframe()
df = analysis.dataframe(self.metric, mode="max")
self.assertTrue(isinstance(df, pd.DataFrame))
self.assertEqual(df.shape[0], self.num_samples * 2)

View file

@ -82,15 +82,24 @@ class PopulationBasedTrainingSynchTest(unittest.TestCase):
def testAsynchFail(self):
analysis = self.synchSetup(False)
self.assertTrue(any(analysis.dataframe()["mean_accuracy"] != 33))
self.assertTrue(
any(
analysis.dataframe(metric="mean_accuracy", mode="max")
["mean_accuracy"] != 33))
def testSynchPass(self):
analysis = self.synchSetup(True)
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
self.assertTrue(
all(
analysis.dataframe(metric="mean_accuracy", mode="max")[
"mean_accuracy"] == 33))
def testSynchPassLast(self):
analysis = self.synchSetup(True, param=[30, 20, 10])
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
self.assertTrue(
all(
analysis.dataframe(metric="mean_accuracy", mode="max")[
"mean_accuracy"] == 33))
class PopulationBasedTrainingConfigTest(unittest.TestCase):

View file

@ -166,7 +166,7 @@ analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)
# __run_analysis_begin__
import os
df = analysis.dataframe()
df = analysis.results_df
logdir = analysis.get_best_logdir("mean_accuracy", mode="max")
state_dict = torch.load(os.path.join(logdir, "model.pth"))