From 898243d538c38c41c34eb28e930585585408ef8a Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 24 Mar 2021 02:19:41 +0100 Subject: [PATCH] [tune] Limit maximum number of pending trials. Add convergence test. (#14835) --- doc/source/tune/user-guide.rst | 5 +- python/ray/tune/BUILD | 6 +- python/ray/tune/ray_trial_executor.py | 12 +- python/ray/tune/stopper.py | 1 + python/ray/tune/suggest/ax.py | 5 +- python/ray/tune/suggest/dragonfly.py | 12 -- python/ray/tune/suggest/nevergrad.py | 10 +- python/ray/tune/tests/test_cluster.py | 4 +- python/ray/tune/tests/test_convergence.py | 156 ++++++++++++++++++ .../test_convergence_gaussian_process.py | 56 ------- python/ray/tune/tests/test_trial_runner.py | 5 +- python/ray/tune/tests/test_trial_runner_3.py | 17 +- python/ray/tune/tests/test_trial_runner_pg.py | 8 +- python/ray/tune/tests/test_tune_restore.py | 17 +- python/ray/tune/trial_executor.py | 4 + python/ray/tune/trial_runner.py | 19 ++- python/ray/tune/utils/placement_groups.py | 38 ++++- 17 files changed, 257 insertions(+), 118 deletions(-) create mode 100644 python/ray/tune/tests/test_convergence.py delete mode 100644 python/ray/tune/tests/test_convergence_gaussian_process.py diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 6a78a264e..743f69219 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -759,7 +759,7 @@ These are the environment variables Ray Tune currently considers: * **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those with the parameter values in them) * **TUNE_MAX_PENDING_TRIALS_PG**: Maximum number of pending trials when placement groups are used. Defaults - to ``1000``. + to ``auto``, which will be updated to ``1000`` for random/grid search and ``1`` for any other search algorithms. * **TUNE_PLACEMENT_GROUP_AUTO_DISABLED**: Ray Tune automatically uses placement groups instead of the legacy resource requests. Setting this to 1 enables legacy placement. * **TUNE_PLACEMENT_GROUP_CLEANUP_DISABLED**: Ray Tune cleans up existing placement groups @@ -767,6 +767,9 @@ These are the environment variables Ray Tune currently considers: that scheduled placement groups are removed when multiple calls to ``tune.run()`` are done in the same script. You might want to disable this if you run multiple Tune runs in parallel from different scripts. Set to 1 to disable. +* **TUNE_PLACEMENT_GROUP_PREFIX**: Prefix for placement groups created by Ray Tune. This prefix is used + e.g. to identify placement groups that should be cleaned up on start/stop of the tuning run. This is + initialized to a unique name at the start of the first run. * **TUNE_PLACEMENT_GROUP_WAIT_S**: Default time the trial executor waits for placement groups to be placed before continuing the tuning loop. Setting this to a float will block for that many seconds. This is mostly used for testing purposes. Defaults diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 327a29524..357bc219d 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -50,9 +50,9 @@ py_test( ) py_test( - name = "test_convergence_gaussian_process", - size = "small", - srcs = ["tests/test_convergence_gaussian_process.py"], + name = "test_convergence", + size = "medium", + srcs = ["tests/test_convergence.py"], deps = [":tune_lib"], tags = ["exclusive"], ) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index ce5a72727..08c49941f 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -19,7 +19,8 @@ from ray.tune.error import AbortTrialExecution, TuneError from ray.tune.logger import NoopLogger from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.resources import Resources -from ray.tune.utils.placement_groups import PlacementGroupManager +from ray.tune.utils.placement_groups import PlacementGroupManager, \ + get_tune_pg_prefix from ray.tune.utils.trainable import TrainableUtil from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo from ray.tune.trial_executor import TrialExecutor @@ -160,7 +161,7 @@ class RayTrialExecutor(TrialExecutor): self._avail_resources = Resources(cpu=0, gpu=0) self._committed_resources = Resources(cpu=0, gpu=0) - self._pg_manager = PlacementGroupManager() + self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix()) self._staged_trials = set() self._just_staged_trials = set() self._trial_just_finished = False @@ -197,6 +198,9 @@ class RayTrialExecutor(TrialExecutor): """Returns True if trials have recently been staged.""" return self._pg_manager.in_staging_grace_period() + def set_max_pending_trials(self, max_pending: int): + self._pg_manager.set_max_staging(max_pending) + def stage_and_update_status(self, trials: List[Trial]): """Check and update statuses of scheduled placement groups. @@ -783,7 +787,9 @@ class RayTrialExecutor(TrialExecutor): """ if trial.uses_placement_groups: - return trial in self._staged_trials or self._pg_manager.can_stage() + return trial in self._staged_trials or self._pg_manager.can_stage( + ) or self._pg_manager.has_ready( + trial, update=True) return self.has_resources(trial.resources) diff --git a/python/ray/tune/stopper.py b/python/ray/tune/stopper.py index bc0940bd7..2279b4ed3 100644 --- a/python/ray/tune/stopper.py +++ b/python/ray/tune/stopper.py @@ -213,6 +213,7 @@ class ExperimentPlateauStopper(Stopper): return self.has_plateaued() and self._iterations >= self._patience +# Deprecate: 1.4 class EarlyStopping(ExperimentPlateauStopper): def __init__(self, *args, **kwargs): warnings.warn( diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index 85aa79f30..9b7867fbc 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -17,10 +17,11 @@ except ImportError: # This exception only exists in newer Ax releases for python 3.7 try: + from ax.exceptions.core import DataRequiredError from ax.exceptions.generation_strategy import \ MaxParallelismReachedException except ImportError: - MaxParallelismReachedException = Exception + MaxParallelismReachedException = DataRequiredError = Exception import logging @@ -262,7 +263,7 @@ class AxSearch(Searcher): else: try: parameters, trial_index = self._ax.get_next_trial() - except MaxParallelismReachedException: + except (MaxParallelismReachedException, DataRequiredError): return None self._live_trial_mapping[trial_id] = trial_index diff --git a/python/ray/tune/suggest/dragonfly.py b/python/ray/tune/suggest/dragonfly.py index a46cea143..0fdbcdafa 100644 --- a/python/ray/tune/suggest/dragonfly.py +++ b/python/ray/tune/suggest/dragonfly.py @@ -4,7 +4,6 @@ from __future__ import print_function import inspect import logging -import pickle from typing import Dict, List, Optional, Union from ray.tune.result import DEFAULT_METRIC @@ -331,17 +330,6 @@ class DragonflySearch(Searcher): self._opt.tell([(trial_info, self._metric_op * result[self._metric])]) - def save(self, checkpoint_path: str): - trials_object = (self._initial_points, self._opt) - with open(checkpoint_path, "wb") as outputFile: - pickle.dump(trials_object, outputFile) - - def restore(self, checkpoint_dir: str): - with open(checkpoint_dir, "rb") as inputFile: - trials_object = pickle.load(inputFile) - self._initial_points = trials_object[0] - self._opt = trials_object[1] - @staticmethod def convert_search_space(spec: Dict) -> List[Dict]: resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) diff --git a/python/ray/tune/suggest/nevergrad.py b/python/ray/tune/suggest/nevergrad.py index 879ef1a9a..c00d2894a 100644 --- a/python/ray/tune/suggest/nevergrad.py +++ b/python/ray/tune/suggest/nevergrad.py @@ -1,6 +1,7 @@ +import inspect import logging import pickle -from typing import Dict, Optional, Union, List, Sequence +from typing import Dict, Optional, Type, Union, List, Sequence from ray.tune.result import DEFAULT_METRIC from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ @@ -108,7 +109,8 @@ class NevergradSearch(Searcher): """ def __init__(self, - optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None, + optimizer: Union[None, Optimizer, Type[Optimizer], + ConfiguredOptimizer] = None, space: Optional[Union[Dict, Parameter]] = None, metric: Optional[str] = None, mode: Optional[str] = None, @@ -154,7 +156,9 @@ class NevergradSearch(Searcher): "parameter.") self._parameters = space self._nevergrad_opt = optimizer - elif isinstance(optimizer, ConfiguredOptimizer): + elif (inspect.isclass(optimizer) + and issubclass(optimizer, Optimizer)) or isinstance( + optimizer, ConfiguredOptimizer): self._opt_factory = optimizer self._parameters = None self._space = space diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 484a3ac98..4ca751b72 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -364,11 +364,11 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): @pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"]) @pytest.mark.parametrize("with_pg", [True, False]) -@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) -@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id, with_pg): """Removing a node in full cluster causes Trial to be requeued.""" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + if not with_pg: os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1" diff --git a/python/ray/tune/tests/test_convergence.py b/python/ray/tune/tests/test_convergence.py new file mode 100644 index 000000000..6b29af9f6 --- /dev/null +++ b/python/ray/tune/tests/test_convergence.py @@ -0,0 +1,156 @@ +import math +import numpy as np + +import ray +from ray import tune +from ray.tune.stopper import ExperimentPlateauStopper +from ray.tune.suggest import ConcurrencyLimiter +import unittest + + +def loss(config, reporter): + x = config.get("x") + reporter(loss=x**2) # A simple function to optimize + + +class ConvergenceTest(unittest.TestCase): + """Test convergence in gaussian process.""" + + @classmethod + def setUpClass(cls) -> None: + ray.init(local_mode=False, num_cpus=1, num_gpus=0) + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def _testConvergence(self, searcher, top=3, patience=20): + # This is the space of parameters to explore + space = {"x": tune.uniform(0, 20)} + + resources_per_trial = {"cpu": 1, "gpu": 0} + + analysis = tune.run( + loss, + metric="loss", + mode="min", + stop=ExperimentPlateauStopper( + metric="loss", top=top, patience=patience), + search_alg=searcher, + config=space, + num_samples=100, # Number of iterations + resources_per_trial=resources_per_trial, + raise_on_failed_trial=False, + fail_fast=True, + reuse_actors=True, + verbose=1) + print(f"Num trials: {len(analysis.trials)}. " + f"Best result: {analysis.best_config['x']}") + + return analysis + + def testConvergenceAx(self): + from ray.tune.suggest.ax import AxSearch + + np.random.seed(0) + + searcher = AxSearch() + analysis = self._testConvergence(searcher, patience=10) + + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5) + + def testConvergenceBayesOpt(self): + from ray.tune.suggest.bayesopt import BayesOptSearch + + np.random.seed(0) + + # Following bayesian optimization + searcher = BayesOptSearch(random_search_steps=10) + searcher.repeat_float_precision = 5 + searcher = ConcurrencyLimiter(searcher, 1) + + analysis = self._testConvergence(searcher, patience=100) + + assert len(analysis.trials) < 50 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5) + + def testConvergenceDragonfly(self): + from ray.tune.suggest.dragonfly import DragonflySearch + + np.random.seed(0) + searcher = DragonflySearch(domain="euclidean", optimizer="bandit") + analysis = self._testConvergence(searcher) + + assert len(analysis.trials) < 100 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5) + + def testConvergenceHEBO(self): + from ray.tune.suggest.hebo import HEBOSearch + + np.random.seed(0) + searcher = HEBOSearch() + analysis = self._testConvergence(searcher) + + assert len(analysis.trials) < 100 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2) + + def testConvergenceHyperopt(self): + from ray.tune.suggest.hyperopt import HyperOptSearch + + np.random.seed(0) + searcher = HyperOptSearch(random_state_seed=1234) + analysis = self._testConvergence(searcher, patience=50, top=5) + + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2) + + def testConvergenceNevergrad(self): + from ray.tune.suggest.nevergrad import NevergradSearch + import nevergrad as ng + + np.random.seed(0) + searcher = NevergradSearch(optimizer=ng.optimizers.PSO) + analysis = self._testConvergence(searcher, patience=50, top=5) + + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3) + + def testConvergenceOptuna(self): + from ray.tune.suggest.optuna import OptunaSearch + + np.random.seed(1) + searcher = OptunaSearch() + analysis = self._testConvergence( + searcher, + top=5, + ) + + # This assertion is much weaker than in the BO case, but TPE + # don't converge too close. It is still unlikely to get to this + # tolerance with random search (~0.01% chance) + assert len(analysis.trials) < 100 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2) + + def testConvergenceSkOpt(self): + from ray.tune.suggest.skopt import SkOptSearch + + np.random.seed(0) + searcher = SkOptSearch() + analysis = self._testConvergence(searcher) + + assert len(analysis.trials) < 100 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3) + + def testConvergenceZoopt(self): + from ray.tune.suggest.zoopt import ZOOptSearch + + np.random.seed(0) + searcher = ZOOptSearch(budget=100) + analysis = self._testConvergence(searcher) + + assert len(analysis.trials) < 100 + assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/tests/test_convergence_gaussian_process.py b/python/ray/tune/tests/test_convergence_gaussian_process.py deleted file mode 100644 index 811720014..000000000 --- a/python/ray/tune/tests/test_convergence_gaussian_process.py +++ /dev/null @@ -1,56 +0,0 @@ -import math -import numpy as np - -import ray -from ray import tune -from ray.tune.suggest.bayesopt import BayesOptSearch -from ray.tune.suggest import ConcurrencyLimiter -import unittest - - -def loss(config, reporter): - x = config.get("x") - reporter(loss=x**2) # A simple function to optimize - - -class ConvergenceTest(unittest.TestCase): - """Test convergence in gaussian process.""" - - def shutDown(self): - ray.shutdown() - - def test_convergence_gaussian_process(self): - np.random.seed(0) - ray.init(local_mode=True, num_cpus=1, num_gpus=1) - - # This is the space of parameters to explore - space = {"x": tune.uniform(0, 20)} - - resources_per_trial = {"cpu": 1, "gpu": 0} - - # Following bayesian optimization - gp = BayesOptSearch(random_search_steps=10) - gp.repeat_float_precision = 5 - gp = ConcurrencyLimiter(gp, 1) - - # Execution of the BO. - analysis = tune.run( - loss, - metric="loss", - mode="min", - # stop=EarlyStopping("loss", mode="min", patience=5), - search_alg=gp, - config=space, - num_samples=100, # Number of iterations - resources_per_trial=resources_per_trial, - raise_on_failed_trial=False, - fail_fast=True, - verbose=1) - assert len(analysis.trials) in {13, 40, 43} # it is 43 on the cluster? - assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index f2cd6c62a..5691e02bf 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1,7 +1,6 @@ import os import sys import unittest -from unittest.mock import patch import ray from ray.rllib import _register_all @@ -294,9 +293,9 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(runner.trial_executor._committed_resources.cpu, 2) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def testQueueFilling(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + ray.init(num_cpus=4) def f1(config): diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index 9306785ad..b6935e637 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -31,6 +31,8 @@ class TrialRunnerTest3(unittest.TestCase): # Block for results even when placement groups are pending os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default + self.tmpdir = tempfile.mkdtemp() def tearDown(self): @@ -114,11 +116,10 @@ class TrialRunnerTest3(unittest.TestCase): self.assertEqual(trials[2].status, Trial.RUNNING) self.assertEqual(trials[-1].status, Trial.TERMINATED) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def testSearchAlgNotification(self): """Checks notification of trial to the Search Algorithm.""" os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" # Don't finish early + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=4, num_gpus=2) experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}} @@ -235,10 +236,9 @@ class TrialRunnerTest3(unittest.TestCase): self.assertTrue(search_alg.is_finished()) self.assertTrue(runner.is_finished()) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def testSearchAlgFinishes(self): """Empty SearchAlg changing state in `next_trials` does not crash.""" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" class FinishFastAlg(_MockSuggestionAlgorithm): _index = 0 @@ -295,7 +295,7 @@ class TrialRunnerTest3(unittest.TestCase): def __init__(self, index): self.index = index self.returned_result = [] - super().__init__(metric="result", mode="max") + super().__init__(metric="episode_reward_mean", mode="max") def suggest(self, trial_id): self.index += 1 @@ -506,10 +506,10 @@ class TrialRunnerTest3(unittest.TestCase): runner2.step() # Process save self.assertRaises(TuneError, runner2.step) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def testTrialNoSave(self): """Check that non-checkpointing trials are not saved.""" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + ray.init(num_cpus=3) runner = TrialRunner( @@ -635,10 +635,9 @@ class TrialRunnerTest3(unittest.TestCase): self.assertEqual(trial.last_result[TRAINING_ITERATION], 9) self.assertEqual(num_checkpoints(trial), 3) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) def testUserCheckpoint(self): os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" # Don't finish early + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=3) runner = TrialRunner( diff --git a/python/ray/tune/tests/test_trial_runner_pg.py b/python/ray/tune/tests/test_trial_runner_pg.py index 971a8c874..25b999640 100644 --- a/python/ray/tune/tests/test_trial_runner_pg.py +++ b/python/ray/tune/tests/test_trial_runner_pg.py @@ -3,7 +3,6 @@ import os import time import numpy as np import unittest -from unittest.mock import patch import ray from ray import tune @@ -19,6 +18,7 @@ from ray.rllib import _register_all class TrialRunnerPlacementGroupTest(unittest.TestCase): def setUp(self): os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "10000" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default self.head_cpus = 8 self.head_gpus = 4 self.head_custom = 16 @@ -154,15 +154,13 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase): """Assert that reuse actors doesn't leak placement groups""" self.testPlacementGroupRequests(reuse_actors=True) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 6) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 6) def testPlacementGroupLimitedRequests(self): """Assert that maximum number of placement groups is enforced.""" + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "6" self.testPlacementGroupRequests(scheduled=6) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 6) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 6) def testPlacementGroupLimitedRequestsWithActorReuse(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "6" self.testPlacementGroupRequests(reuse_actors=True, scheduled=6) def testPlacementGroupDistributedTraining(self, reuse_actors=False): diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index 7b220daef..261276d5c 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -6,7 +6,6 @@ import shutil import tempfile import time import unittest -from unittest.mock import patch import skopt import numpy as np @@ -206,9 +205,9 @@ class TuneFailResumeGridTest(unittest.TestCase): shutil.rmtree(self.logdir) ray.shutdown() - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) def testFailResumeGridSearch(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + config = dict( num_samples=3, fail_fast=True, @@ -237,9 +236,9 @@ class TuneFailResumeGridTest(unittest.TestCase): test2_counter = Counter([t.config["test2"] for t in analysis.trials]) assert all(v == 9 for v in test2_counter.values()) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) def testFailResumeWithPreset(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + search_alg = BasicVariantGenerator(points_to_evaluate=[{ "test": -1, "test2": -1 @@ -280,9 +279,9 @@ class TuneFailResumeGridTest(unittest.TestCase): assert test2_counter.pop(-1) == 4 assert all(v == 10 for v in test2_counter.values()) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) def testFailResumeAfterPreset(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + search_alg = BasicVariantGenerator(points_to_evaluate=[{ "test": -1, "test2": -1 @@ -324,9 +323,9 @@ class TuneFailResumeGridTest(unittest.TestCase): assert test2_counter.pop(-1) == 4 assert all(v == 10 for v in test2_counter.values()) - @patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1) - @patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1) def testMultiExperimentFail(self): + os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" + experiments = [] for i in range(3): experiments.append( diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 2938514a4..6502fa79d 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -285,3 +285,7 @@ class TrialExecutor: def in_staging_grace_period(self) -> bool: """Returns True if trials have recently been staged.""" return False + + def set_max_pending_trials(self, max_pending: int): + """Set the maximum number of allowed pending trials.""" + pass diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 28e8746d8..30d65c4ce 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -22,7 +22,6 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm from ray.tune.utils import warn_if_slow, flatten_dict from ray.tune.utils.log import Verbosity, has_verbosity -from ray.tune.utils.placement_groups import TUNE_MAX_PENDING_TRIALS_PG from ray.tune.utils.serialization import TuneFunctionDecoder, \ TuneFunctionEncoder from ray.tune.web_server import TuneServer @@ -215,9 +214,18 @@ class TrialRunner: self.trial_executor = trial_executor or RayTrialExecutor() self._pending_trial_queue_times = {} - # Setting this to 0 still allows adding one new (pending) trial, - # but it will prevent us from trying to fill the trial list - self._max_pending_trials = 0 # Can be updated in `self.add_trial()` + # Set the number of maximum pending trials + max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto") + if max_pending_trials == "auto": + # Auto detect + if isinstance(self._search_alg, BasicVariantGenerator): + self._max_pending_trials = 1000 + else: + self._max_pending_trials = 1 + else: + # Manual override + self._max_pending_trials = int(max_pending_trials) + self.trial_executor.set_max_pending_trials(self._max_pending_trials) self._metric = metric @@ -557,9 +565,6 @@ class TrialRunner: Args: trial (Trial): Trial to queue. """ - if trial.uses_placement_groups: - self._max_pending_trials = TUNE_MAX_PENDING_TRIALS_PG - self._trials.append(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self, trial) diff --git a/python/ray/tune/utils/placement_groups.py b/python/ray/tune/utils/placement_groups.py index 01c6cf025..596b3dd5e 100644 --- a/python/ray/tune/utils/placement_groups.py +++ b/python/ray/tune/utils/placement_groups.py @@ -16,9 +16,36 @@ from ray.util.placement_group import PlacementGroup, get_placement_group, \ if TYPE_CHECKING: from ray.tune.trial import Trial -TUNE_MAX_PENDING_TRIALS_PG = int(os.getenv("TUNE_MAX_PENDING_TRIALS_PG", 1000)) TUNE_PLACEMENT_GROUP_REMOVAL_DELAY = 2. +_tune_pg_prefix = None + + +def get_tune_pg_prefix(): + """Get the tune placement group name prefix. + + This will store the prefix in a global variable so that subsequent runs + can use this identifier to clean up placement groups before starting their + run. + + Can be overwritten with the ``TUNE_PLACEMENT_GROUP_PREFIX`` env variable. + """ + global _tune_pg_prefix + + if _tune_pg_prefix: + return _tune_pg_prefix + + # Else: check env variable + env_prefix = os.getenv("TUNE_PLACEMENT_GROUP_PREFIX", "") + + if env_prefix: + _tune_pg_prefix = env_prefix + return _tune_pg_prefix + + # Else: create and store unique prefix + _tune_pg_prefix = f"__tune_{uuid.uuid4().hex[:8]}__" + return _tune_pg_prefix + class PlacementGroupFactory: """Wrapper class that creates placement groups for trials. @@ -187,7 +214,7 @@ class PlacementGroupManager: prefix (str): Prefix for the placement group names that are created. """ - def __init__(self, prefix: str = "_tune__"): + def __init__(self, prefix: str = "__tune__", max_staging: int = 1000): self._prefix = prefix # Sets of staged placement groups by factory @@ -220,6 +247,11 @@ class PlacementGroupManager: self._grace_period = float( os.getenv("TUNE_TRIAL_STARTUP_GRACE_PERIOD", 10.)) + self._max_staging = max_staging + + def set_max_staging(self, max_staging: int): + self._max_staging = max_staging + def remove_pg(self, pg: PlacementGroup): """Schedule placement group for (delayed) removal. @@ -314,7 +346,7 @@ class PlacementGroupManager: def can_stage(self): """Return True if we can stage another placement group.""" - return len(self._staging_futures) < TUNE_MAX_PENDING_TRIALS_PG + return len(self._staging_futures) < self._max_staging def update_status(self): """Update placement group status.