mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Fix tests for Function API for better consistency (#4421)
This commit is contained in:
parent
80ef8c19aa
commit
828dc08ac8
5 changed files with 52 additions and 24 deletions
|
@ -10,7 +10,7 @@ from six.moves import queue
|
|||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIME_THIS_ITER_S
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -249,8 +249,8 @@ def wrap_function(train_func):
|
|||
output = train_func(config, reporter)
|
||||
# If train_func returns, we need to notify the main event loop
|
||||
# of the last result while avoiding double logging. This is done
|
||||
# with the keyword "__duplicate__" -- see tune/trial_runner.py,
|
||||
reporter(done=True, __duplicate__=True)
|
||||
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
|
||||
reporter(**{RESULT_DUPLICATE: True})
|
||||
return output
|
||||
|
||||
return WrappedFunc
|
||||
|
|
|
@ -51,6 +51,10 @@ TRAINING_ITERATION = "training_iteration"
|
|||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
RESULT_DUPLICATE = "__duplicate__"
|
||||
|
||||
# Where Tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR")
|
||||
or os.path.expanduser("~/ray_results"))
|
||||
|
|
|
@ -20,8 +20,9 @@ from ray.tune.ray_trial_executor import RayTrialExecutor
|
|||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
TIMESTEPS_THIS_ITER)
|
||||
HOSTNAME, NODE_IP, PID, EPISODES_TOTAL,
|
||||
TRAINING_ITERATION, TIMESTEPS_THIS_ITER,
|
||||
TIME_THIS_ITER_S, TIME_TOTAL_S)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
|
@ -109,15 +110,28 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
raise_on_failed_trial=False,
|
||||
scheduler=MockScheduler())
|
||||
|
||||
# Only compare these result fields. Metadata handling
|
||||
# may be different across APIs.
|
||||
COMPARE_FIELDS = {field for res in results for field in res}
|
||||
# Ignore these fields
|
||||
NO_COMPARE_FIELDS = {
|
||||
HOSTNAME,
|
||||
NODE_IP,
|
||||
PID,
|
||||
TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S,
|
||||
DONE, # This is ignored because FunctionAPI has different handling
|
||||
"timestamp",
|
||||
"time_since_restore",
|
||||
"experiment_id",
|
||||
"date",
|
||||
}
|
||||
|
||||
self.assertEqual(len(class_output), len(results))
|
||||
self.assertEqual(len(function_output), len(results))
|
||||
|
||||
def as_comparable_result(result):
|
||||
return {k: v for k, v in result.items() if k in COMPARE_FIELDS}
|
||||
return {
|
||||
k: v
|
||||
for k, v in result.items() if k not in NO_COMPARE_FIELDS
|
||||
}
|
||||
|
||||
function_comparable = [
|
||||
as_comparable_result(result) for result in function_output
|
||||
|
@ -133,6 +147,11 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
as_comparable_result(scheduler_notif[0]),
|
||||
as_comparable_result(scheduler_notif[1]))
|
||||
|
||||
# Make sure the last result is the same.
|
||||
self.assertEqual(
|
||||
as_comparable_result(trials[0].last_result),
|
||||
as_comparable_result(trials[1].last_result))
|
||||
|
||||
return function_output, trials
|
||||
|
||||
def testPinObject(self):
|
||||
|
@ -583,11 +602,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
# check if the correct number of results were reported.
|
||||
self.assertEqual(len(logs1), len(results1))
|
||||
|
||||
# We should not double-log
|
||||
trial = [t for t in trials if "function" in str(t)][0]
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[DONE], False)
|
||||
|
||||
def check_no_missing(reported_result, result):
|
||||
common_results = [reported_result[k] == result[k] for k in result]
|
||||
return all(common_results)
|
||||
|
|
|
@ -17,9 +17,10 @@ import uuid
|
|||
|
||||
import ray
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
|
||||
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
EPISODES_THIS_ITER, EPISODES_TOTAL,
|
||||
TRAINING_ITERATION, RESULT_DUPLICATE)
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -150,6 +151,10 @@ class Trainable(object):
|
|||
result = self._train()
|
||||
assert isinstance(result, dict), "_train() needs to return a dict."
|
||||
|
||||
# We do not modify internal state nor update this result if duplicate.
|
||||
if RESULT_DUPLICATE in result:
|
||||
return result
|
||||
|
||||
result = result.copy()
|
||||
|
||||
self._iteration += 1
|
||||
|
|
|
@ -13,7 +13,7 @@ import traceback
|
|||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.util import warn_if_slow
|
||||
|
@ -407,6 +407,16 @@ class TrialRunner(object):
|
|||
def _process_trial(self, trial):
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
|
||||
is_duplicate = RESULT_DUPLICATE in result
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
# notification because there may be special handling for
|
||||
# the `on_trial_complete` hook.
|
||||
if is_duplicate:
|
||||
logger.debug("Trial finished without logging 'done'.")
|
||||
result = trial.last_result
|
||||
result.update(done=True)
|
||||
|
||||
self._total_time += result[TIME_THIS_ITER_S]
|
||||
|
||||
if trial.should_stop(result):
|
||||
|
@ -426,12 +436,7 @@ class TrialRunner(object):
|
|||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, early_terminated=True)
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
# notification because there may be special handling for
|
||||
# the `on_trial_complete` hook.
|
||||
if "__duplicate__" not in result:
|
||||
if not is_duplicate:
|
||||
trial.update_last_result(
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue