[Tune] Remove queue_trials. (#19472)

This commit is contained in:
xwjiang2010 2021-10-22 01:24:54 -07:00 committed by GitHub
parent 580b58a68f
commit a632cb439f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 50 additions and 213 deletions

View file

@ -169,15 +169,11 @@ class RayTrialExecutor(TrialExecutor):
"""An implementation of TrialExecutor based on Ray."""
def __init__(self,
queue_trials: bool = False,
reuse_actors: bool = False,
result_buffer_length: Optional[int] = None,
refresh_period: Optional[float] = None,
wait_for_placement_group: Optional[float] = None):
super(RayTrialExecutor, self).__init__(queue_trials)
# Check for if we are launching a trial without resources in kick off
# autoscaler.
self._trial_queued = False
super(RayTrialExecutor, self).__init__()
self._running = {}
# Since trial resume after paused should not run
# trial.train.remote(), thus no more new remote object ref generated.
@ -881,9 +877,9 @@ class RayTrialExecutor(TrialExecutor):
def has_resources_for_trial(self, trial: Trial) -> bool:
"""Returns whether this runner has resources available for this trial.
If using placement groups, this will return True as long as we
didn't reach the maximum number of pending trials. It will also return
True if the trial placement group is already staged.
This will return True as long as we didn't reach the maximum number
of pending trials. It will also return True if the trial placement
group is already staged.
Args:
trial: Trial object which should be scheduled.
@ -924,19 +920,6 @@ class RayTrialExecutor(TrialExecutor):
if have_space:
# The assumption right now is that we block all trials if one
# trial is queued.
self._trial_queued = False
return True
can_overcommit = self._queue_trials and not self._trial_queued
if can_overcommit:
self._trial_queued = True
logger.warning(
"Allowing trial to start even though the "
"cluster does not have enough free resources. Trial actors "
"may appear to hang until enough resources are added to the "
"cluster (e.g., via autoscaling). You can disable this "
"behavior by specifying `queue_trials=False` in "
"ray.tune.run().")
return True
return False

View file

@ -265,31 +265,24 @@ class TrainableFunctionApiTest(unittest.TestCase):
register_trainable("B", B)
def f(cpus, gpus, queue_trials):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
queue_trials=queue_trials)[0]
def f(cpus, gpus):
return run_experiments({
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
})[0]
# Should all succeed
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
self.assertEqual(f(0, 0).status, Trial.TERMINATED)
# Too large resource request
self.assertRaises(TuneError, lambda: f(100, 100, False))
self.assertRaises(TuneError, lambda: f(0, 100, False))
self.assertRaises(TuneError, lambda: f(100, 0, False))
# TODO(ekl) how can we test this is queued (hangs)?
# f(100, 0, True)
self.assertRaises(TuneError, lambda: f(100, 100))
self.assertRaises(TuneError, lambda: f(0, 100))
self.assertRaises(TuneError, lambda: f(100, 0))
def testRewriteEnv(self):
def train(config, reporter):

View file

@ -17,8 +17,6 @@ from ray._private.test_utils import run_string_as_driver_nonblocking
from ray.tune import register_trainable
from ray.tune.experiment import Experiment
from ray.tune.error import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.resources import Resources
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
from ray.tune.utils.trainable import TrainableUtil
@ -218,60 +216,6 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
runner.step()
def test_queue_trials(start_connected_emptyhead_cluster):
"""Tests explicit oversubscription for autoscaling.
Tune oversubscribes a trial when `queue_trials=True`, but
does not block other trials from running.
"""
os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1"
cluster = start_connected_emptyhead_cluster
runner = TrialRunner()
def create_trial(cpu, gpu=0):
kwargs = {
"resources": Resources(cpu=cpu, gpu=gpu),
"stopping_criterion": {
"training_iteration": 3
}
}
return Trial("__fake", **kwargs)
runner.add_trial(create_trial(cpu=1))
with pytest.raises(TuneError):
runner.step() # run 1
del runner
executor = RayTrialExecutor(queue_trials=True)
runner = TrialRunner(trial_executor=executor)
cluster.add_node(num_cpus=2)
cluster.wait_for_nodes()
cpu_only = create_trial(cpu=1)
runner.add_trial(cpu_only)
runner.step() # add cpu_only trial
gpu_trial = create_trial(cpu=1, gpu=1)
runner.add_trial(gpu_trial)
runner.step() # queue gpu_trial
# This tests that the cpu_only trial should bypass the queued trial.
for i in range(3):
runner.step()
assert cpu_only.status == Trial.TERMINATED
assert gpu_trial.status == Trial.RUNNING
# Scale up
cluster.add_node(num_cpus=1, num_gpus=1)
cluster.wait_for_nodes()
for i in range(3):
runner.step()
assert gpu_trial.status == Trial.TERMINATED
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"""Removing a node while cluster has space should migrate trial.

View file

@ -85,7 +85,7 @@ class RayTrialExecutorTest(unittest.TestCase):
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
self.trial_executor = RayTrialExecutor(queue_trials=False)
self.trial_executor = RayTrialExecutor()
ray.init(num_cpus=2, ignore_reinit_error=True)
_register_all() # Needed for flaky tests
@ -190,7 +190,7 @@ class RayTrialExecutorTest(unittest.TestCase):
os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1"
# Need a new trial executor so the ENV vars are parsed again
self.trial_executor = RayTrialExecutor(queue_trials=False)
self.trial_executor = RayTrialExecutor()
base = max(result_buffer_length, 1)
@ -298,7 +298,7 @@ class RayTrialExecutorTest(unittest.TestCase):
}, "grid_search")
trial = trials[0]
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
self.trial_executor = RayTrialExecutor(queue_trials=False)
self.trial_executor = RayTrialExecutor()
os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
@ -336,74 +336,6 @@ class RayTrialExecutorTest(unittest.TestCase):
trial.on_checkpoint(checkpoint)
class RayExecutorQueueTest(unittest.TestCase):
def setUp(self):
self.cluster = Cluster(
initialize_head=True,
connect=True,
head_node_args={
"num_cpus": 1,
"_system_config": {
"num_heartbeats_timeout": 10
}
})
self.trial_executor = RayTrialExecutor(
queue_trials=True, refresh_period=0)
# Pytest doesn't play nicely with imports
_register_all()
def tearDown(self):
ray.shutdown()
self.cluster.shutdown()
_register_all() # re-register the evicted objects
def testQueueTrial(self):
"""Tests that reset handles NotImplemented properly."""
def create_trial(cpu, gpu=0):
return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu))
cpu_only = create_trial(1, 0)
self.assertTrue(self.trial_executor.has_resources_for_trial(cpu_only))
self.trial_executor.start_trial(cpu_only)
gpu_only = create_trial(0, 1)
self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_only))
def testHeadBlocking(self):
# Once resource requests are deprecated, remove this test
os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1"
def create_trial(cpu, gpu=0):
return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu))
gpu_trial = create_trial(1, 1)
self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_trial))
self.trial_executor.start_trial(gpu_trial)
# TODO(rliaw): This behavior is probably undesirable, but right now
# trials with different resource requirements is not often used.
cpu_only_trial = create_trial(1, 0)
self.assertFalse(
self.trial_executor.has_resources_for_trial(cpu_only_trial))
self.cluster.add_node(num_cpus=1, num_gpus=1)
self.cluster.wait_for_nodes()
self.assertTrue(
self.trial_executor.has_resources_for_trial(cpu_only_trial))
self.trial_executor.start_trial(cpu_only_trial)
cpu_only_trial2 = create_trial(1, 0)
self.assertTrue(
self.trial_executor.has_resources_for_trial(cpu_only_trial2))
self.trial_executor.start_trial(cpu_only_trial2)
cpu_only_trial3 = create_trial(1, 0)
self.assertFalse(
self.trial_executor.has_resources_for_trial(cpu_only_trial3))
class RayExecutorPlacementGroupTest(unittest.TestCase):
def setUp(self):
self.head_cpus = 8
@ -537,7 +469,7 @@ class RayExecutorPlacementGroupTest(unittest.TestCase):
class LocalModeExecutorTest(RayTrialExecutorTest):
def setUp(self):
ray.init(local_mode=True)
self.trial_executor = RayTrialExecutor(queue_trials=False)
self.trial_executor = RayTrialExecutor()
def tearDown(self):
ray.shutdown()

View file

@ -89,8 +89,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
checkpoint_freq=1,
fail_fast=True,
config={"a": tune.sample_from(lambda _: param_a())},
trial_executor=CustomExecutor(
queue_trials=False, reuse_actors=False),
trial_executor=CustomExecutor(reuse_actors=False),
)

View file

@ -118,16 +118,9 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
and starting/stopping trials.
"""
def __init__(self, queue_trials: bool = False):
def __init__(self):
"""Initializes a new TrialExecutor.
Args:
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
"""
self._queue_trials = queue_trials
self._cached_trial_state = {}
self._trials_to_cache = set()
# The next two variables are used to keep track of if there is any
@ -338,29 +331,21 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
trials (List[Trial]): The list of trials. Note, refrain from
providing TrialRunner directly here.
"""
if self._queue_trials:
return
self._may_warn_insufficient_resources(trials)
for trial in trials:
if trial.uses_placement_groups:
return
# TODO(xwjiang): The rest should be gone in a follow up PR
# to remove non-pg case.
if trial.status == Trial.PENDING:
if not self.has_resources_for_trial(trial):
resource_string = trial.resources.summary_string()
trial_resource_help_msg = trial.get_trainable_cls(
).resource_help(trial.config)
autoscaling_msg = ""
if is_ray_cluster():
autoscaling_msg = (
"Pass `queue_trials=True` in ray.tune.run() or "
"on the command line to queue trials until the "
"cluster scales up or resources become available. "
)
raise TuneError(
"Insufficient cluster resources to launch trial: "
f"trial requested {resource_string}, but the cluster "
f"has only {self.resource_string()}. "
f"{autoscaling_msg}"
f"{trial_resource_help_msg} ")
elif trial.status == Trial.PAUSED:
raise TuneError("There are paused trials, but no more pending "

View file

@ -100,15 +100,15 @@ def run(
restore: Optional[str] = None,
server_port: Optional[int] = None,
resume: bool = False,
queue_trials: bool = False,
reuse_actors: bool = False,
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
callbacks: Optional[Sequence[Callback]] = None,
max_concurrent_trials: Optional[int] = None,
# Deprecated args
queue_trials: Optional[bool] = None,
loggers: Optional[Sequence[Type[Logger]]] = None,
_remote: bool = None,
_remote: Optional[bool] = None,
) -> ExperimentAnalysis:
"""Executes training.
@ -261,10 +261,6 @@ def run(
ERRORED trials upon resume - previous trial artifacts will
be left untouched. If resume is set but checkpoint does not exist,
ValueError will be thrown.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
reuse_actors (bool): Whether to reuse actors between different trials
when possible. This can drastically speed up experiments that start
and stop actors often (e.g., PBT in time-multiplexing mode). This
@ -294,6 +290,15 @@ def run(
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""
# To be removed in 1.9.
if queue_trials is not None:
raise DeprecationWarning(
"`queue_trials` has been deprecated and is replaced by "
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
"Per default at least one Trial is queued at all times, "
"so you likely don't need to change anything other than "
"removing this argument from your call to `tune.run()`")
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
# remote_run_kwargs must be defined before any other
# code is ran to ensure that at this point,
@ -409,9 +414,7 @@ def run(
result_buffer_length = 1
trial_executor = trial_executor or RayTrialExecutor(
reuse_actors=reuse_actors,
queue_trials=queue_trials,
result_buffer_length=result_buffer_length)
reuse_actors=reuse_actors, result_buffer_length=result_buffer_length)
if isinstance(run_or_experiment, list):
experiments = run_or_experiment
else:
@ -653,13 +656,14 @@ def run_experiments(
verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
progress_reporter: Optional[ProgressReporter] = None,
resume: bool = False,
queue_trials: bool = False,
reuse_actors: bool = False,
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
concurrent: bool = True,
# Deprecated args.
queue_trials: Optional[bool] = None,
callbacks: Optional[Sequence[Callback]] = None,
_remote: bool = None):
_remote: Optional[bool] = None):
"""Runs and blocks until all trials finish.
Examples:
@ -673,6 +677,15 @@ def run_experiments(
List of Trial objects, holding data for each executed trial.
"""
# To be removed in 1.9.
if queue_trials is not None:
raise DeprecationWarning(
"`queue_trials` has been deprecated and is replaced by "
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
"Per default at least one Trial is queued at all times, "
"so you likely don't need to change anything other than "
"removing this argument from your call to `tune.run()`")
if _remote is None:
_remote = ray.util.client.ray.is_connected()
@ -696,7 +709,6 @@ def run_experiments(
verbose,
progress_reporter,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
@ -716,7 +728,6 @@ def run_experiments(
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,
@ -731,7 +742,6 @@ def run_experiments(
verbose=verbose,
progress_reporter=progress_reporter,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,

View file

@ -129,7 +129,6 @@ analysis = tune.run(
FailureInjectorCallback(time_between_checks=90),
ProgressCallback()
],
queue_trials=True,
stop={"training_iteration": 1} if args.smoke_test else None)
print(analysis.get_best_config(metric="val_loss", mode="min"))

View file

@ -123,13 +123,6 @@ def create_parser(parser_creator=None):
help="Whether to attempt to enable tracing for eager mode.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
"--queue-trials",
action="store_true",
help=(
"Whether to queue trials when the cluster does not currently have "
"enough resources to launch one. This should be set to True when "
"running on an autoscaling cluster to enable automatic scale-up."))
parser.add_argument(
"-f",
"--config-file",
@ -267,7 +260,6 @@ def run(args, parser):
experiments,
scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
resume=args.resume,
queue_trials=args.queue_trials,
verbose=verbose,
progress_reporter=progress_reporter,
concurrent=True)