mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[tune] Only add new trial when there is no pending trial (#10979)
This commit is contained in:
parent
7dbd0ff824
commit
5921e87ecd
2 changed files with 42 additions and 1 deletions
|
@ -288,6 +288,44 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||||
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
|
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
|
||||||
|
|
||||||
|
def testQueueFilling(self):
|
||||||
|
ray.init(num_cpus=4)
|
||||||
|
|
||||||
|
def f1(config):
|
||||||
|
for i in range(10):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
tune.register_trainable("f1", f1)
|
||||||
|
|
||||||
|
search_alg = BasicVariantGenerator()
|
||||||
|
search_alg.add_configurations({
|
||||||
|
"foo": {
|
||||||
|
"run": "f1",
|
||||||
|
"num_samples": 100,
|
||||||
|
"config": {
|
||||||
|
"a": tune.sample_from(lambda spec: 5.0 / 7),
|
||||||
|
"b": tune.sample_from(lambda spec: "long" * 40)
|
||||||
|
},
|
||||||
|
"resources_per_trial": {
|
||||||
|
"cpu": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
runner = TrialRunner(search_alg=search_alg)
|
||||||
|
|
||||||
|
runner.step()
|
||||||
|
runner.step()
|
||||||
|
runner.step()
|
||||||
|
self.assertEqual(len(runner._trials), 3)
|
||||||
|
|
||||||
|
runner.step()
|
||||||
|
self.assertEqual(len(runner._trials), 3)
|
||||||
|
|
||||||
|
self.assertEqual(runner._trials[0].status, Trial.RUNNING)
|
||||||
|
self.assertEqual(runner._trials[1].status, Trial.RUNNING)
|
||||||
|
self.assertEqual(runner._trials[2].status, Trial.PENDING)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -454,7 +454,10 @@ class TrialRunner:
|
||||||
"""
|
"""
|
||||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||||
wait_for_trial = trials_done and not self._search_alg.is_finished()
|
wait_for_trial = trials_done and not self._search_alg.is_finished()
|
||||||
self._update_trial_queue(blocking=wait_for_trial)
|
# Only fetch a new trial if we have no pending trial
|
||||||
|
if not any(trial.status == Trial.PENDING for trial in self._trials) \
|
||||||
|
or wait_for_trial:
|
||||||
|
self._update_trial_queue(blocking=wait_for_trial)
|
||||||
with warn_if_slow("choose_trial_to_run"):
|
with warn_if_slow("choose_trial_to_run"):
|
||||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||||
logger.debug("Running trial {}".format(trial))
|
logger.debug("Running trial {}".format(trial))
|
||||||
|
|
Loading…
Add table
Reference in a new issue