mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Tune] Remove pg caching with fixes to pbt scheduler (#20403)
This reverts commit f13c2a5350
.
Re-land remove PG caching logic.
As a result, pbt scheduler cannot stop and start trial within itself for weight transfer and perturbation now. So these are some changes to pbt scheduler:
1. the trial being perturbed is always left in a PAUSED state upon exiting on_trial_result. This is because instead of maintaining two separate paths for replacing a trial, we consolidate to always "stop" and "restore" and rely on reuse_actor as an optimization if available. (see 2)
2. consolidates pbt replacing a trial with reuse_actor.
3. introduces a NOOP scheduler decision to indicate that (pbt) scheduler has finished its interaction with executor and thus no decision is further needed in Tune loop.
Long term, we should control the interface between scheduler and executor. For example, on_trial_result taking in the whole runner is too much API exposure that we want to remove.
This commit is contained in:
parent
97b4490401
commit
96b44adf67
12 changed files with 256 additions and 283 deletions
|
@ -312,7 +312,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_trial_scheduler_pbt",
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/test_trial_scheduler_pbt.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "flaky", "tests_dir_T"],
|
||||
|
@ -405,7 +405,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "blendsearch_example",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["examples/blendsearch_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "example"],
|
||||
|
|
|
@ -220,7 +220,7 @@ def tune_xgboost(use_class_trainable=True):
|
|||
total_available_cpus // len(trial_runner.get_live_trials()))
|
||||
|
||||
# Assign new CPUs to the trial in a PlacementGroupFactory
|
||||
return PlacementGroupFactory([{"CPU": cpu_to_use}])
|
||||
return PlacementGroupFactory([{"CPU": cpu_to_use, "GPU": 0}])
|
||||
|
||||
# You can either define your own resources_allocation_function, or
|
||||
# use the default one - evenly_distribute_cpus_gpus
|
||||
|
|
|
@ -502,21 +502,13 @@ class RayTrialExecutor(TrialExecutor):
|
|||
logger.exception(
|
||||
"Trial %s: updating resources timed out.", trial)
|
||||
|
||||
def _stop_trial(self,
|
||||
trial: Trial,
|
||||
error=False,
|
||||
error_msg=None,
|
||||
destroy_pg_if_cannot_replace=True):
|
||||
def _stop_trial(self, trial: Trial, error=False, error_msg=None):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
trial fails, the run will be marked as terminated in error, but no
|
||||
exception will be thrown.
|
||||
|
||||
If the placement group will be used right away
|
||||
(destroy_pg_if_cannot_replace=False), we do not remove its placement
|
||||
group (or a surrogate placement group).
|
||||
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
|
@ -555,8 +547,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
logger.debug("Trial %s: Destroying actor.", trial)
|
||||
|
||||
# Try to return the placement group for other trials to use
|
||||
self._pg_manager.return_pg(trial,
|
||||
destroy_pg_if_cannot_replace)
|
||||
self._pg_manager.return_pg(trial)
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
self._trial_cleanup.add(trial, actor=trial.runner)
|
||||
|
@ -614,18 +605,9 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def stop_trial(self,
|
||||
trial: Trial,
|
||||
error: bool = False,
|
||||
error_msg: Optional[str] = None,
|
||||
destroy_pg_if_cannot_replace: bool = True) -> None:
|
||||
"""Only returns resources if resources allocated.
|
||||
|
||||
If destroy_pg_if_cannot_replace is False, the Trial placement group
|
||||
will not be removed if it can't replace any staging ones."""
|
||||
error_msg: Optional[str] = None) -> None:
|
||||
prior_status = trial.status
|
||||
self._stop_trial(
|
||||
trial,
|
||||
error=error,
|
||||
error_msg=error_msg,
|
||||
destroy_pg_if_cannot_replace=destroy_pg_if_cannot_replace)
|
||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||
if prior_status == Trial.RUNNING:
|
||||
logger.debug("Trial %s: Returning resources.", trial)
|
||||
out = self._find_item(self._running, trial)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
@ -397,13 +397,15 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
if not self._synch:
|
||||
state.last_perturbation_time = time
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
self._perturb_trial(trial, trial_runner, upper_quantile,
|
||||
lower_quantile)
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
return TrialScheduler.PAUSE # yield time to other trials
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
decision = TrialScheduler.CONTINUE
|
||||
for other_trial in trial_runner.get_trials():
|
||||
if other_trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
decision = TrialScheduler.PAUSE
|
||||
break
|
||||
self._checkpoint_or_exploit(trial, trial_runner.trial_executor,
|
||||
upper_quantile, lower_quantile)
|
||||
return (TrialScheduler.NOOP
|
||||
if trial.status == Trial.PAUSED else decision)
|
||||
else:
|
||||
# Synchronous mode.
|
||||
if any(self._trial_state[t].last_train_time <
|
||||
|
@ -425,12 +427,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
for t in all_trials:
|
||||
logger.debug("Perturbing Trial {}".format(t))
|
||||
self._trial_state[t].last_perturbation_time = time
|
||||
self._perturb_trial(t, trial_runner, upper_quantile,
|
||||
lower_quantile)
|
||||
self._checkpoint_or_exploit(t, trial_runner.trial_executor,
|
||||
upper_quantile, lower_quantile)
|
||||
|
||||
all_train_times = [
|
||||
self._trial_state[trial].last_train_time
|
||||
for trial in trial_runner.get_trials()
|
||||
self._trial_state[t].last_train_time
|
||||
for t in trial_runner.get_trials()
|
||||
]
|
||||
max_last_train_time = max(all_train_times)
|
||||
self._next_perturbation_sync = max(
|
||||
|
@ -441,7 +443,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
# still all be paused.
|
||||
# choose_trial_to_run will then pick the next trial to run out of
|
||||
# the paused trials.
|
||||
return TrialScheduler.PAUSE
|
||||
return (TrialScheduler.NOOP
|
||||
if trial.status == Trial.PAUSED else TrialScheduler.PAUSE)
|
||||
|
||||
def _save_trial_state(self, state: PBTTrialState, time: int, result: Dict,
|
||||
trial: Trial):
|
||||
|
@ -462,9 +465,10 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
|
||||
return score
|
||||
|
||||
def _perturb_trial(
|
||||
self, trial: Trial, trial_runner: "trial_runner.TrialRunner",
|
||||
upper_quantile: List[Trial], lower_quantile: List[Trial]):
|
||||
def _checkpoint_or_exploit(self, trial: Trial,
|
||||
trial_executor: "trial_runner.RayTrialExecutor",
|
||||
upper_quantile: List[Trial],
|
||||
lower_quantile: List[Trial]):
|
||||
"""Checkpoint if in upper quantile, exploits if in lower."""
|
||||
state = self._trial_state[trial]
|
||||
if trial in upper_quantile:
|
||||
|
@ -476,7 +480,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
# Paused trial will always have an in-memory checkpoint.
|
||||
state.last_checkpoint = trial.checkpoint
|
||||
else:
|
||||
state.last_checkpoint = trial_runner.trial_executor.save(
|
||||
state.last_checkpoint = trial_executor.save(
|
||||
trial, Checkpoint.MEMORY, result=state.last_result)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
|
@ -490,7 +494,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
logger.info("[pbt]: no checkpoint for trial."
|
||||
" Skip exploit for Trial {}".format(trial))
|
||||
return
|
||||
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
|
||||
self._exploit(trial_executor, trial, trial_to_clone)
|
||||
|
||||
def _log_config_on_step(self, trial_state: PBTTrialState,
|
||||
new_state: PBTTrialState, trial: Trial,
|
||||
|
@ -571,36 +575,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
raise TuneError("Trials should be paused here only if in "
|
||||
"synchronous mode. If you encounter this error"
|
||||
" please raise an issue on Ray Github.")
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial.on_checkpoint(new_state.last_checkpoint)
|
||||
else:
|
||||
# If trial is running, we first try to reset it.
|
||||
# If that is unsuccessful, then we have to stop it and start it
|
||||
# again with a new checkpoint.
|
||||
reset_successful = trial_executor.reset_trial(
|
||||
trial, new_config, new_tag)
|
||||
# TODO(ujvl): Refactor Scheduler abstraction to abstract
|
||||
# mechanism for trial restart away. We block on restore
|
||||
# and suppress train on start as a stop-gap fix to
|
||||
# https://github.com/ray-project/ray/issues/7258.
|
||||
if reset_successful:
|
||||
trial_executor.restore(
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
# Stop trial, but do not free resources (so we can use them
|
||||
# again right away)
|
||||
trial_executor.stop_trial(
|
||||
trial, destroy_pg_if_cannot_replace=False)
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
|
||||
if not trial_executor.start_trial(
|
||||
trial, new_state.last_checkpoint, train=False):
|
||||
logger.warning(
|
||||
f"Trial couldn't be reset: {trial}. Terminating "
|
||||
f"instead.")
|
||||
trial_executor.stop_trial(trial, error=True)
|
||||
trial_executor.stop_trial(trial)
|
||||
trial_executor.set_status(trial, Trial.PAUSED)
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial.on_checkpoint(new_state.last_checkpoint)
|
||||
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
|
@ -651,10 +631,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
key=lambda trial: self._trial_state[trial].last_train_time)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
# Unit test only. TODO(xwjiang): Remove test-specific APIs.
|
||||
def reset_stats(self):
|
||||
self._num_perturbations = 0
|
||||
self._num_checkpoints = 0
|
||||
|
||||
# Unit test only. TODO(xwjiang): Remove test-specific APIs.
|
||||
def last_scores(self, trials: List[Trial]) -> List[float]:
|
||||
scores = []
|
||||
for trial in trials:
|
||||
|
@ -815,23 +797,17 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
|||
new_config)
|
||||
|
||||
trial_executor = trial_runner.trial_executor
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
|
||||
if reset_successful:
|
||||
trial_executor.restore(trial, checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(
|
||||
trial, destroy_pg_if_cannot_replace=False)
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial_executor.start_trial(trial, checkpoint, train=False)
|
||||
trial_executor.stop_trial(trial)
|
||||
trial_executor.set_status(trial, Trial.PAUSED)
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial.on_checkpoint(checkpoint)
|
||||
|
||||
self.current_config = new_config
|
||||
self._num_perturbations += 1
|
||||
self._next_policy = next(self._policy_iter, None)
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
return TrialScheduler.NOOP
|
||||
|
||||
def debug_string(self) -> str:
|
||||
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
|
||||
|
|
|
@ -11,6 +11,11 @@ class TrialScheduler:
|
|||
CONTINUE = "CONTINUE" #: Status for continuing trial execution
|
||||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
# Caution: Temporary and anti-pattern! This means Scheduler calls
|
||||
# into Executor directly without going through TrialRunner.
|
||||
# TODO(xwjiang): Deprecate this after we control the interaction
|
||||
# between schedulers and executor.
|
||||
NOOP = "NOOP"
|
||||
|
||||
_metric = None
|
||||
|
||||
|
|
|
@ -314,49 +314,48 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
runner.step()
|
||||
|
||||
|
||||
# TODO(xwjiang): Uncomment this when pg caching is removed.
|
||||
# @pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
# def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
# """Removing a node in full cluster causes Trial to be requeued."""
|
||||
# os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
|
||||
#
|
||||
# cluster = start_connected_emptyhead_cluster
|
||||
# node = cluster.add_node(num_cpus=1)
|
||||
# cluster.wait_for_nodes()
|
||||
#
|
||||
# syncer_callback = _PerTrialSyncerCallback(
|
||||
# lambda trial: trial.trainable_name == "__fake")
|
||||
# runner = TrialRunner(
|
||||
# BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
|
||||
# kwargs = {
|
||||
# "stopping_criterion": {
|
||||
# "training_iteration": 5
|
||||
# },
|
||||
# "checkpoint_freq": 1,
|
||||
# "max_failures": 1,
|
||||
# }
|
||||
#
|
||||
# if trainable_id == "__fake_durable":
|
||||
# kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
#
|
||||
# trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
# for t in trials:
|
||||
# runner.add_trial(t)
|
||||
#
|
||||
# runner.step() # Start trial
|
||||
# runner.step() # Process result, dispatch save
|
||||
# runner.step() # Process save
|
||||
#
|
||||
# running_trials = _get_running_trials(runner)
|
||||
# assert len(running_trials) == 1
|
||||
# assert _check_trial_running(running_trials[0])
|
||||
# cluster.remove_node(node)
|
||||
# cluster.wait_for_nodes()
|
||||
# time.sleep(0.1) # Sleep so that next step() refreshes cluster resources
|
||||
# runner.step() # Process result, dispatch save
|
||||
# runner.step() # Process save (detect error), requeue trial
|
||||
# assert all(
|
||||
# t.status == Trial.PENDING for t in trials), runner.debug_string()
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
|
||||
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake")
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), callbacks=[syncer_callback]) # noqa
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
}
|
||||
|
||||
if trainable_id == "__fake_durable":
|
||||
kwargs["remote_checkpoint_dir"] = MOCK_REMOTE_DIR
|
||||
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
|
||||
running_trials = _get_running_trials(runner)
|
||||
assert len(running_trials) == 1
|
||||
assert _check_trial_running(running_trials[0])
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
time.sleep(0.1) # Sleep so that next step() refreshes cluster resources
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save (detect error), requeue trial
|
||||
assert all(
|
||||
t.status == Trial.PENDING for t in trials), runner.debug_string()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake_remote", "__fake_durable"])
|
||||
|
|
|
@ -225,11 +225,7 @@ class _MockTrialExecutor(TrialExecutor):
|
|||
trial.status = Trial.RUNNING
|
||||
return True
|
||||
|
||||
def stop_trial(self,
|
||||
trial,
|
||||
error=False,
|
||||
error_msg=None,
|
||||
destroy_pg_if_cannot_replace=True):
|
||||
def stop_trial(self, trial, error=False, error_msg=None):
|
||||
trial.status = Trial.ERROR if error else Trial.TERMINATED
|
||||
|
||||
def restore(self, trial, checkpoint=None, block=False):
|
||||
|
@ -856,6 +852,26 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
# Helper function to call pbt.on_trial_result and assert decision,
|
||||
# or trial status upon existing.
|
||||
# Need to have the `trial` in `RUNNING` status first.
|
||||
def on_trial_result(self,
|
||||
pbt,
|
||||
runner,
|
||||
trial,
|
||||
result,
|
||||
expected_decision=None):
|
||||
trial.status = Trial.RUNNING
|
||||
decision = pbt.on_trial_result(runner, trial, result)
|
||||
if expected_decision is None:
|
||||
pass
|
||||
elif expected_decision == TrialScheduler.PAUSE:
|
||||
self.assertTrue(trial.status == Trial.PAUSED
|
||||
or decision == expected_decision)
|
||||
elif expected_decision == TrialScheduler.CONTINUE:
|
||||
self.assertEqual(decision, expected_decision)
|
||||
return decision
|
||||
|
||||
def basicSetup(self,
|
||||
num_trials=5,
|
||||
resample_prob=0.0,
|
||||
|
@ -900,13 +916,19 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trial = runner.trials[i]
|
||||
if step_once:
|
||||
if synch:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(
|
||||
pbt,
|
||||
runner,
|
||||
trial,
|
||||
result(10, 50 * i),
|
||||
expected_decision=TrialScheduler.PAUSE)
|
||||
else:
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(
|
||||
pbt,
|
||||
runner,
|
||||
trial,
|
||||
result(10, 50 * i),
|
||||
expected_decision=TrialScheduler.CONTINUE)
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
|
@ -929,12 +951,13 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
|
||||
# Should error if training_iteration not in result dict.
|
||||
with self.assertRaises(RuntimeError):
|
||||
pbt.on_trial_result(
|
||||
runner, trials[0], result={"episode_reward_mean": 4})
|
||||
self.on_trial_result(
|
||||
pbt, runner, trials[0], result={"episode_reward_mean": 4})
|
||||
|
||||
# Should error if episode_reward_mean not in result dict.
|
||||
with self.assertRaises(RuntimeError):
|
||||
pbt.on_trial_result(
|
||||
self.on_trial_result(
|
||||
pbt,
|
||||
runner,
|
||||
trials[0],
|
||||
result={
|
||||
|
@ -948,12 +971,13 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
|
||||
# Should not error if training_iteration not in result dict
|
||||
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
|
||||
pbt.on_trial_result(
|
||||
runner, trials[0], result={"episode_reward_mean": 4})
|
||||
self.on_trial_result(
|
||||
pbt, runner, trials[0], result={"episode_reward_mean": 4})
|
||||
|
||||
# Should not error if episode_reward_mean not in result dict.
|
||||
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
|
||||
pbt.on_trial_result(
|
||||
self.on_trial_result(
|
||||
pbt,
|
||||
runner,
|
||||
trials[0],
|
||||
result={
|
||||
|
@ -967,28 +991,24 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(15, 200),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# checkpoint: both past interval and upper quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, 200),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(30, 201)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[1], result(30, 201),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
# not upper quantile any more
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[4], result(30, 199)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[4], result(30, 199),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
|
@ -998,24 +1018,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(15, 200),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# trials should be paused until all trials are synced.
|
||||
for i in range(len(trials) - 1):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[i], result(20, 200 + i)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(pbt, runner, trials[i], result(20, 200 + i),
|
||||
TrialScheduler.PAUSE)
|
||||
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(20, 204)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(pbt, runner, trials[-1], result(20, 204),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
def testPerturbsLowPerformingTrials(self):
|
||||
|
@ -1023,26 +1040,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trials = runner.get_trials()
|
||||
|
||||
# no perturbation: haven't hit next perturbation interval
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(15, -100),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
# perturb since it's lower quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" in trials[0].experiment_tag)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
|
||||
# also perturbed
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[2], result(20, 40),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
|
@ -1053,25 +1067,22 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trials = runner.get_trials()
|
||||
|
||||
# no perturbation: haven't hit next perturbation interval
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[-1], result(15, -100),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
# Don't perturb until all trials are synched.
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[-1], result(20, -100)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(pbt, runner, trials[-1], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100])
|
||||
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
|
||||
|
||||
# Synch all trials.
|
||||
for i in range(len(trials) - 1):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[i], result(20, -10 * i)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(pbt, runner, trials[i], result(20, -10 * i),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100])
|
||||
self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"])
|
||||
self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"])
|
||||
|
@ -1080,9 +1091,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
def testPerturbWithoutResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertIn(trials[0].config["id_factor"], [100])
|
||||
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
|
||||
|
@ -1094,9 +1104,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
def testPerturbWithResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=1.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(trials[0].status, Trial.PAUSED)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["id_factor"], 100)
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
|
@ -1114,9 +1124,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
"id_factor": tune.choice([100])
|
||||
})
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["id_factor"], 100)
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
|
@ -1277,9 +1286,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
"int_factor": lambda: 10,
|
||||
})
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
self.assertIsInstance(trials[0].config["float_factor"], float)
|
||||
|
@ -1293,21 +1301,19 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
trials[0].status = Trial.PENDING # simulate not enough resources
|
||||
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.on_trial_result(pbt, runner, trials[1], result(20, 1000),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
|
||||
|
||||
def testSchedulesMostBehindTrialToRun(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(800, 1000))
|
||||
pbt.on_trial_result(runner, trials[1], result(700, 1001))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1002))
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1003))
|
||||
pbt.on_trial_result(runner, trials[4], result(700, 1004))
|
||||
self.on_trial_result(pbt, runner, trials[0], result(800, 1000))
|
||||
self.on_trial_result(pbt, runner, trials[1], result(700, 1001))
|
||||
self.on_trial_result(pbt, runner, trials[2], result(600, 1002))
|
||||
self.on_trial_result(pbt, runner, trials[3], result(500, 1003))
|
||||
self.on_trial_result(pbt, runner, trials[4], result(700, 1004))
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), None)
|
||||
for i in range(5):
|
||||
trials[i].status = Trial.PENDING
|
||||
|
@ -1317,35 +1323,35 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
pbt, runner = self.basicSetup(synch=True)
|
||||
trials = runner.get_trials()
|
||||
runner.process_action(
|
||||
trials[0], pbt.on_trial_result(runner, trials[0], result(
|
||||
800, 1000)))
|
||||
trials[0],
|
||||
self.on_trial_result(pbt, runner, trials[0], result(800, 1000)))
|
||||
runner.process_action(
|
||||
trials[1], pbt.on_trial_result(runner, trials[1], result(
|
||||
700, 1001)))
|
||||
trials[1],
|
||||
self.on_trial_result(pbt, runner, trials[1], result(700, 1001)))
|
||||
runner.process_action(
|
||||
trials[2], pbt.on_trial_result(runner, trials[2], result(
|
||||
600, 1002)))
|
||||
trials[2],
|
||||
self.on_trial_result(pbt, runner, trials[2], result(600, 1002)))
|
||||
runner.process_action(
|
||||
trials[3], pbt.on_trial_result(runner, trials[3], result(
|
||||
500, 1003)))
|
||||
trials[3],
|
||||
self.on_trial_result(pbt, runner, trials[3], result(500, 1003)))
|
||||
runner.process_action(
|
||||
trials[4], pbt.on_trial_result(runner, trials[4], result(
|
||||
700, 1004)))
|
||||
trials[4],
|
||||
self.on_trial_result(pbt, runner, trials[4], result(700, 1004)))
|
||||
self.assertIn(
|
||||
pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]])
|
||||
|
||||
def testPerturbationResetsLastPerturbTime(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
|
||||
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1003))
|
||||
self.on_trial_result(pbt, runner, trials[0], result(10000, 1005))
|
||||
self.on_trial_result(pbt, runner, trials[1], result(10000, 1004))
|
||||
self.on_trial_result(pbt, runner, trials[2], result(600, 1003))
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1002))
|
||||
self.on_trial_result(pbt, runner, trials[3], result(500, 1002))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(600, 100))
|
||||
self.on_trial_result(pbt, runner, trials[3], result(600, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(11000, 100))
|
||||
self.on_trial_result(pbt, runner, trials[3], result(11000, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
|
||||
def testLogConfig(self):
|
||||
|
@ -1376,9 +1382,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: i}
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100))
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100))
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40))
|
||||
self.on_trial_result(pbt, runner, trials[0], result(15, -100))
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100))
|
||||
self.on_trial_result(pbt, runner, trials[2], result(20, 40))
|
||||
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_2.txt"]
|
||||
for log_file in log_files:
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
|
||||
|
@ -1416,7 +1422,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: i}
|
||||
pbt.on_trial_result(runner, trials[i], result(10, i))
|
||||
self.on_trial_result(pbt, runner, trials[i], result(10, i))
|
||||
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"]
|
||||
for log_file in log_files:
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
|
||||
|
@ -1475,7 +1481,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trial_state[k].forward(res[TRAINING_ITERATION])
|
||||
|
||||
old_config = trials[k].config
|
||||
pbt.on_trial_result(runner, trials[k], res)
|
||||
self.on_trial_result(pbt, runner, trials[k], res)
|
||||
new_config = trials[k].config
|
||||
trial_state[k].config = new_config.copy()
|
||||
|
||||
|
@ -1538,15 +1544,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
TRAINING_ITERATION: self.iter
|
||||
}
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
return True
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(
|
||||
json.dumps({
|
||||
"iter": self.iter,
|
||||
"replayed": self.replayed
|
||||
}))
|
||||
return path
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
pass
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
checkpoint_json = json.loads(f.read())
|
||||
self.iter = checkpoint_json["iter"]
|
||||
self.replayed = checkpoint_json["replayed"]
|
||||
|
||||
# Loop through all trials and check if PBT history is the
|
||||
# same as the playback history
|
||||
|
@ -1626,15 +1638,14 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
trials[k].last_result = res
|
||||
trial_state[k].forward(res[TRAINING_ITERATION])
|
||||
|
||||
trials[k].status = Trial.RUNNING
|
||||
if not synced:
|
||||
action = pbt.on_trial_result(runner, trials[k], res)
|
||||
action = self.on_trial_result(pbt, runner, trials[k], res)
|
||||
runner.process_action(trials[k], action)
|
||||
return
|
||||
else:
|
||||
# Reached synchronization point
|
||||
old_configs = [trial.config for trial in trials]
|
||||
action = pbt.on_trial_result(runner, trials[k], res)
|
||||
action = self.on_trial_result(pbt, runner, trials[k], res)
|
||||
runner.process_action(trials[k], action)
|
||||
new_configs = [trial.config for trial in trials]
|
||||
|
||||
|
@ -1708,15 +1719,21 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
TRAINING_ITERATION: self.iter
|
||||
}
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
return True
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(
|
||||
json.dumps({
|
||||
"iter": self.iter,
|
||||
"replayed": self.replayed
|
||||
}))
|
||||
return path
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
pass
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
checkpoint_json = json.loads(f.read())
|
||||
self.iter = checkpoint_json["iter"]
|
||||
self.replayed = checkpoint_json["replayed"]
|
||||
|
||||
# Loop through all trials and check if PBT history is the
|
||||
# same as the playback history
|
||||
|
@ -1755,9 +1772,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(20, -100),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(trials[0].config["id_factor"], 42)
|
||||
self.assertEqual(trials[0].config["float_factor"], 43)
|
||||
|
||||
|
@ -1770,10 +1786,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {}
|
||||
pbt.on_trial_result(runner, trials[0], result(1, 10))
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(1, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.on_trial_result(pbt, runner, trials[0], result(1, 10))
|
||||
self.on_trial_result(pbt, runner, trials[2], result(1, 200),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
|
||||
pbt._exploit(runner.trial_executor, trials[1], trials[2])
|
||||
|
@ -1879,6 +1894,12 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
|||
f.write("OK")
|
||||
return checkpoint
|
||||
|
||||
def reset_config(self, config):
|
||||
return True
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
pass
|
||||
|
||||
trial_hyperparams = {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
|
@ -1914,6 +1935,9 @@ class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
|||
def load_checkpoint(self, state):
|
||||
self.state = state
|
||||
|
||||
def reset_config(self, config):
|
||||
return True
|
||||
|
||||
trial_hyperparams = {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
|
|
|
@ -174,7 +174,6 @@ class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
|
|||
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0"
|
||||
os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999"
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def MockTrainingFuncSync(config, checkpoint_dir=None):
|
||||
|
@ -210,7 +209,7 @@ class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
|||
|
||||
def synchSetup(self, synch, param=None):
|
||||
if param is None:
|
||||
param = [10, 20, 30]
|
||||
param = [10, 20, 40]
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
|
@ -244,14 +243,14 @@ class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
any(
|
||||
analysis.dataframe(metric="mean_accuracy", mode="max")
|
||||
["mean_accuracy"] != 33))
|
||||
["mean_accuracy"] != 43))
|
||||
|
||||
def testSynchPass(self):
|
||||
analysis = self.synchSetup(True)
|
||||
self.assertTrue(
|
||||
all(
|
||||
analysis.dataframe(metric="mean_accuracy", mode="max")[
|
||||
"mean_accuracy"] == 33))
|
||||
"mean_accuracy"] == 43))
|
||||
|
||||
def testSynchPassLast(self):
|
||||
analysis = self.synchSetup(True, param=[30, 20, 10])
|
||||
|
|
|
@ -190,8 +190,7 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
def stop_trial(self,
|
||||
trial: Trial,
|
||||
error: bool = False,
|
||||
error_msg: Optional[str] = None,
|
||||
destroy_pg_if_cannot_replace: bool = True) -> None:
|
||||
error_msg: Optional[str] = None) -> None:
|
||||
"""Stops the trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources.
|
||||
|
@ -201,8 +200,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
destroy_pg_if_cannot_replace (bool): Whether the trial's placement
|
||||
group should be destroyed if it cannot replace any staged ones.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -1176,6 +1176,8 @@ class TrialRunner:
|
|||
elif decision == TrialScheduler.STOP:
|
||||
self.trial_executor.export_trial_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
elif decision == TrialScheduler.NOOP:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Invalid decision: {}".format(decision))
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ from ray.tune.progress_reporter import (detect_reporter, ProgressReporter,
|
|||
JupyterNotebookReporter)
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.schedulers import (PopulationBasedTraining,
|
||||
PopulationBasedTrainingReplay)
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchAlgorithm, \
|
||||
SearchGenerator
|
||||
|
@ -413,6 +415,13 @@ def run(
|
|||
f"to 1 instead.")
|
||||
result_buffer_length = 1
|
||||
|
||||
if isinstance(scheduler,
|
||||
(PopulationBasedTraining,
|
||||
PopulationBasedTrainingReplay)) and not reuse_actors:
|
||||
warnings.warn(
|
||||
"Consider boosting PBT performance by enabling `reuse_actors` as "
|
||||
"well as implementing `reset_config` for Trainable.")
|
||||
|
||||
trial_executor = trial_executor or RayTrialExecutor(
|
||||
reuse_actors=reuse_actors, result_buffer_length=result_buffer_length)
|
||||
if isinstance(run_or_experiment, list):
|
||||
|
|
|
@ -607,37 +607,17 @@ class PlacementGroupManager:
|
|||
self._ready[pgf].add(pg)
|
||||
return True
|
||||
|
||||
def return_pg(self,
|
||||
trial: "Trial",
|
||||
destroy_pg_if_cannot_replace: bool = True):
|
||||
"""Return pg, making it available for other trials to use.
|
||||
|
||||
If destroy_pg_if_cannot_replace is True, this will only return
|
||||
a placement group if a staged placement group can be replaced
|
||||
by it. If not, it will destroy the placement group.
|
||||
def return_pg(self, trial: "Trial"):
|
||||
"""Return pg back to Core scheduling.
|
||||
|
||||
Args:
|
||||
trial (Trial): Return placement group of this trial.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if the placement group was returned.
|
||||
"""
|
||||
pgf = trial.placement_group_factory
|
||||
|
||||
pg = self._in_use_trials.pop(trial)
|
||||
self._in_use_pgs.pop(pg)
|
||||
|
||||
if destroy_pg_if_cannot_replace:
|
||||
staged_pg = self._unstage_unused_pg(pgf)
|
||||
|
||||
# Could not replace
|
||||
if not staged_pg:
|
||||
self.remove_pg(pg)
|
||||
return False
|
||||
|
||||
self.remove_pg(staged_pg)
|
||||
self._ready[pgf].add(pg)
|
||||
|
||||
return True
|
||||
self.remove_pg(pg)
|
||||
|
||||
def _unstage_unused_pg(
|
||||
self, pgf: PlacementGroupFactory) -> Optional[PlacementGroup]:
|
||||
|
|
Loading…
Add table
Reference in a new issue