[Tune] Remove pg caching with fixes to pbt scheduler (#20403)

This reverts commit f13c2a5350.

Re-land remove PG caching logic.

As a result, pbt scheduler cannot stop and start trial within itself for weight transfer and perturbation now. So these are some changes to pbt scheduler:

1. the trial being perturbed is always left in a PAUSED state upon exiting on_trial_result. This is because instead of maintaining two separate paths for replacing a trial, we consolidate to always "stop" and "restore" and rely on reuse_actor as an optimization if available. (see 2)
2. consolidates pbt replacing a trial with reuse_actor.
3. introduces a NOOP scheduler decision to indicate that (pbt) scheduler has finished its interaction with executor and thus no decision is further needed in Tune loop.

Long term, we should control the interface between scheduler and executor. For example, on_trial_result taking in the whole runner is too much API exposure that we want to remove.
This commit is contained in:
xwjiang2010 2021-11-26 07:54:45 -08:00 committed by GitHub
parent 97b4490401
commit 96b44adf67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 256 additions and 283 deletions

View file

@ -312,7 +312,7 @@ py_test(
py_test(
name = "test_trial_scheduler_pbt",
size = "medium",
size = "large",
srcs = ["tests/test_trial_scheduler_pbt.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "flaky", "tests_dir_T"],
@ -405,7 +405,7 @@ py_test(
py_test(
name = "blendsearch_example",
size = "small",
size = "medium",
srcs = ["examples/blendsearch_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example"],

View file

@ -220,7 +220,7 @@ def tune_xgboost(use_class_trainable=True):
total_available_cpus // len(trial_runner.get_live_trials()))
# Assign new CPUs to the trial in a PlacementGroupFactory
return PlacementGroupFactory([{"CPU": cpu_to_use}])
return PlacementGroupFactory([{"CPU": cpu_to_use, "GPU": 0}])
# You can either define your own resources_allocation_function, or
# use the default one - evenly_distribute_cpus_gpus

View file

@ -502,21 +502,13 @@ class RayTrialExecutor(TrialExecutor):
logger.exception(
"Trial %s: updating resources timed out.", trial)
def _stop_trial(self,
trial: Trial,
error=False,
error_msg=None,
destroy_pg_if_cannot_replace=True):
def _stop_trial(self, trial: Trial, error=False, error_msg=None):
"""Stops this trial.
Stops this trial, releasing all allocating resources. If stopping the
trial fails, the run will be marked as terminated in error, but no
exception will be thrown.
If the placement group will be used right away
(destroy_pg_if_cannot_replace=False), we do not remove its placement
group (or a surrogate placement group).
Args:
error (bool): Whether to mark this trial as terminated in error.
error_msg (str): Optional error message.
@ -555,8 +547,7 @@ class RayTrialExecutor(TrialExecutor):
logger.debug("Trial %s: Destroying actor.", trial)
# Try to return the placement group for other trials to use
self._pg_manager.return_pg(trial,
destroy_pg_if_cannot_replace)
self._pg_manager.return_pg(trial)
with self._change_working_directory(trial):
self._trial_cleanup.add(trial, actor=trial.runner)
@ -614,18 +605,9 @@ class RayTrialExecutor(TrialExecutor):
def stop_trial(self,
trial: Trial,
error: bool = False,
error_msg: Optional[str] = None,
destroy_pg_if_cannot_replace: bool = True) -> None:
"""Only returns resources if resources allocated.
If destroy_pg_if_cannot_replace is False, the Trial placement group
will not be removed if it can't replace any staging ones."""
error_msg: Optional[str] = None) -> None:
prior_status = trial.status
self._stop_trial(
trial,
error=error,
error_msg=error_msg,
destroy_pg_if_cannot_replace=destroy_pg_if_cannot_replace)
self._stop_trial(trial, error=error, error_msg=error_msg)
if prior_status == Trial.RUNNING:
logger.debug("Trial %s: Returning resources.", trial)
out = self._find_item(self._running, trial)

View file

@ -1,6 +1,6 @@
import copy
import logging
import json
import logging
import math
import os
import random
@ -397,13 +397,15 @@ class PopulationBasedTraining(FIFOScheduler):
if not self._synch:
state.last_perturbation_time = time
lower_quantile, upper_quantile = self._quantiles()
self._perturb_trial(trial, trial_runner, upper_quantile,
lower_quantile)
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED]:
return TrialScheduler.PAUSE # yield time to other trials
return TrialScheduler.CONTINUE
decision = TrialScheduler.CONTINUE
for other_trial in trial_runner.get_trials():
if other_trial.status in [Trial.PENDING, Trial.PAUSED]:
decision = TrialScheduler.PAUSE
break
self._checkpoint_or_exploit(trial, trial_runner.trial_executor,
upper_quantile, lower_quantile)
return (TrialScheduler.NOOP
if trial.status == Trial.PAUSED else decision)
else:
# Synchronous mode.
if any(self._trial_state[t].last_train_time <
@ -425,12 +427,12 @@ class PopulationBasedTraining(FIFOScheduler):
for t in all_trials:
logger.debug("Perturbing Trial {}".format(t))
self._trial_state[t].last_perturbation_time = time
self._perturb_trial(t, trial_runner, upper_quantile,
lower_quantile)
self._checkpoint_or_exploit(t, trial_runner.trial_executor,
upper_quantile, lower_quantile)
all_train_times = [
self._trial_state[trial].last_train_time
for trial in trial_runner.get_trials()
self._trial_state[t].last_train_time
for t in trial_runner.get_trials()
]
max_last_train_time = max(all_train_times)
self._next_perturbation_sync = max(
@ -441,7 +443,8 @@ class PopulationBasedTraining(FIFOScheduler):
# still all be paused.
# choose_trial_to_run will then pick the next trial to run out of
# the paused trials.
return TrialScheduler.PAUSE
return (TrialScheduler.NOOP
if trial.status == Trial.PAUSED else TrialScheduler.PAUSE)
def _save_trial_state(self, state: PBTTrialState, time: int, result: Dict,
trial: Trial):
@ -462,9 +465,10 @@ class PopulationBasedTraining(FIFOScheduler):
return score
def _perturb_trial(
self, trial: Trial, trial_runner: "trial_runner.TrialRunner",
upper_quantile: List[Trial], lower_quantile: List[Trial]):
def _checkpoint_or_exploit(self, trial: Trial,
trial_executor: "trial_runner.RayTrialExecutor",
upper_quantile: List[Trial],
lower_quantile: List[Trial]):
"""Checkpoint if in upper quantile, exploits if in lower."""
state = self._trial_state[trial]
if trial in upper_quantile:
@ -476,7 +480,7 @@ class PopulationBasedTraining(FIFOScheduler):
# Paused trial will always have an in-memory checkpoint.
state.last_checkpoint = trial.checkpoint
else:
state.last_checkpoint = trial_runner.trial_executor.save(
state.last_checkpoint = trial_executor.save(
trial, Checkpoint.MEMORY, result=state.last_result)
self._num_checkpoints += 1
else:
@ -490,7 +494,7 @@ class PopulationBasedTraining(FIFOScheduler):
logger.info("[pbt]: no checkpoint for trial."
" Skip exploit for Trial {}".format(trial))
return
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
self._exploit(trial_executor, trial, trial_to_clone)
def _log_config_on_step(self, trial_state: PBTTrialState,
new_state: PBTTrialState, trial: Trial,
@ -571,36 +575,12 @@ class PopulationBasedTraining(FIFOScheduler):
raise TuneError("Trials should be paused here only if in "
"synchronous mode. If you encounter this error"
" please raise an issue on Ray Github.")
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial.on_checkpoint(new_state.last_checkpoint)
else:
# If trial is running, we first try to reset it.
# If that is unsuccessful, then we have to stop it and start it
# again with a new checkpoint.
reset_successful = trial_executor.reset_trial(
trial, new_config, new_tag)
# TODO(ujvl): Refactor Scheduler abstraction to abstract
# mechanism for trial restart away. We block on restore
# and suppress train on start as a stop-gap fix to
# https://github.com/ray-project/ray/issues/7258.
if reset_successful:
trial_executor.restore(
trial, new_state.last_checkpoint, block=True)
else:
# Stop trial, but do not free resources (so we can use them
# again right away)
trial_executor.stop_trial(
trial, destroy_pg_if_cannot_replace=False)
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
if not trial_executor.start_trial(
trial, new_state.last_checkpoint, train=False):
logger.warning(
f"Trial couldn't be reset: {trial}. Terminating "
f"instead.")
trial_executor.stop_trial(trial, error=True)
trial_executor.stop_trial(trial)
trial_executor.set_status(trial, Trial.PAUSED)
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial.on_checkpoint(new_state.last_checkpoint)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
@ -651,10 +631,12 @@ class PopulationBasedTraining(FIFOScheduler):
key=lambda trial: self._trial_state[trial].last_train_time)
return candidates[0] if candidates else None
# Unit test only. TODO(xwjiang): Remove test-specific APIs.
def reset_stats(self):
self._num_perturbations = 0
self._num_checkpoints = 0
# Unit test only. TODO(xwjiang): Remove test-specific APIs.
def last_scores(self, trials: List[Trial]) -> List[float]:
scores = []
for trial in trials:
@ -815,23 +797,17 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
new_config)
trial_executor = trial_runner.trial_executor
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
if reset_successful:
trial_executor.restore(trial, checkpoint, block=True)
else:
trial_executor.stop_trial(
trial, destroy_pg_if_cannot_replace=False)
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial_executor.start_trial(trial, checkpoint, train=False)
trial_executor.stop_trial(trial)
trial_executor.set_status(trial, Trial.PAUSED)
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial.on_checkpoint(checkpoint)
self.current_config = new_config
self._num_perturbations += 1
self._next_policy = next(self._policy_iter, None)
return TrialScheduler.CONTINUE
return TrialScheduler.NOOP
def debug_string(self) -> str:
return "PopulationBasedTraining replay: Step {}, perturb {}".format(

View file

@ -11,6 +11,11 @@ class TrialScheduler:
CONTINUE = "CONTINUE" #: Status for continuing trial execution
PAUSE = "PAUSE" #: Status for pausing trial execution
STOP = "STOP" #: Status for stopping trial execution
# Caution: Temporary and anti-pattern! This means Scheduler calls
# into Executor directly without going through TrialRunner.
# TODO(xwjiang): Deprecate this after we control the interaction
# between schedulers and executor.
NOOP = "NOOP"
_metric = None

View file

@ -314,49 +314,48 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
runner.step()
# TODO(xwjiang): Uncomment this when pg caching is removed.
# @pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
# def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
# """Removing a node in full cluster causes Trial to be requeued."""
# os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
#
# cluster = start_connected_emptyhead_cluster
# node = cluster.add_node(num_cpus=1)
# cluster.wait_for_nodes()
#
# syncer_callback = _PerTrialSyncerCallback(
# lambda trial: trial.trainable_name == "__fake")
# runner = TrialRunner(
# BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
# kwargs = {
# "stopping_criterion": {
# "training_iteration": 5
# },
# "checkpoint_freq": 1,
# "max_failures": 1,
# }
#
# if trainable_id == "__fake_durable":
# kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
#
# trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
# for t in trials:
# runner.add_trial(t)
#
# runner.step() # Start trial
# runner.step() # Process result, dispatch save
# runner.step() # Process save
#
# running_trials = _get_running_trials(runner)
# assert len(running_trials) == 1
# assert _check_trial_running(running_trials[0])
# cluster.remove_node(node)
# cluster.wait_for_nodes()
# time.sleep(0.1) # Sleep so that next step() refreshes cluster resources
# runner.step() # Process result, dispatch save
# runner.step() # Process save (detect error), requeue trial
# assert all(
# t.status == Trial.PENDING for t in trials), runner.debug_string()
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
"""Removing a node in full cluster causes Trial to be requeued."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(
BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
kwargs = {
"stopping_criterion": {
"training_iteration": 5
},
"checkpoint_freq": 1,
"max_failures": 1,
}
if trainable_id == "__fake_durable":
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
running_trials = _get_running_trials(runner)
assert len(running_trials) == 1
assert _check_trial_running(running_trials[0])
cluster.remove_node(node)
cluster.wait_for_nodes()
time.sleep(0.1) # Sleep so that next step() refreshes cluster resources
runner.step() # Process result, dispatch save
runner.step() # Process save (detect error), requeue trial
assert all(
t.status == Trial.PENDING for t in trials), runner.debug_string()
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])

View file

@ -225,11 +225,7 @@ class _MockTrialExecutor(TrialExecutor):
trial.status = Trial.RUNNING
return True
def stop_trial(self,
trial,
error=False,
error_msg=None,
destroy_pg_if_cannot_replace=True):
def stop_trial(self, trial, error=False, error_msg=None):
trial.status = Trial.ERROR if error else Trial.TERMINATED
def restore(self, trial, checkpoint=None, block=False):
@ -856,6 +852,26 @@ class PopulationBasedTestingSuite(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
# Helper function to call pbt.on_trial_result and assert decision,
# or trial status upon existing.
# Need to have the `trial` in `RUNNING` status first.
def on_trial_result(self,
pbt,
runner,
trial,
result,
expected_decision=None):
trial.status = Trial.RUNNING
decision = pbt.on_trial_result(runner, trial, result)
if expected_decision is None:
pass
elif expected_decision == TrialScheduler.PAUSE:
self.assertTrue(trial.status == Trial.PAUSED
or decision == expected_decision)
elif expected_decision == TrialScheduler.CONTINUE:
self.assertEqual(decision, expected_decision)
return decision
def basicSetup(self,
num_trials=5,
resample_prob=0.0,
@ -900,13 +916,19 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trial = runner.trials[i]
if step_once:
if synch:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.PAUSE)
self.on_trial_result(
pbt,
runner,
trial,
result(10, 50 * i),
expected_decision=TrialScheduler.PAUSE)
else:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
self.on_trial_result(
pbt,
runner,
trial,
result(10, 50 * i),
expected_decision=TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
@ -929,12 +951,13 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Should error if training_iteration not in result dict.
with self.assertRaises(RuntimeError):
pbt.on_trial_result(
runner, trials[0], result={"episode_reward_mean": 4})
self.on_trial_result(
pbt, runner, trials[0], result={"episode_reward_mean": 4})
# Should error if episode_reward_mean not in result dict.
with self.assertRaises(RuntimeError):
pbt.on_trial_result(
self.on_trial_result(
pbt,
runner,
trials[0],
result={
@ -948,12 +971,13 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Should not error if training_iteration not in result dict
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
pbt.on_trial_result(
runner, trials[0], result={"episode_reward_mean": 4})
self.on_trial_result(
pbt, runner, trials[0], result={"episode_reward_mean": 4})
# Should not error if episode_reward_mean not in result dict.
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
pbt.on_trial_result(
self.on_trial_result(
pbt,
runner,
trials[0],
result={
@ -967,28 +991,24 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(15, 200),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# checkpoint: both past interval and upper quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, 200)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, 200),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 1)
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(30, 201)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[1], result(30, 201),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 2)
# not upper quantile any more
self.assertEqual(
pbt.on_trial_result(runner, trials[4], result(30, 199)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[4], result(30, 199),
TrialScheduler.CONTINUE)
self.assertEqual(pbt._num_checkpoints, 2)
self.assertEqual(pbt._num_perturbations, 0)
@ -998,24 +1018,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(15, 200),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# trials should be paused until all trials are synced.
for i in range(len(trials) - 1):
self.assertEqual(
pbt.on_trial_result(runner, trials[i], result(20, 200 + i)),
TrialScheduler.PAUSE)
self.on_trial_result(pbt, runner, trials[i], result(20, 200 + i),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200])
self.assertEqual(pbt._num_checkpoints, 0)
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(20, 204)),
TrialScheduler.PAUSE)
self.on_trial_result(pbt, runner, trials[-1], result(20, 204),
TrialScheduler.PAUSE)
self.assertEqual(pbt._num_checkpoints, 2)
def testPerturbsLowPerformingTrials(self):
@ -1023,26 +1040,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials = runner.get_trials()
# no perturbation: haven't hit next perturbation interval
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(15, -100),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
# perturb since it's lower quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertTrue("@perturbed" in trials[0].experiment_tag)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(pbt._num_perturbations, 1)
# also perturbed
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(20, 40)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[2], result(20, 40),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt._num_perturbations, 2)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
@ -1053,25 +1067,22 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials = runner.get_trials()
# no perturbation: haven't hit next perturbation interval
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(15, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[-1], result(15, -100),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
# Don't perturb until all trials are synched.
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(20, -100)),
TrialScheduler.PAUSE)
self.on_trial_result(pbt, runner, trials[-1], result(20, -100),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100])
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
# Synch all trials.
for i in range(len(trials) - 1):
self.assertEqual(
pbt.on_trial_result(runner, trials[i], result(20, -10 * i)),
TrialScheduler.PAUSE)
self.on_trial_result(pbt, runner, trials[i], result(20, -10 * i),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100])
self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"])
self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"])
@ -1080,9 +1091,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
def testPerturbWithoutResample(self):
pbt, runner = self.basicSetup(resample_prob=0.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertIn(trials[0].config["id_factor"], [100])
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
@ -1094,9 +1104,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
def testPerturbWithResample(self):
pbt, runner = self.basicSetup(resample_prob=1.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertEqual(trials[0].status, Trial.PAUSED)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["id_factor"], 100)
self.assertEqual(trials[0].config["float_factor"], 100.0)
@ -1114,9 +1124,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
"id_factor": tune.choice([100])
})
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["id_factor"], 100)
self.assertEqual(trials[0].config["float_factor"], 100.0)
@ -1277,9 +1286,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
"int_factor": lambda: 10,
})
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["float_factor"], 100.0)
self.assertIsInstance(trials[0].config["float_factor"], float)
@ -1293,21 +1301,19 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
trials[0].status = Trial.PENDING # simulate not enough resources
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
TrialScheduler.PAUSE)
self.on_trial_result(pbt, runner, trials[1], result(20, 1000),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
def testSchedulesMostBehindTrialToRun(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(800, 1000))
pbt.on_trial_result(runner, trials[1], result(700, 1001))
pbt.on_trial_result(runner, trials[2], result(600, 1002))
pbt.on_trial_result(runner, trials[3], result(500, 1003))
pbt.on_trial_result(runner, trials[4], result(700, 1004))
self.on_trial_result(pbt, runner, trials[0], result(800, 1000))
self.on_trial_result(pbt, runner, trials[1], result(700, 1001))
self.on_trial_result(pbt, runner, trials[2], result(600, 1002))
self.on_trial_result(pbt, runner, trials[3], result(500, 1003))
self.on_trial_result(pbt, runner, trials[4], result(700, 1004))
self.assertEqual(pbt.choose_trial_to_run(runner), None)
for i in range(5):
trials[i].status = Trial.PENDING
@ -1317,35 +1323,35 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt, runner = self.basicSetup(synch=True)
trials = runner.get_trials()
runner.process_action(
trials[0], pbt.on_trial_result(runner, trials[0], result(
800, 1000)))
trials[0],
self.on_trial_result(pbt, runner, trials[0], result(800, 1000)))
runner.process_action(
trials[1], pbt.on_trial_result(runner, trials[1], result(
700, 1001)))
trials[1],
self.on_trial_result(pbt, runner, trials[1], result(700, 1001)))
runner.process_action(
trials[2], pbt.on_trial_result(runner, trials[2], result(
600, 1002)))
trials[2],
self.on_trial_result(pbt, runner, trials[2], result(600, 1002)))
runner.process_action(
trials[3], pbt.on_trial_result(runner, trials[3], result(
500, 1003)))
trials[3],
self.on_trial_result(pbt, runner, trials[3], result(500, 1003)))
runner.process_action(
trials[4], pbt.on_trial_result(runner, trials[4], result(
700, 1004)))
trials[4],
self.on_trial_result(pbt, runner, trials[4], result(700, 1004)))
self.assertIn(
pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]])
def testPerturbationResetsLastPerturbTime(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
pbt.on_trial_result(runner, trials[2], result(600, 1003))
self.on_trial_result(pbt, runner, trials[0], result(10000, 1005))
self.on_trial_result(pbt, runner, trials[1], result(10000, 1004))
self.on_trial_result(pbt, runner, trials[2], result(600, 1003))
self.assertEqual(pbt._num_perturbations, 0)
pbt.on_trial_result(runner, trials[3], result(500, 1002))
self.on_trial_result(pbt, runner, trials[3], result(500, 1002))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(600, 100))
self.on_trial_result(pbt, runner, trials[3], result(600, 100))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(11000, 100))
self.on_trial_result(pbt, runner, trials[3], result(11000, 100))
self.assertEqual(pbt._num_perturbations, 2)
def testLogConfig(self):
@ -1376,9 +1382,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: i}
pbt.on_trial_result(runner, trials[0], result(15, -100))
pbt.on_trial_result(runner, trials[0], result(20, -100))
pbt.on_trial_result(runner, trials[2], result(20, 40))
self.on_trial_result(pbt, runner, trials[0], result(15, -100))
self.on_trial_result(pbt, runner, trials[0], result(20, -100))
self.on_trial_result(pbt, runner, trials[2], result(20, 40))
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_2.txt"]
for log_file in log_files:
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
@ -1416,7 +1422,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: i}
pbt.on_trial_result(runner, trials[i], result(10, i))
self.on_trial_result(pbt, runner, trials[i], result(10, i))
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"]
for log_file in log_files:
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
@ -1475,7 +1481,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trial_state[k].forward(res[TRAINING_ITERATION])
old_config = trials[k].config
pbt.on_trial_result(runner, trials[k], res)
self.on_trial_result(pbt, runner, trials[k], res)
new_config = trials[k].config
trial_state[k].config = new_config.copy()
@ -1538,15 +1544,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
TRAINING_ITERATION: self.iter
}
def reset_config(self, new_config):
self.config = new_config
return True
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(
json.dumps({
"iter": self.iter,
"replayed": self.replayed
}))
return path
def save_checkpoint(self, tmp_checkpoint_dir):
return tmp_checkpoint_dir
def load_checkpoint(self, checkpoint):
pass
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
checkpoint_json = json.loads(f.read())
self.iter = checkpoint_json["iter"]
self.replayed = checkpoint_json["replayed"]
# Loop through all trials and check if PBT history is the
# same as the playback history
@ -1626,15 +1638,14 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials[k].last_result = res
trial_state[k].forward(res[TRAINING_ITERATION])
trials[k].status = Trial.RUNNING
if not synced:
action = pbt.on_trial_result(runner, trials[k], res)
action = self.on_trial_result(pbt, runner, trials[k], res)
runner.process_action(trials[k], action)
return
else:
# Reached synchronization point
old_configs = [trial.config for trial in trials]
action = pbt.on_trial_result(runner, trials[k], res)
action = self.on_trial_result(pbt, runner, trials[k], res)
runner.process_action(trials[k], action)
new_configs = [trial.config for trial in trials]
@ -1708,15 +1719,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
TRAINING_ITERATION: self.iter
}
def reset_config(self, new_config):
self.config = new_config
return True
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(
json.dumps({
"iter": self.iter,
"replayed": self.replayed
}))
return path
def save_checkpoint(self, tmp_checkpoint_dir):
return tmp_checkpoint_dir
def load_checkpoint(self, checkpoint):
pass
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
checkpoint_json = json.loads(f.read())
self.iter = checkpoint_json["iter"]
self.replayed = checkpoint_json["replayed"]
# Loop through all trials and check if PBT history is the
# same as the playback history
@ -1755,9 +1772,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
TrialScheduler.PAUSE)
self.assertEqual(trials[0].config["id_factor"], 42)
self.assertEqual(trials[0].config["float_factor"], 43)
@ -1770,10 +1786,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {}
pbt.on_trial_result(runner, trials[0], result(1, 10))
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(1, 200)),
TrialScheduler.CONTINUE)
self.on_trial_result(pbt, runner, trials[0], result(1, 10))
self.on_trial_result(pbt, runner, trials[2], result(1, 200),
TrialScheduler.CONTINUE)
self.assertEqual(pbt._num_checkpoints, 1)
pbt._exploit(runner.trial_executor, trials[1], trials[2])
@ -1879,6 +1894,12 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
f.write("OK")
return checkpoint
def reset_config(self, config):
return True
def load_checkpoint(self, checkpoint):
pass
trial_hyperparams = {
"float_factor": 2.0,
"const_factor": 3,
@ -1914,6 +1935,9 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
def load_checkpoint(self, state):
self.state = state
def reset_config(self, config):
return True
trial_hyperparams = {
"float_factor": 2.0,
"const_factor": 3,

View file

@ -174,7 +174,6 @@ class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
class PopulationBasedTrainingSynchTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
ray.init(num_cpus=2)
def MockTrainingFuncSync(config, checkpoint_dir=None):
@ -210,7 +209,7 @@ class PopulationBasedTrainingSynchTest(unittest.TestCase):
def synchSetup(self, synch, param=None):
if param is None:
param = [10, 20, 30]
param = [10, 20, 40]
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
@ -244,14 +243,14 @@ class PopulationBasedTrainingSynchTest(unittest.TestCase):
self.assertTrue(
any(
analysis.dataframe(metric="mean_accuracy", mode="max")
["mean_accuracy"] != 33))
["mean_accuracy"] != 43))
def testSynchPass(self):
analysis = self.synchSetup(True)
self.assertTrue(
all(
analysis.dataframe(metric="mean_accuracy", mode="max")[
"mean_accuracy"] == 33))
"mean_accuracy"] == 43))
def testSynchPassLast(self):
analysis = self.synchSetup(True, param=[30, 20, 10])

View file

@ -190,8 +190,7 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
def stop_trial(self,
trial: Trial,
error: bool = False,
error_msg: Optional[str] = None,
destroy_pg_if_cannot_replace: bool = True) -> None:
error_msg: Optional[str] = None) -> None:
"""Stops the trial.
Stops this trial, releasing all allocating resources.
@ -201,8 +200,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
Args:
error (bool): Whether to mark this trial as terminated in error.
error_msg (str): Optional error message.
destroy_pg_if_cannot_replace (bool): Whether the trial's placement
group should be destroyed if it cannot replace any staged ones.
"""
pass

View file

@ -1176,6 +1176,8 @@ class TrialRunner:
elif decision == TrialScheduler.STOP:
self.trial_executor.export_trial_if_needed(trial)
self.trial_executor.stop_trial(trial)
elif decision == TrialScheduler.NOOP:
pass
else:
raise ValueError("Invalid decision: {}".format(decision))

View file

@ -23,6 +23,8 @@ from ray.tune.progress_reporter import (detect_reporter, ProgressReporter,
JupyterNotebookReporter)
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls
from ray.tune.schedulers import (PopulationBasedTraining,
PopulationBasedTrainingReplay)
from ray.tune.stopper import Stopper
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm, \
SearchGenerator
@ -413,6 +415,13 @@ def run(
f"to 1 instead.")
result_buffer_length = 1
if isinstance(scheduler,
(PopulationBasedTraining,
PopulationBasedTrainingReplay)) and not reuse_actors:
warnings.warn(
"Consider boosting PBT performance by enabling `reuse_actors` as "
"well as implementing `reset_config` for Trainable.")
trial_executor = trial_executor or RayTrialExecutor(
reuse_actors=reuse_actors, result_buffer_length=result_buffer_length)
if isinstance(run_or_experiment, list):

View file

@ -607,37 +607,17 @@ class PlacementGroupManager:
self._ready[pgf].add(pg)
return True
def return_pg(self,
trial: "Trial",
destroy_pg_if_cannot_replace: bool = True):
"""Return pg, making it available for other trials to use.
If destroy_pg_if_cannot_replace is True, this will only return
a placement group if a staged placement group can be replaced
by it. If not, it will destroy the placement group.
def return_pg(self, trial: "Trial"):
"""Return pg back to Core scheduling.
Args:
trial (Trial): Return placement group of this trial.
Returns:
Boolean indicating if the placement group was returned.
"""
pgf = trial.placement_group_factory
pg = self._in_use_trials.pop(trial)
self._in_use_pgs.pop(pg)
if destroy_pg_if_cannot_replace:
staged_pg = self._unstage_unused_pg(pgf)
# Could not replace
if not staged_pg:
self.remove_pg(pg)
return False
self.remove_pg(staged_pg)
self._ready[pgf].add(pg)
return True
self.remove_pg(pg)
def _unstage_unused_pg(
self, pgf: PlacementGroupFactory) -> Optional[PlacementGroup]: