[tune] Fix tests for Function API for better consistency (#4421)

This commit is contained in:
Richard Liaw 2019-03-20 22:31:38 -07:00 committed by GitHub
parent 80ef8c19aa
commit 828dc08ac8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 24 deletions

View file

@ -10,7 +10,7 @@ from six.moves import queue
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
logger = logging.getLogger(__name__)
@ -249,8 +249,8 @@ def wrap_function(train_func):
output = train_func(config, reporter)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
# with the keyword "__duplicate__" -- see tune/trial_runner.py,
reporter(done=True, __duplicate__=True)
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
reporter(**{RESULT_DUPLICATE: True})
return output
return WrappedFunc

View file

@ -51,6 +51,10 @@ TRAINING_ITERATION = "training_iteration"
# __sphinx_doc_end__
# yapf: enable
# __duplicate__ is a magic keyword used internally to
# avoid double-logging results when using the Function API.
RESULT_DUPLICATE = "__duplicate__"
# Where Tune writes result files by default
DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR")
or os.path.expanduser("~/ray_results"))

View file

@ -20,8 +20,9 @@ from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
EPISODES_TOTAL, TRAINING_ITERATION,
TIMESTEPS_THIS_ITER)
HOSTNAME, NODE_IP, PID, EPISODES_TOTAL,
TRAINING_ITERATION, TIMESTEPS_THIS_ITER,
TIME_THIS_ITER_S, TIME_TOTAL_S)
from ray.tune.logger import Logger
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
@ -109,15 +110,28 @@ class TrainableFunctionApiTest(unittest.TestCase):
raise_on_failed_trial=False,
scheduler=MockScheduler())
# Only compare these result fields. Metadata handling
# may be different across APIs.
COMPARE_FIELDS = {field for res in results for field in res}
# Ignore these fields
NO_COMPARE_FIELDS = {
HOSTNAME,
NODE_IP,
PID,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
DONE, # This is ignored because FunctionAPI has different handling
"timestamp",
"time_since_restore",
"experiment_id",
"date",
}
self.assertEqual(len(class_output), len(results))
self.assertEqual(len(function_output), len(results))
def as_comparable_result(result):
return {k: v for k, v in result.items() if k in COMPARE_FIELDS}
return {
k: v
for k, v in result.items() if k not in NO_COMPARE_FIELDS
}
function_comparable = [
as_comparable_result(result) for result in function_output
@ -133,6 +147,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
as_comparable_result(scheduler_notif[0]),
as_comparable_result(scheduler_notif[1]))
# Make sure the last result is the same.
self.assertEqual(
as_comparable_result(trials[0].last_result),
as_comparable_result(trials[1].last_result))
return function_output, trials
def testPinObject(self):
@ -583,11 +602,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
# check if the correct number of results were reported.
self.assertEqual(len(logs1), len(results1))
# We should not double-log
trial = [t for t in trials if "function" in str(t)][0]
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[DONE], False)
def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)

View file

@ -17,9 +17,10 @@ import uuid
import ray
from ray.tune.logger import UnifiedLogger
from ray.tune.result import (
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
EPISODES_THIS_ITER, EPISODES_TOTAL,
TRAINING_ITERATION, RESULT_DUPLICATE)
from ray.tune.trial import Resources
logger = logging.getLogger(__name__)
@ -150,6 +151,10 @@ class Trainable(object):
result = self._train()
assert isinstance(result, dict), "_train() needs to return a dict."
# We do not modify internal state nor update this result if duplicate.
if RESULT_DUPLICATE in result:
return result
result = result.copy()
self._iteration += 1

View file

@ -13,7 +13,7 @@ import traceback
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
@ -407,6 +407,16 @@ class TrialRunner(object):
def _process_trial(self, trial):
try:
result = self.trial_executor.fetch_result(trial)
is_duplicate = RESULT_DUPLICATE in result
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if is_duplicate:
logger.debug("Trial finished without logging 'done'.")
result = trial.last_result
result.update(done=True)
self._total_time += result[TIME_THIS_ITER_S]
if trial.should_stop(result):
@ -426,12 +436,7 @@ class TrialRunner(object):
self._search_alg.on_trial_complete(
trial.trial_id, early_terminated=True)
# __duplicate__ is a magic keyword used internally to
# avoid double-logging results when using the Function API.
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if "__duplicate__" not in result:
if not is_duplicate:
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))