From 5921e87ecd4e359fad60dab55f45855456d591e5 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 23 Sep 2020 19:08:12 +0100 Subject: [PATCH] [tune] Only add new trial when there is no pending trial (#10979) --- python/ray/tune/tests/test_trial_runner.py | 38 ++++++++++++++++++++++ python/ray/tune/trial_runner.py | 5 ++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 1dbac72ae..cf1232f7b 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -288,6 +288,44 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(runner.trial_executor._committed_resources.cpu, 2) + def testQueueFilling(self): + ray.init(num_cpus=4) + + def f1(config): + for i in range(10): + yield i + + tune.register_trainable("f1", f1) + + search_alg = BasicVariantGenerator() + search_alg.add_configurations({ + "foo": { + "run": "f1", + "num_samples": 100, + "config": { + "a": tune.sample_from(lambda spec: 5.0 / 7), + "b": tune.sample_from(lambda spec: "long" * 40) + }, + "resources_per_trial": { + "cpu": 2 + } + } + }) + + runner = TrialRunner(search_alg=search_alg) + + runner.step() + runner.step() + runner.step() + self.assertEqual(len(runner._trials), 3) + + runner.step() + self.assertEqual(len(runner._trials), 3) + + self.assertEqual(runner._trials[0].status, Trial.RUNNING) + self.assertEqual(runner._trials[1].status, Trial.RUNNING) + self.assertEqual(runner._trials[2].status, Trial.PENDING) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 0bd134718..d13073ad4 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -454,7 +454,10 @@ class TrialRunner: """ trials_done = all(trial.is_finished() for trial in self._trials) wait_for_trial = trials_done and not self._search_alg.is_finished() - self._update_trial_queue(blocking=wait_for_trial) + # Only fetch a new trial if we have no pending trial + if not any(trial.status == Trial.PENDING for trial in self._trials) \ + or wait_for_trial: + self._update_trial_queue(blocking=wait_for_trial) with warn_if_slow("choose_trial_to_run"): trial = self._scheduler_alg.choose_trial_to_run(self) logger.debug("Running trial {}".format(trial))