diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 7dbf02ef8..551f17027 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,7 +10,7 @@ from six.moves import queue from ray.tune import TuneError from ray.tune.trainable import Trainable -from ray.tune.result import TIME_THIS_ITER_S +from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE logger = logging.getLogger(__name__) @@ -249,8 +249,8 @@ def wrap_function(train_func): output = train_func(config, reporter) # If train_func returns, we need to notify the main event loop # of the last result while avoiding double logging. This is done - # with the keyword "__duplicate__" -- see tune/trial_runner.py, - reporter(done=True, __duplicate__=True) + # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py. + reporter(**{RESULT_DUPLICATE: True}) return output return WrappedFunc diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 47b536186..2978fe540 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -51,6 +51,10 @@ TRAINING_ITERATION = "training_iteration" # __sphinx_doc_end__ # yapf: enable +# __duplicate__ is a magic keyword used internally to +# avoid double-logging results when using the Function API. +RESULT_DUPLICATE = "__duplicate__" + # Where Tune writes result files by default DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR") or os.path.expanduser("~/ray_results")) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 7e3ee1071..32e5253f2 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -20,8 +20,9 @@ from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, - EPISODES_TOTAL, TRAINING_ITERATION, - TIMESTEPS_THIS_ITER) + HOSTNAME, NODE_IP, PID, EPISODES_TOTAL, + TRAINING_ITERATION, TIMESTEPS_THIS_ITER, + TIME_THIS_ITER_S, TIME_TOTAL_S) from ray.tune.logger import Logger from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment @@ -109,15 +110,28 @@ class TrainableFunctionApiTest(unittest.TestCase): raise_on_failed_trial=False, scheduler=MockScheduler()) - # Only compare these result fields. Metadata handling - # may be different across APIs. - COMPARE_FIELDS = {field for res in results for field in res} + # Ignore these fields + NO_COMPARE_FIELDS = { + HOSTNAME, + NODE_IP, + PID, + TIME_THIS_ITER_S, + TIME_TOTAL_S, + DONE, # This is ignored because FunctionAPI has different handling + "timestamp", + "time_since_restore", + "experiment_id", + "date", + } self.assertEqual(len(class_output), len(results)) self.assertEqual(len(function_output), len(results)) def as_comparable_result(result): - return {k: v for k, v in result.items() if k in COMPARE_FIELDS} + return { + k: v + for k, v in result.items() if k not in NO_COMPARE_FIELDS + } function_comparable = [ as_comparable_result(result) for result in function_output @@ -133,6 +147,11 @@ class TrainableFunctionApiTest(unittest.TestCase): as_comparable_result(scheduler_notif[0]), as_comparable_result(scheduler_notif[1])) + # Make sure the last result is the same. + self.assertEqual( + as_comparable_result(trials[0].last_result), + as_comparable_result(trials[1].last_result)) + return function_output, trials def testPinObject(self): @@ -583,11 +602,6 @@ class TrainableFunctionApiTest(unittest.TestCase): # check if the correct number of results were reported. self.assertEqual(len(logs1), len(results1)) - # We should not double-log - trial = [t for t in trials if "function" in str(t)][0] - self.assertEqual(trial.status, Trial.TERMINATED) - self.assertEqual(trial.last_result[DONE], False) - def check_no_missing(reported_result, result): common_results = [reported_result[k] == result[k] for k in result] return all(common_results) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 7661d01dd..31d766af3 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -17,9 +17,10 @@ import uuid import ray from ray.tune.logger import UnifiedLogger -from ray.tune.result import ( - DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE, - TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION) +from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, + TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, + EPISODES_THIS_ITER, EPISODES_TOTAL, + TRAINING_ITERATION, RESULT_DUPLICATE) from ray.tune.trial import Resources logger = logging.getLogger(__name__) @@ -150,6 +151,10 @@ class Trainable(object): result = self._train() assert isinstance(result, dict), "_train() needs to return a dict." + # We do not modify internal state nor update this result if duplicate. + if RESULT_DUPLICATE in result: + return result + result = result.copy() self._iteration += 1 diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 1dc33c03d..6be2b7375 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -13,7 +13,7 @@ import traceback from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.result import TIME_THIS_ITER_S +from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE from ray.tune.trial import Trial, Checkpoint from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.util import warn_if_slow @@ -407,6 +407,16 @@ class TrialRunner(object): def _process_trial(self, trial): try: result = self.trial_executor.fetch_result(trial) + + is_duplicate = RESULT_DUPLICATE in result + # TrialScheduler and SearchAlgorithm still receive a + # notification because there may be special handling for + # the `on_trial_complete` hook. + if is_duplicate: + logger.debug("Trial finished without logging 'done'.") + result = trial.last_result + result.update(done=True) + self._total_time += result[TIME_THIS_ITER_S] if trial.should_stop(result): @@ -426,12 +436,7 @@ class TrialRunner(object): self._search_alg.on_trial_complete( trial.trial_id, early_terminated=True) - # __duplicate__ is a magic keyword used internally to - # avoid double-logging results when using the Function API. - # TrialScheduler and SearchAlgorithm still receive a - # notification because there may be special handling for - # the `on_trial_complete` hook. - if "__duplicate__" not in result: + if not is_duplicate: trial.update_last_result( result, terminate=(decision == TrialScheduler.STOP))