[tune] Fix SearchAlg finishing early (#3081)

* Fix trial search alg finishing early

* Fix lint

* fix lint

* nit fix
This commit is contained in:
Richard Liaw 2018-10-22 12:17:13 -07:00 committed by Eric Liang
parent 221d1663c1
commit eff7cb4458
3 changed files with 32 additions and 2 deletions

View file

@ -223,7 +223,7 @@ For TensorFlow model training, this would look something like this `(full tensor
.. code-block:: python
class MyClass(Trainable):
def _setup(self):
def _setup(self, config):
self.saver = tf.train.Saver()
self.sess = ...
self.iteration = 0

View file

@ -20,7 +20,8 @@ from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import grid_search, BasicVariantGenerator
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SuggestionAlgorithm)
from ray.tune.suggest.variant_generator import RecursiveDependencyError
@ -1385,6 +1386,31 @@ class TrialRunnerTest(unittest.TestCase):
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())
def testSearchAlgFinishes(self):
"""SearchAlg changing state in `next_trials` does not crash."""
class FinishFastAlg(SuggestionAlgorithm):
def next_trials(self):
self._finished = True
return []
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {
"run": "__fake",
"num_samples": 3,
"stop": {
"training_iteration": 1
}
}
searcher = FinishFastAlg()
experiments = [Experiment.from_json("test", experiment_spec)]
searcher.add_configurations(experiments)
runner = TrialRunner(search_alg=searcher)
runner.step() # This should not fail
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -114,6 +114,10 @@ class TrialRunner(object):
self.trial_executor.start_trial(next_trial)
elif self.trial_executor.get_running_trials():
self._process_events()
elif self.is_finished():
# We check `is_finished` again here because the experiment
# may have finished while getting the next trial.
pass
else:
for trial in self._trials:
if trial.status == Trial.PENDING: