mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Fix SearchAlg finishing early (#3081)
* Fix trial search alg finishing early * Fix lint * fix lint * nit fix
This commit is contained in:
parent
221d1663c1
commit
eff7cb4458
3 changed files with 32 additions and 2 deletions
|
@ -223,7 +223,7 @@ For TensorFlow model training, this would look something like this `(full tensor
|
|||
.. code-block:: python
|
||||
|
||||
class MyClass(Trainable):
|
||||
def _setup(self):
|
||||
def _setup(self, config):
|
||||
self.saver = tf.train.Saver()
|
||||
self.sess = ...
|
||||
self.iteration = 0
|
||||
|
|
|
@ -20,7 +20,8 @@ from ray.tune.experiment import Experiment
|
|||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
SuggestionAlgorithm)
|
||||
from ray.tune.suggest.variant_generator import RecursiveDependencyError
|
||||
|
||||
|
||||
|
@ -1385,6 +1386,31 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
def testSearchAlgFinishes(self):
|
||||
"""SearchAlg changing state in `next_trials` does not crash."""
|
||||
|
||||
class FinishFastAlg(SuggestionAlgorithm):
|
||||
def next_trials(self):
|
||||
self._finished = True
|
||||
return []
|
||||
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"num_samples": 3,
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
searcher = FinishFastAlg()
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher.add_configurations(experiments)
|
||||
|
||||
runner = TrialRunner(search_alg=searcher)
|
||||
runner.step() # This should not fail
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertTrue(runner.is_finished())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -114,6 +114,10 @@ class TrialRunner(object):
|
|||
self.trial_executor.start_trial(next_trial)
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events()
|
||||
elif self.is_finished():
|
||||
# We check `is_finished` again here because the experiment
|
||||
# may have finished while getting the next trial.
|
||||
pass
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
|
|
Loading…
Add table
Reference in a new issue