[tune] Only add new trial when there is no pending trial (#10979)

This commit is contained in:
Kai Fricke 2020-09-23 19:08:12 +01:00 committed by GitHub
parent 7dbd0ff824
commit 5921e87ecd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 1 deletions

View file

@ -288,6 +288,44 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
def testQueueFilling(self):
ray.init(num_cpus=4)
def f1(config):
for i in range(10):
yield i
tune.register_trainable("f1", f1)
search_alg = BasicVariantGenerator()
search_alg.add_configurations({
"foo": {
"run": "f1",
"num_samples": 100,
"config": {
"a": tune.sample_from(lambda spec: 5.0 / 7),
"b": tune.sample_from(lambda spec: "long" * 40)
},
"resources_per_trial": {
"cpu": 2
}
}
})
runner = TrialRunner(search_alg=search_alg)
runner.step()
runner.step()
runner.step()
self.assertEqual(len(runner._trials), 3)
runner.step()
self.assertEqual(len(runner._trials), 3)
self.assertEqual(runner._trials[0].status, Trial.RUNNING)
self.assertEqual(runner._trials[1].status, Trial.RUNNING)
self.assertEqual(runner._trials[2].status, Trial.PENDING)
if __name__ == "__main__":
import pytest

View file

@ -454,7 +454,10 @@ class TrialRunner:
"""
trials_done = all(trial.is_finished() for trial in self._trials)
wait_for_trial = trials_done and not self._search_alg.is_finished()
self._update_trial_queue(blocking=wait_for_trial)
# Only fetch a new trial if we have no pending trial
if not any(trial.status == Trial.PENDING for trial in self._trials) \
or wait_for_trial:
self._update_trial_queue(blocking=wait_for_trial)
with warn_if_slow("choose_trial_to_run"):
trial = self._scheduler_alg.choose_trial_to_run(self)
logger.debug("Running trial {}".format(trial))