[tune] Limit maximum number of pending trials. Add convergence test. (#14835)

This commit is contained in:
Kai Fricke 2021-03-24 02:19:41 +01:00 committed by GitHub
parent 5d763b3f49
commit 898243d538
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 257 additions and 118 deletions

View file

@ -759,7 +759,7 @@ These are the environment variables Ray Tune currently considers:
* **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those
with the parameter values in them)
* **TUNE_MAX_PENDING_TRIALS_PG**: Maximum number of pending trials when placement groups are used. Defaults
to ``1000``.
to ``auto``, which will be updated to ``1000`` for random/grid search and ``1`` for any other search algorithms.
* **TUNE_PLACEMENT_GROUP_AUTO_DISABLED**: Ray Tune automatically uses placement groups
instead of the legacy resource requests. Setting this to 1 enables legacy placement.
* **TUNE_PLACEMENT_GROUP_CLEANUP_DISABLED**: Ray Tune cleans up existing placement groups
@ -767,6 +767,9 @@ These are the environment variables Ray Tune currently considers:
that scheduled placement groups are removed when multiple calls to ``tune.run()`` are
done in the same script. You might want to disable this if you run multiple Tune runs in
parallel from different scripts. Set to 1 to disable.
* **TUNE_PLACEMENT_GROUP_PREFIX**: Prefix for placement groups created by Ray Tune. This prefix is used
e.g. to identify placement groups that should be cleaned up on start/stop of the tuning run. This is
initialized to a unique name at the start of the first run.
* **TUNE_PLACEMENT_GROUP_WAIT_S**: Default time the trial executor waits for placement
groups to be placed before continuing the tuning loop. Setting this to a float
will block for that many seconds. This is mostly used for testing purposes. Defaults

View file

@ -50,9 +50,9 @@ py_test(
)
py_test(
name = "test_convergence_gaussian_process",
size = "small",
srcs = ["tests/test_convergence_gaussian_process.py"],
name = "test_convergence",
size = "medium",
srcs = ["tests/test_convergence.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)

View file

@ -19,7 +19,8 @@ from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
from ray.tune.utils.placement_groups import PlacementGroupManager
from ray.tune.utils.placement_groups import PlacementGroupManager, \
get_tune_pg_prefix
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
from ray.tune.trial_executor import TrialExecutor
@ -160,7 +161,7 @@ class RayTrialExecutor(TrialExecutor):
self._avail_resources = Resources(cpu=0, gpu=0)
self._committed_resources = Resources(cpu=0, gpu=0)
self._pg_manager = PlacementGroupManager()
self._pg_manager = PlacementGroupManager(prefix=get_tune_pg_prefix())
self._staged_trials = set()
self._just_staged_trials = set()
self._trial_just_finished = False
@ -197,6 +198,9 @@ class RayTrialExecutor(TrialExecutor):
"""Returns True if trials have recently been staged."""
return self._pg_manager.in_staging_grace_period()
def set_max_pending_trials(self, max_pending: int):
self._pg_manager.set_max_staging(max_pending)
def stage_and_update_status(self, trials: List[Trial]):
"""Check and update statuses of scheduled placement groups.
@ -783,7 +787,9 @@ class RayTrialExecutor(TrialExecutor):
"""
if trial.uses_placement_groups:
return trial in self._staged_trials or self._pg_manager.can_stage()
return trial in self._staged_trials or self._pg_manager.can_stage(
) or self._pg_manager.has_ready(
trial, update=True)
return self.has_resources(trial.resources)

View file

@ -213,6 +213,7 @@ class ExperimentPlateauStopper(Stopper):
return self.has_plateaued() and self._iterations >= self._patience
# Deprecate: 1.4
class EarlyStopping(ExperimentPlateauStopper):
def __init__(self, *args, **kwargs):
warnings.warn(

View file

@ -17,10 +17,11 @@ except ImportError:
# This exception only exists in newer Ax releases for python 3.7
try:
from ax.exceptions.core import DataRequiredError
from ax.exceptions.generation_strategy import \
MaxParallelismReachedException
except ImportError:
MaxParallelismReachedException = Exception
MaxParallelismReachedException = DataRequiredError = Exception
import logging
@ -262,7 +263,7 @@ class AxSearch(Searcher):
else:
try:
parameters, trial_index = self._ax.get_next_trial()
except MaxParallelismReachedException:
except (MaxParallelismReachedException, DataRequiredError):
return None
self._live_trial_mapping[trial_id] = trial_index

View file

@ -4,7 +4,6 @@ from __future__ import print_function
import inspect
import logging
import pickle
from typing import Dict, List, Optional, Union
from ray.tune.result import DEFAULT_METRIC
@ -331,17 +330,6 @@ class DragonflySearch(Searcher):
self._opt.tell([(trial_info,
self._metric_op * result[self._metric])])
def save(self, checkpoint_path: str):
trials_object = (self._initial_points, self._opt)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir: str):
with open(checkpoint_dir, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._initial_points = trials_object[0]
self._opt = trials_object[1]
@staticmethod
def convert_search_space(spec: Dict) -> List[Dict]:
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

View file

@ -1,6 +1,7 @@
import inspect
import logging
import pickle
from typing import Dict, Optional, Union, List, Sequence
from typing import Dict, Optional, Type, Union, List, Sequence
from ray.tune.result import DEFAULT_METRIC
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
@ -108,7 +109,8 @@ class NevergradSearch(Searcher):
"""
def __init__(self,
optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None,
optimizer: Union[None, Optimizer, Type[Optimizer],
ConfiguredOptimizer] = None,
space: Optional[Union[Dict, Parameter]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
@ -154,7 +156,9 @@ class NevergradSearch(Searcher):
"parameter.")
self._parameters = space
self._nevergrad_opt = optimizer
elif isinstance(optimizer, ConfiguredOptimizer):
elif (inspect.isclass(optimizer)
and issubclass(optimizer, Optimizer)) or isinstance(
optimizer, ConfiguredOptimizer):
self._opt_factory = optimizer
self._parameters = None
self._space = space

View file

@ -364,11 +364,11 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
@pytest.mark.parametrize("with_pg", [True, False])
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id,
with_pg):
"""Removing a node in full cluster causes Trial to be requeued."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
if not with_pg:
os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1"

View file

@ -0,0 +1,156 @@
import math
import numpy as np
import ray
from ray import tune
from ray.tune.stopper import ExperimentPlateauStopper
from ray.tune.suggest import ConcurrencyLimiter
import unittest
def loss(config, reporter):
x = config.get("x")
reporter(loss=x**2) # A simple function to optimize
class ConvergenceTest(unittest.TestCase):
"""Test convergence in gaussian process."""
@classmethod
def setUpClass(cls) -> None:
ray.init(local_mode=False, num_cpus=1, num_gpus=0)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def _testConvergence(self, searcher, top=3, patience=20):
# This is the space of parameters to explore
space = {"x": tune.uniform(0, 20)}
resources_per_trial = {"cpu": 1, "gpu": 0}
analysis = tune.run(
loss,
metric="loss",
mode="min",
stop=ExperimentPlateauStopper(
metric="loss", top=top, patience=patience),
search_alg=searcher,
config=space,
num_samples=100, # Number of iterations
resources_per_trial=resources_per_trial,
raise_on_failed_trial=False,
fail_fast=True,
reuse_actors=True,
verbose=1)
print(f"Num trials: {len(analysis.trials)}. "
f"Best result: {analysis.best_config['x']}")
return analysis
def testConvergenceAx(self):
from ray.tune.suggest.ax import AxSearch
np.random.seed(0)
searcher = AxSearch()
analysis = self._testConvergence(searcher, patience=10)
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5)
def testConvergenceBayesOpt(self):
from ray.tune.suggest.bayesopt import BayesOptSearch
np.random.seed(0)
# Following bayesian optimization
searcher = BayesOptSearch(random_search_steps=10)
searcher.repeat_float_precision = 5
searcher = ConcurrencyLimiter(searcher, 1)
analysis = self._testConvergence(searcher, patience=100)
assert len(analysis.trials) < 50
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5)
def testConvergenceDragonfly(self):
from ray.tune.suggest.dragonfly import DragonflySearch
np.random.seed(0)
searcher = DragonflySearch(domain="euclidean", optimizer="bandit")
analysis = self._testConvergence(searcher)
assert len(analysis.trials) < 100
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5)
def testConvergenceHEBO(self):
from ray.tune.suggest.hebo import HEBOSearch
np.random.seed(0)
searcher = HEBOSearch()
analysis = self._testConvergence(searcher)
assert len(analysis.trials) < 100
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2)
def testConvergenceHyperopt(self):
from ray.tune.suggest.hyperopt import HyperOptSearch
np.random.seed(0)
searcher = HyperOptSearch(random_state_seed=1234)
analysis = self._testConvergence(searcher, patience=50, top=5)
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2)
def testConvergenceNevergrad(self):
from ray.tune.suggest.nevergrad import NevergradSearch
import nevergrad as ng
np.random.seed(0)
searcher = NevergradSearch(optimizer=ng.optimizers.PSO)
analysis = self._testConvergence(searcher, patience=50, top=5)
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3)
def testConvergenceOptuna(self):
from ray.tune.suggest.optuna import OptunaSearch
np.random.seed(1)
searcher = OptunaSearch()
analysis = self._testConvergence(
searcher,
top=5,
)
# This assertion is much weaker than in the BO case, but TPE
# don't converge too close. It is still unlikely to get to this
# tolerance with random search (~0.01% chance)
assert len(analysis.trials) < 100
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-2)
def testConvergenceSkOpt(self):
from ray.tune.suggest.skopt import SkOptSearch
np.random.seed(0)
searcher = SkOptSearch()
analysis = self._testConvergence(searcher)
assert len(analysis.trials) < 100
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3)
def testConvergenceZoopt(self):
from ray.tune.suggest.zoopt import ZOOptSearch
np.random.seed(0)
searcher = ZOOptSearch(budget=100)
analysis = self._testConvergence(searcher)
assert len(analysis.trials) < 100
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-3)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,56 +0,0 @@
import math
import numpy as np
import ray
from ray import tune
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.suggest import ConcurrencyLimiter
import unittest
def loss(config, reporter):
x = config.get("x")
reporter(loss=x**2) # A simple function to optimize
class ConvergenceTest(unittest.TestCase):
"""Test convergence in gaussian process."""
def shutDown(self):
ray.shutdown()
def test_convergence_gaussian_process(self):
np.random.seed(0)
ray.init(local_mode=True, num_cpus=1, num_gpus=1)
# This is the space of parameters to explore
space = {"x": tune.uniform(0, 20)}
resources_per_trial = {"cpu": 1, "gpu": 0}
# Following bayesian optimization
gp = BayesOptSearch(random_search_steps=10)
gp.repeat_float_precision = 5
gp = ConcurrencyLimiter(gp, 1)
# Execution of the BO.
analysis = tune.run(
loss,
metric="loss",
mode="min",
# stop=EarlyStopping("loss", mode="min", patience=5),
search_alg=gp,
config=space,
num_samples=100, # Number of iterations
resources_per_trial=resources_per_trial,
raise_on_failed_trial=False,
fail_fast=True,
verbose=1)
assert len(analysis.trials) in {13, 40, 43} # it is 43 on the cluster?
assert math.isclose(analysis.best_config["x"], 0, abs_tol=1e-5)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,7 +1,6 @@
import os
import sys
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
@ -294,9 +293,9 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(runner.trial_executor._committed_resources.cpu, 2)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testQueueFilling(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=4)
def f1(config):

View file

@ -31,6 +31,8 @@ class TrialRunnerTest3(unittest.TestCase):
# Block for results even when placement groups are pending
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
@ -114,11 +116,10 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(trials[2].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.TERMINATED)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testSearchAlgNotification(self):
"""Checks notification of trial to the Search Algorithm."""
os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" # Don't finish early
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
@ -235,10 +236,9 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertTrue(search_alg.is_finished())
self.assertTrue(runner.is_finished())
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testSearchAlgFinishes(self):
"""Empty SearchAlg changing state in `next_trials` does not crash."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
class FinishFastAlg(_MockSuggestionAlgorithm):
_index = 0
@ -295,7 +295,7 @@ class TrialRunnerTest3(unittest.TestCase):
def __init__(self, index):
self.index = index
self.returned_result = []
super().__init__(metric="result", mode="max")
super().__init__(metric="episode_reward_mean", mode="max")
def suggest(self, trial_id):
self.index += 1
@ -506,10 +506,10 @@ class TrialRunnerTest3(unittest.TestCase):
runner2.step() # Process save
self.assertRaises(TuneError, runner2.step)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testTrialNoSave(self):
"""Check that non-checkpointing trials are not saved."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=3)
runner = TrialRunner(
@ -635,10 +635,9 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(trial.last_result[TRAINING_ITERATION], 9)
self.assertEqual(num_checkpoints(trial), 3)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testUserCheckpoint(self):
os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" # Don't finish early
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=3)
runner = TrialRunner(

View file

@ -3,7 +3,6 @@ import os
import time
import numpy as np
import unittest
from unittest.mock import patch
import ray
from ray import tune
@ -19,6 +18,7 @@ from ray.rllib import _register_all
class TrialRunnerPlacementGroupTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "10000"
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "auto" # Reset default
self.head_cpus = 8
self.head_gpus = 4
self.head_custom = 16
@ -154,15 +154,13 @@ class TrialRunnerPlacementGroupTest(unittest.TestCase):
"""Assert that reuse actors doesn't leak placement groups"""
self.testPlacementGroupRequests(reuse_actors=True)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 6)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 6)
def testPlacementGroupLimitedRequests(self):
"""Assert that maximum number of placement groups is enforced."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "6"
self.testPlacementGroupRequests(scheduled=6)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 6)
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 6)
def testPlacementGroupLimitedRequestsWithActorReuse(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "6"
self.testPlacementGroupRequests(reuse_actors=True, scheduled=6)
def testPlacementGroupDistributedTraining(self, reuse_actors=False):

View file

@ -6,7 +6,6 @@ import shutil
import tempfile
import time
import unittest
from unittest.mock import patch
import skopt
import numpy as np
@ -206,9 +205,9 @@ class TuneFailResumeGridTest(unittest.TestCase):
shutil.rmtree(self.logdir)
ray.shutdown()
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testFailResumeGridSearch(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
config = dict(
num_samples=3,
fail_fast=True,
@ -237,9 +236,9 @@ class TuneFailResumeGridTest(unittest.TestCase):
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
assert all(v == 9 for v in test2_counter.values())
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testFailResumeWithPreset(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
search_alg = BasicVariantGenerator(points_to_evaluate=[{
"test": -1,
"test2": -1
@ -280,9 +279,9 @@ class TuneFailResumeGridTest(unittest.TestCase):
assert test2_counter.pop(-1) == 4
assert all(v == 10 for v in test2_counter.values())
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testFailResumeAfterPreset(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
search_alg = BasicVariantGenerator(points_to_evaluate=[{
"test": -1,
"test2": -1
@ -324,9 +323,9 @@ class TuneFailResumeGridTest(unittest.TestCase):
assert test2_counter.pop(-1) == 4
assert all(v == 10 for v in test2_counter.values())
@patch("ray.tune.utils.placement_groups.TUNE_MAX_PENDING_TRIALS_PG", 1)
@patch("ray.tune.trial_runner.TUNE_MAX_PENDING_TRIALS_PG", 1)
def testMultiExperimentFail(self):
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
experiments = []
for i in range(3):
experiments.append(

View file

@ -285,3 +285,7 @@ class TrialExecutor:
def in_staging_grace_period(self) -> bool:
"""Returns True if trials have recently been staged."""
return False
def set_max_pending_trials(self, max_pending: int):
"""Set the maximum number of allowed pending trials."""
pass

View file

@ -22,7 +22,6 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.utils.placement_groups import TUNE_MAX_PENDING_TRIALS_PG
from ray.tune.utils.serialization import TuneFunctionDecoder, \
TuneFunctionEncoder
from ray.tune.web_server import TuneServer
@ -215,9 +214,18 @@ class TrialRunner:
self.trial_executor = trial_executor or RayTrialExecutor()
self._pending_trial_queue_times = {}
# Setting this to 0 still allows adding one new (pending) trial,
# but it will prevent us from trying to fill the trial list
self._max_pending_trials = 0 # Can be updated in `self.add_trial()`
# Set the number of maximum pending trials
max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto")
if max_pending_trials == "auto":
# Auto detect
if isinstance(self._search_alg, BasicVariantGenerator):
self._max_pending_trials = 1000
else:
self._max_pending_trials = 1
else:
# Manual override
self._max_pending_trials = int(max_pending_trials)
self.trial_executor.set_max_pending_trials(self._max_pending_trials)
self._metric = metric
@ -557,9 +565,6 @@ class TrialRunner:
Args:
trial (Trial): Trial to queue.
"""
if trial.uses_placement_groups:
self._max_pending_trials = TUNE_MAX_PENDING_TRIALS_PG
self._trials.append(trial)
with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self, trial)

View file

@ -16,9 +16,36 @@ from ray.util.placement_group import PlacementGroup, get_placement_group, \
if TYPE_CHECKING:
from ray.tune.trial import Trial
TUNE_MAX_PENDING_TRIALS_PG = int(os.getenv("TUNE_MAX_PENDING_TRIALS_PG", 1000))
TUNE_PLACEMENT_GROUP_REMOVAL_DELAY = 2.
_tune_pg_prefix = None
def get_tune_pg_prefix():
"""Get the tune placement group name prefix.
This will store the prefix in a global variable so that subsequent runs
can use this identifier to clean up placement groups before starting their
run.
Can be overwritten with the ``TUNE_PLACEMENT_GROUP_PREFIX`` env variable.
"""
global _tune_pg_prefix
if _tune_pg_prefix:
return _tune_pg_prefix
# Else: check env variable
env_prefix = os.getenv("TUNE_PLACEMENT_GROUP_PREFIX", "")
if env_prefix:
_tune_pg_prefix = env_prefix
return _tune_pg_prefix
# Else: create and store unique prefix
_tune_pg_prefix = f"__tune_{uuid.uuid4().hex[:8]}__"
return _tune_pg_prefix
class PlacementGroupFactory:
"""Wrapper class that creates placement groups for trials.
@ -187,7 +214,7 @@ class PlacementGroupManager:
prefix (str): Prefix for the placement group names that are created.
"""
def __init__(self, prefix: str = "_tune__"):
def __init__(self, prefix: str = "__tune__", max_staging: int = 1000):
self._prefix = prefix
# Sets of staged placement groups by factory
@ -220,6 +247,11 @@ class PlacementGroupManager:
self._grace_period = float(
os.getenv("TUNE_TRIAL_STARTUP_GRACE_PERIOD", 10.))
self._max_staging = max_staging
def set_max_staging(self, max_staging: int):
self._max_staging = max_staging
def remove_pg(self, pg: PlacementGroup):
"""Schedule placement group for (delayed) removal.
@ -314,7 +346,7 @@ class PlacementGroupManager:
def can_stage(self):
"""Return True if we can stage another placement group."""
return len(self._staging_futures) < TUNE_MAX_PENDING_TRIALS_PG
return len(self._staging_futures) < self._max_staging
def update_status(self):
"""Update placement group status.