mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Tune] Remove queue_trials. (#19472)
This commit is contained in:
parent
580b58a68f
commit
a632cb439f
9 changed files with 50 additions and 213 deletions
|
@ -169,15 +169,11 @@ class RayTrialExecutor(TrialExecutor):
|
|||
"""An implementation of TrialExecutor based on Ray."""
|
||||
|
||||
def __init__(self,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
result_buffer_length: Optional[int] = None,
|
||||
refresh_period: Optional[float] = None,
|
||||
wait_for_placement_group: Optional[float] = None):
|
||||
super(RayTrialExecutor, self).__init__(queue_trials)
|
||||
# Check for if we are launching a trial without resources in kick off
|
||||
# autoscaler.
|
||||
self._trial_queued = False
|
||||
super(RayTrialExecutor, self).__init__()
|
||||
self._running = {}
|
||||
# Since trial resume after paused should not run
|
||||
# trial.train.remote(), thus no more new remote object ref generated.
|
||||
|
@ -881,9 +877,9 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def has_resources_for_trial(self, trial: Trial) -> bool:
|
||||
"""Returns whether this runner has resources available for this trial.
|
||||
|
||||
If using placement groups, this will return True as long as we
|
||||
didn't reach the maximum number of pending trials. It will also return
|
||||
True if the trial placement group is already staged.
|
||||
This will return True as long as we didn't reach the maximum number
|
||||
of pending trials. It will also return True if the trial placement
|
||||
group is already staged.
|
||||
|
||||
Args:
|
||||
trial: Trial object which should be scheduled.
|
||||
|
@ -924,19 +920,6 @@ class RayTrialExecutor(TrialExecutor):
|
|||
if have_space:
|
||||
# The assumption right now is that we block all trials if one
|
||||
# trial is queued.
|
||||
self._trial_queued = False
|
||||
return True
|
||||
|
||||
can_overcommit = self._queue_trials and not self._trial_queued
|
||||
if can_overcommit:
|
||||
self._trial_queued = True
|
||||
logger.warning(
|
||||
"Allowing trial to start even though the "
|
||||
"cluster does not have enough free resources. Trial actors "
|
||||
"may appear to hang until enough resources are added to the "
|
||||
"cluster (e.g., via autoscaling). You can disable this "
|
||||
"behavior by specifying `queue_trials=False` in "
|
||||
"ray.tune.run().")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
@ -265,31 +265,24 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
|
||||
register_trainable("B", B)
|
||||
|
||||
def f(cpus, gpus, queue_trials):
|
||||
return run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "B",
|
||||
"config": {
|
||||
"cpu": cpus,
|
||||
"gpu": gpus,
|
||||
},
|
||||
}
|
||||
},
|
||||
queue_trials=queue_trials)[0]
|
||||
def f(cpus, gpus):
|
||||
return run_experiments({
|
||||
"foo": {
|
||||
"run": "B",
|
||||
"config": {
|
||||
"cpu": cpus,
|
||||
"gpu": gpus,
|
||||
},
|
||||
}
|
||||
})[0]
|
||||
|
||||
# Should all succeed
|
||||
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
|
||||
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
|
||||
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
|
||||
self.assertEqual(f(0, 0).status, Trial.TERMINATED)
|
||||
|
||||
# Too large resource request
|
||||
self.assertRaises(TuneError, lambda: f(100, 100, False))
|
||||
self.assertRaises(TuneError, lambda: f(0, 100, False))
|
||||
self.assertRaises(TuneError, lambda: f(100, 0, False))
|
||||
|
||||
# TODO(ekl) how can we test this is queued (hangs)?
|
||||
# f(100, 0, True)
|
||||
self.assertRaises(TuneError, lambda: f(100, 100))
|
||||
self.assertRaises(TuneError, lambda: f(0, 100))
|
||||
self.assertRaises(TuneError, lambda: f(100, 0))
|
||||
|
||||
def testRewriteEnv(self):
|
||||
def train(config, reporter):
|
||||
|
|
|
@ -17,8 +17,6 @@ from ray._private.test_utils import run_string_as_driver_nonblocking
|
|||
from ray.tune import register_trainable
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
|
@ -218,60 +216,6 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
|||
runner.step()
|
||||
|
||||
|
||||
def test_queue_trials(start_connected_emptyhead_cluster):
|
||||
"""Tests explicit oversubscription for autoscaling.
|
||||
|
||||
Tune oversubscribes a trial when `queue_trials=True`, but
|
||||
does not block other trials from running.
|
||||
"""
|
||||
os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1"
|
||||
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
runner = TrialRunner()
|
||||
|
||||
def create_trial(cpu, gpu=0):
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=cpu, gpu=gpu),
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
}
|
||||
}
|
||||
return Trial("__fake", **kwargs)
|
||||
|
||||
runner.add_trial(create_trial(cpu=1))
|
||||
with pytest.raises(TuneError):
|
||||
runner.step() # run 1
|
||||
|
||||
del runner
|
||||
|
||||
executor = RayTrialExecutor(queue_trials=True)
|
||||
runner = TrialRunner(trial_executor=executor)
|
||||
cluster.add_node(num_cpus=2)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
cpu_only = create_trial(cpu=1)
|
||||
runner.add_trial(cpu_only)
|
||||
runner.step() # add cpu_only trial
|
||||
|
||||
gpu_trial = create_trial(cpu=1, gpu=1)
|
||||
runner.add_trial(gpu_trial)
|
||||
runner.step() # queue gpu_trial
|
||||
|
||||
# This tests that the cpu_only trial should bypass the queued trial.
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
assert cpu_only.status == Trial.TERMINATED
|
||||
assert gpu_trial.status == Trial.RUNNING
|
||||
|
||||
# Scale up
|
||||
cluster.add_node(num_cpus=1, num_gpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
assert gpu_trial.status == Trial.TERMINATED
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
"""Removing a node while cluster has space should migrate trial.
|
||||
|
|
|
@ -85,7 +85,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
||||
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
self.trial_executor = RayTrialExecutor()
|
||||
ray.init(num_cpus=2, ignore_reinit_error=True)
|
||||
_register_all() # Needed for flaky tests
|
||||
|
||||
|
@ -190,7 +190,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1"
|
||||
|
||||
# Need a new trial executor so the ENV vars are parsed again
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
self.trial_executor = RayTrialExecutor()
|
||||
|
||||
base = max(result_buffer_length, 1)
|
||||
|
||||
|
@ -298,7 +298,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
}, "grid_search")
|
||||
trial = trials[0]
|
||||
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
self.trial_executor = RayTrialExecutor()
|
||||
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
|
@ -336,74 +336,6 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
trial.on_checkpoint(checkpoint)
|
||||
|
||||
|
||||
class RayExecutorQueueTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
head_node_args={
|
||||
"num_cpus": 1,
|
||||
"_system_config": {
|
||||
"num_heartbeats_timeout": 10
|
||||
}
|
||||
})
|
||||
self.trial_executor = RayTrialExecutor(
|
||||
queue_trials=True, refresh_period=0)
|
||||
# Pytest doesn't play nicely with imports
|
||||
_register_all()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
self.cluster.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testQueueTrial(self):
|
||||
"""Tests that reset handles NotImplemented properly."""
|
||||
|
||||
def create_trial(cpu, gpu=0):
|
||||
return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu))
|
||||
|
||||
cpu_only = create_trial(1, 0)
|
||||
self.assertTrue(self.trial_executor.has_resources_for_trial(cpu_only))
|
||||
self.trial_executor.start_trial(cpu_only)
|
||||
|
||||
gpu_only = create_trial(0, 1)
|
||||
self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_only))
|
||||
|
||||
def testHeadBlocking(self):
|
||||
# Once resource requests are deprecated, remove this test
|
||||
os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1"
|
||||
|
||||
def create_trial(cpu, gpu=0):
|
||||
return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu))
|
||||
|
||||
gpu_trial = create_trial(1, 1)
|
||||
self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_trial))
|
||||
self.trial_executor.start_trial(gpu_trial)
|
||||
|
||||
# TODO(rliaw): This behavior is probably undesirable, but right now
|
||||
# trials with different resource requirements is not often used.
|
||||
cpu_only_trial = create_trial(1, 0)
|
||||
self.assertFalse(
|
||||
self.trial_executor.has_resources_for_trial(cpu_only_trial))
|
||||
|
||||
self.cluster.add_node(num_cpus=1, num_gpus=1)
|
||||
self.cluster.wait_for_nodes()
|
||||
|
||||
self.assertTrue(
|
||||
self.trial_executor.has_resources_for_trial(cpu_only_trial))
|
||||
self.trial_executor.start_trial(cpu_only_trial)
|
||||
|
||||
cpu_only_trial2 = create_trial(1, 0)
|
||||
self.assertTrue(
|
||||
self.trial_executor.has_resources_for_trial(cpu_only_trial2))
|
||||
self.trial_executor.start_trial(cpu_only_trial2)
|
||||
|
||||
cpu_only_trial3 = create_trial(1, 0)
|
||||
self.assertFalse(
|
||||
self.trial_executor.has_resources_for_trial(cpu_only_trial3))
|
||||
|
||||
|
||||
class RayExecutorPlacementGroupTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.head_cpus = 8
|
||||
|
@ -537,7 +469,7 @@ class RayExecutorPlacementGroupTest(unittest.TestCase):
|
|||
class LocalModeExecutorTest(RayTrialExecutorTest):
|
||||
def setUp(self):
|
||||
ray.init(local_mode=True)
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
self.trial_executor = RayTrialExecutor()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
|
|
@ -89,8 +89,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
|||
checkpoint_freq=1,
|
||||
fail_fast=True,
|
||||
config={"a": tune.sample_from(lambda _: param_a())},
|
||||
trial_executor=CustomExecutor(
|
||||
queue_trials=False, reuse_actors=False),
|
||||
trial_executor=CustomExecutor(reuse_actors=False),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -118,16 +118,9 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
and starting/stopping trials.
|
||||
"""
|
||||
|
||||
def __init__(self, queue_trials: bool = False):
|
||||
def __init__(self):
|
||||
"""Initializes a new TrialExecutor.
|
||||
|
||||
Args:
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
"""
|
||||
self._queue_trials = queue_trials
|
||||
self._cached_trial_state = {}
|
||||
self._trials_to_cache = set()
|
||||
# The next two variables are used to keep track of if there is any
|
||||
|
@ -338,29 +331,21 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
trials (List[Trial]): The list of trials. Note, refrain from
|
||||
providing TrialRunner directly here.
|
||||
"""
|
||||
if self._queue_trials:
|
||||
return
|
||||
self._may_warn_insufficient_resources(trials)
|
||||
for trial in trials:
|
||||
if trial.uses_placement_groups:
|
||||
return
|
||||
# TODO(xwjiang): The rest should be gone in a follow up PR
|
||||
# to remove non-pg case.
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources_for_trial(trial):
|
||||
resource_string = trial.resources.summary_string()
|
||||
trial_resource_help_msg = trial.get_trainable_cls(
|
||||
).resource_help(trial.config)
|
||||
autoscaling_msg = ""
|
||||
if is_ray_cluster():
|
||||
autoscaling_msg = (
|
||||
"Pass `queue_trials=True` in ray.tune.run() or "
|
||||
"on the command line to queue trials until the "
|
||||
"cluster scales up or resources become available. "
|
||||
)
|
||||
raise TuneError(
|
||||
"Insufficient cluster resources to launch trial: "
|
||||
f"trial requested {resource_string}, but the cluster "
|
||||
f"has only {self.resource_string()}. "
|
||||
f"{autoscaling_msg}"
|
||||
f"{trial_resource_help_msg} ")
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError("There are paused trials, but no more pending "
|
||||
|
|
|
@ -100,15 +100,15 @@ def run(
|
|||
restore: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
resume: bool = False,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
trial_executor: Optional[RayTrialExecutor] = None,
|
||||
raise_on_failed_trial: bool = True,
|
||||
callbacks: Optional[Sequence[Callback]] = None,
|
||||
max_concurrent_trials: Optional[int] = None,
|
||||
# Deprecated args
|
||||
queue_trials: Optional[bool] = None,
|
||||
loggers: Optional[Sequence[Type[Logger]]] = None,
|
||||
_remote: bool = None,
|
||||
_remote: Optional[bool] = None,
|
||||
) -> ExperimentAnalysis:
|
||||
"""Executes training.
|
||||
|
||||
|
@ -261,10 +261,6 @@ def run(
|
|||
ERRORED trials upon resume - previous trial artifacts will
|
||||
be left untouched. If resume is set but checkpoint does not exist,
|
||||
ValueError will be thrown.
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
automatic scale-up.
|
||||
reuse_actors (bool): Whether to reuse actors between different trials
|
||||
when possible. This can drastically speed up experiments that start
|
||||
and stop actors often (e.g., PBT in time-multiplexing mode). This
|
||||
|
@ -294,6 +290,15 @@ def run(
|
|||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
|
||||
# To be removed in 1.9.
|
||||
if queue_trials is not None:
|
||||
raise DeprecationWarning(
|
||||
"`queue_trials` has been deprecated and is replaced by "
|
||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
||||
"Per default at least one Trial is queued at all times, "
|
||||
"so you likely don't need to change anything other than "
|
||||
"removing this argument from your call to `tune.run()`")
|
||||
|
||||
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||
# remote_run_kwargs must be defined before any other
|
||||
# code is ran to ensure that at this point,
|
||||
|
@ -409,9 +414,7 @@ def run(
|
|||
result_buffer_length = 1
|
||||
|
||||
trial_executor = trial_executor or RayTrialExecutor(
|
||||
reuse_actors=reuse_actors,
|
||||
queue_trials=queue_trials,
|
||||
result_buffer_length=result_buffer_length)
|
||||
reuse_actors=reuse_actors, result_buffer_length=result_buffer_length)
|
||||
if isinstance(run_or_experiment, list):
|
||||
experiments = run_or_experiment
|
||||
else:
|
||||
|
@ -653,13 +656,14 @@ def run_experiments(
|
|||
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter: Optional[ProgressReporter] = None,
|
||||
resume: bool = False,
|
||||
queue_trials: bool = False,
|
||||
reuse_actors: bool = False,
|
||||
trial_executor: Optional[RayTrialExecutor] = None,
|
||||
raise_on_failed_trial: bool = True,
|
||||
concurrent: bool = True,
|
||||
# Deprecated args.
|
||||
queue_trials: Optional[bool] = None,
|
||||
callbacks: Optional[Sequence[Callback]] = None,
|
||||
_remote: bool = None):
|
||||
_remote: Optional[bool] = None):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Examples:
|
||||
|
@ -673,6 +677,15 @@ def run_experiments(
|
|||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
"""
|
||||
# To be removed in 1.9.
|
||||
if queue_trials is not None:
|
||||
raise DeprecationWarning(
|
||||
"`queue_trials` has been deprecated and is replaced by "
|
||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
||||
"Per default at least one Trial is queued at all times, "
|
||||
"so you likely don't need to change anything other than "
|
||||
"removing this argument from your call to `tune.run()`")
|
||||
|
||||
if _remote is None:
|
||||
_remote = ray.util.client.ray.is_connected()
|
||||
|
||||
|
@ -696,7 +709,6 @@ def run_experiments(
|
|||
verbose,
|
||||
progress_reporter,
|
||||
resume,
|
||||
queue_trials,
|
||||
reuse_actors,
|
||||
trial_executor,
|
||||
raise_on_failed_trial,
|
||||
|
@ -716,7 +728,6 @@ def run_experiments(
|
|||
verbose=verbose,
|
||||
progress_reporter=progress_reporter,
|
||||
resume=resume,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
|
@ -731,7 +742,6 @@ def run_experiments(
|
|||
verbose=verbose,
|
||||
progress_reporter=progress_reporter,
|
||||
resume=resume,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
|
|
|
@ -129,7 +129,6 @@ analysis = tune.run(
|
|||
FailureInjectorCallback(time_between_checks=90),
|
||||
ProgressCallback()
|
||||
],
|
||||
queue_trials=True,
|
||||
stop={"training_iteration": 1} if args.smoke_test else None)
|
||||
|
||||
print(analysis.get_best_config(metric="val_loss", mode="min"))
|
||||
|
|
|
@ -123,13 +123,6 @@ def create_parser(parser_creator=None):
|
|||
help="Whether to attempt to enable tracing for eager mode.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--queue-trials",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to queue trials when the cluster does not currently have "
|
||||
"enough resources to launch one. This should be set to True when "
|
||||
"running on an autoscaling cluster to enable automatic scale-up."))
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--config-file",
|
||||
|
@ -267,7 +260,6 @@ def run(args, parser):
|
|||
experiments,
|
||||
scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
|
||||
resume=args.resume,
|
||||
queue_trials=args.queue_trials,
|
||||
verbose=verbose,
|
||||
progress_reporter=progress_reporter,
|
||||
concurrent=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue