mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune/rllib] Fix tune cloud tests for function and rllib trainables (#20536)
Fixes some race conditions and softens some constraints around checkpoint numbers.
This commit is contained in:
parent
414737b7c7
commit
7446269ac9
7 changed files with 77 additions and 54 deletions
|
@ -283,7 +283,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def _setup_remote_runner(self, trial):
|
||||
trial.init_logdir()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
self._trials_to_cache.add(trial)
|
||||
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
|
||||
|
||||
if self._reuse_actors and len(self._cached_actor_pg) > 0:
|
||||
|
|
|
@ -510,8 +510,8 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
runner2.step() # Process save
|
||||
self.assertRaises(TuneError, runner2.step)
|
||||
|
||||
def testTrialNoSave(self):
|
||||
"""Check that non-checkpointing trials are not saved."""
|
||||
def testTrialNoCheckpointSave(self):
|
||||
"""Check that non-checkpointing trials *are* saved."""
|
||||
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
|
||||
|
||||
ray.init(num_cpus=3)
|
||||
|
@ -557,7 +557,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
runner2.get_trial("checkpoint").status == Trial.TERMINATED)
|
||||
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
|
||||
self.assertTrue(
|
||||
not runner2.get_trial("pending").has_reported_at_least_once)
|
||||
runner2.get_trial("pending").has_reported_at_least_once)
|
||||
runner2.step()
|
||||
|
||||
def testCheckpointWithFunction(self):
|
||||
|
|
|
@ -151,25 +151,10 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
trial.status, status)
|
||||
trial.set_status(status)
|
||||
if status in [Trial.TERMINATED, Trial.ERROR]:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
|
||||
def try_checkpoint_metadata(self, trial: Trial) -> None:
|
||||
"""Checkpoints trial metadata.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
"""
|
||||
if trial.checkpoint.storage == Checkpoint.MEMORY:
|
||||
logger.debug("Trial %s: Not saving data for memory checkpoint.",
|
||||
trial)
|
||||
return
|
||||
try:
|
||||
logger.debug("Trial %s: Saving trial metadata.", trial)
|
||||
# Lazy cache trials
|
||||
self._trials_to_cache.add(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error checkpointing trial metadata.",
|
||||
trial)
|
||||
|
||||
def mark_trial_to_checkpoint(self, trial: Trial) -> None:
|
||||
self._trials_to_cache.add(trial)
|
||||
|
||||
def get_checkpoints(self) -> Dict[str, str]:
|
||||
"""Returns a copy of mapping of the trial ID to pickled metadata."""
|
||||
|
|
|
@ -753,7 +753,7 @@ class TrialRunner:
|
|||
self._live_trials.add(trial)
|
||||
with warn_if_slow("scheduler.on_trial_add"):
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
self.trial_executor.mark_trial_to_checkpoint(trial)
|
||||
|
||||
def debug_string(self, delim="\n"):
|
||||
from ray.tune.progress_reporter import trial_progress_str
|
||||
|
@ -987,6 +987,8 @@ class TrialRunner:
|
|||
|
||||
if not is_duplicate:
|
||||
trial.update_last_result(result)
|
||||
# Include in next experiment checkpoint
|
||||
self.trial_executor.mark_trial_to_checkpoint(trial)
|
||||
|
||||
# Checkpoints to disk. This should be checked even if
|
||||
# the scheduler decision is STOP or PAUSE. Note that
|
||||
|
@ -1089,7 +1091,8 @@ class TrialRunner:
|
|||
trial=trial,
|
||||
checkpoint=trial.saving_to)
|
||||
trial.on_checkpoint(trial.saving_to)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
if trial.checkpoint.storage != Checkpoint.MEMORY:
|
||||
self.trial_executor.mark_trial_to_checkpoint(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error handling checkpoint %s",
|
||||
trial, checkpoint_value)
|
||||
|
|
|
@ -160,8 +160,8 @@ NIGHTLY_TESTS = {
|
|||
"aws_no_sync_down",
|
||||
"aws_ssh_sync",
|
||||
"aws_durable_upload",
|
||||
# "aws_durable_upload_rllib_str",
|
||||
# "aws_durable_upload_rllib_trainer",
|
||||
"aws_durable_upload_rllib_str",
|
||||
"aws_durable_upload_rllib_trainer",
|
||||
"gcp_k8s_durable_upload",
|
||||
],
|
||||
"~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [
|
||||
|
|
|
@ -33,13 +33,8 @@ def fn_trainable(config, checkpoint_dir=None):
|
|||
|
||||
|
||||
class RLLibCallback(DefaultCallbacks):
|
||||
def __init__(self):
|
||||
super(RLLibCallback, self).__init__()
|
||||
self.internal_iter = 0
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
result["internal_iter"] = self.internal_iter
|
||||
self.internal_iter += 1
|
||||
result["internal_iter"] = result["training_iteration"]
|
||||
|
||||
|
||||
class IndicatorCallback(tune.Callback):
|
||||
|
@ -60,7 +55,7 @@ def run_tune(no_syncer: bool,
|
|||
if trainable == "function":
|
||||
train = fn_trainable
|
||||
config = {
|
||||
"max_iterations": 30,
|
||||
"max_iterations": 100,
|
||||
"sleep_time": 5,
|
||||
"checkpoint_freq": 2,
|
||||
"score_multiplied": tune.randint(0, 100),
|
||||
|
@ -80,8 +75,10 @@ def run_tune(no_syncer: bool,
|
|||
}
|
||||
kwargs = {
|
||||
"stop": {
|
||||
"training_iteration": 10
|
||||
"training_iteration": 100
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"checkpoint_at_end": True,
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown trainable: {trainable}")
|
||||
|
|
|
@ -43,6 +43,7 @@ import time
|
|||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.tune.trial_runner import find_newest_experiment_checkpoint
|
||||
from ray.tune.utils.serialization import TuneFunctionDecoder
|
||||
|
||||
|
@ -118,6 +119,7 @@ class TrialCheckpointData:
|
|||
results: List[Dict[str, Any]]
|
||||
progress: List[Dict[str, Any]]
|
||||
checkpoints: List[Tuple[str, Dict[Any, Any]]]
|
||||
num_skipped: int
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
@ -584,6 +586,7 @@ def load_trial_checkpoint_data(trial_dir: str,
|
|||
progress = []
|
||||
|
||||
checkpoints = []
|
||||
num_skipped = 0
|
||||
for cp_dir in sorted(os.listdir(trial_dir)):
|
||||
if not cp_dir.startswith("checkpoint_"):
|
||||
continue
|
||||
|
@ -601,20 +604,32 @@ def load_trial_checkpoint_data(trial_dir: str,
|
|||
print(f"Skipping unobserved checkpoint: {cp_full_dir} as "
|
||||
f"{checkpoint_num} > "
|
||||
f"{node_trial.last_result['internal_iter']}")
|
||||
num_skipped += 1
|
||||
continue
|
||||
except ValueError:
|
||||
# temporary checkpoint
|
||||
continue
|
||||
|
||||
with open(os.path.join(cp_full_dir, "checkpoint.json"), "rt") as f:
|
||||
checkpoint_data = json.load(f)
|
||||
json_path = os.path.join(cp_full_dir, "checkpoint.json")
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "rt") as f:
|
||||
checkpoint_data = json.load(f)
|
||||
else:
|
||||
meta_path = os.path.join(
|
||||
cp_full_dir, f"checkpoint-{checkpoint_num}.tune_metadata")
|
||||
with open(meta_path, "rb") as f:
|
||||
checkpoint_meta = pickle.load(f)
|
||||
checkpoint_data = {
|
||||
"internal_iter": checkpoint_meta["iteration"]
|
||||
}
|
||||
checkpoints.append((cp_dir, checkpoint_data))
|
||||
|
||||
return TrialCheckpointData(
|
||||
params=params,
|
||||
results=results,
|
||||
progress=progress,
|
||||
checkpoints=checkpoints)
|
||||
checkpoints=checkpoints,
|
||||
num_skipped=num_skipped)
|
||||
|
||||
|
||||
def load_data_from_trial_exp_checkpoints(
|
||||
|
@ -713,18 +728,30 @@ def assert_min_num_trials(trials: Iterable[TrialStub], on_driver: int,
|
|||
|
||||
def assert_checkpoint_count(experiment_dir_cp: ExperimentDirCheckpoint,
|
||||
for_driver_trial: int, for_worker_trial: int):
|
||||
# We relaxed the requirements here and also allow
|
||||
# skipped checkpoints to count. This could be the case if e.g. the trial
|
||||
# already checkpointed but the driver did not process the last result, yet.
|
||||
# We also allow up to one un-collected checkpoint.
|
||||
# Todo: Can we make this stricter?
|
||||
for trial, trial_cp in experiment_dir_cp.trial_to_cps.items():
|
||||
cps = len(trial_cp.checkpoints)
|
||||
num_skipped = trial_cp.num_skipped
|
||||
if trial.was_on_driver_node:
|
||||
assert cps == for_driver_trial, (
|
||||
f"Trial {trial.trial_id} was on driver, "
|
||||
f"but did not observe the expected amount of checkpoints "
|
||||
f"({cps} != {for_driver_trial}).")
|
||||
assert (
|
||||
cps == for_driver_trial
|
||||
or cps + num_skipped == for_driver_trial
|
||||
or cps == for_driver_trial + 1), (
|
||||
f"Trial {trial.trial_id} was on driver, "
|
||||
f"but did not observe the expected amount of checkpoints "
|
||||
f"({cps} != {for_driver_trial}).")
|
||||
else:
|
||||
assert cps == for_worker_trial, (
|
||||
f"Trial {trial.trial_id} was not on the driver, "
|
||||
f"but did not observe the expected amount of checkpoints "
|
||||
f"({cps} != {for_worker_trial}).")
|
||||
assert (
|
||||
cps == for_worker_trial
|
||||
or cps + num_skipped == for_worker_trial
|
||||
or cps == for_worker_trial + 1), (
|
||||
f"Trial {trial.trial_id} was not on the driver, "
|
||||
f"but did not observe the expected amount of checkpoints "
|
||||
f"({cps} != {for_worker_trial}).")
|
||||
|
||||
|
||||
def assert_trial_progressed_training(trial: TrialStub):
|
||||
|
@ -830,13 +857,15 @@ def test_no_sync_down():
|
|||
f"errored, there is something wrong with restoration. If less, "
|
||||
f"maybe cleanup has not worked, or syncing to driver took place.")
|
||||
|
||||
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
|
||||
|
||||
run_resume_flow(
|
||||
experiment_name=experiment_name,
|
||||
indicator_file=indicator_file,
|
||||
no_syncer=True,
|
||||
upload_dir=None,
|
||||
first_run_time=45,
|
||||
second_run_time=45,
|
||||
first_run_time=run_time,
|
||||
second_run_time=run_time,
|
||||
between_experiments_callback=between_experiments,
|
||||
after_experiments_callback=after_experiments)
|
||||
|
||||
|
@ -922,13 +951,15 @@ def test_ssh_sync():
|
|||
for trial in experiment_state.trials:
|
||||
assert_trial_progressed_training(trial)
|
||||
|
||||
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
|
||||
|
||||
run_resume_flow(
|
||||
experiment_name=experiment_name,
|
||||
indicator_file=indicator_file,
|
||||
no_syncer=False,
|
||||
upload_dir=None,
|
||||
first_run_time=55, # More time because of SSH syncing
|
||||
second_run_time=55,
|
||||
first_run_time=run_time + 10, # More time because of SSH syncing
|
||||
second_run_time=run_time + 10,
|
||||
between_experiments_callback=between_experiments,
|
||||
after_experiments_callback=after_experiments)
|
||||
|
||||
|
@ -1048,13 +1079,15 @@ def test_durable_upload(bucket: str):
|
|||
|
||||
clear_bucket_contents(bucket)
|
||||
|
||||
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
|
||||
|
||||
run_resume_flow(
|
||||
experiment_name=experiment_name,
|
||||
indicator_file=indicator_file,
|
||||
no_syncer=False,
|
||||
upload_dir=bucket,
|
||||
first_run_time=45,
|
||||
second_run_time=45,
|
||||
first_run_time=run_time,
|
||||
second_run_time=run_time,
|
||||
before_experiments_callback=before_experiments,
|
||||
between_experiments_callback=between_experiments,
|
||||
after_experiments_callback=after_experiments)
|
||||
|
@ -1094,6 +1127,7 @@ if __name__ == "__main__":
|
|||
|
||||
def _run_test(variant: str,
|
||||
trainable: str = "function",
|
||||
run_time: int = 180,
|
||||
bucket: str = "",
|
||||
cpus_per_trial: int = 2,
|
||||
overwrite_tune_script: Optional[str] = None):
|
||||
|
@ -1103,6 +1137,7 @@ if __name__ == "__main__":
|
|||
f"{cpus_per_trial} CPUs per trial.")
|
||||
|
||||
os.environ["TUNE_TRAINABLE"] = str(trainable)
|
||||
os.environ["TUNE_RUN_TIME"] = str(run_time)
|
||||
os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial)
|
||||
|
||||
if overwrite_tune_script:
|
||||
|
@ -1124,9 +1159,11 @@ if __name__ == "__main__":
|
|||
with open(release_test_out, "wt") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
run_time = 180 if "rllib" in args.trainable else 90
|
||||
|
||||
if not uses_ray_client:
|
||||
print("This test will *not* use Ray client.")
|
||||
_run_test(args.variant, args.trainable, args.bucket,
|
||||
_run_test(args.variant, args.trainable, run_time, args.bucket,
|
||||
args.cpus_per_trial)
|
||||
else:
|
||||
print("This test will run using Ray client.")
|
||||
|
@ -1150,8 +1187,9 @@ if __name__ == "__main__":
|
|||
_run_test_remote = ray.remote(
|
||||
resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test)
|
||||
ray.get(
|
||||
_run_test_remote.remote(args.variant, args.trainable, args.bucket,
|
||||
args.cpus_per_trial, remote_tune_script))
|
||||
_run_test_remote.remote(args.variant, args.trainable, run_time,
|
||||
args.bucket, args.cpus_per_trial,
|
||||
remote_tune_script))
|
||||
|
||||
print(f"Fetching remote release test result file: {release_test_out}")
|
||||
fetch_remote_file_to_local_file(release_test_out, ip, release_test_out)
|
||||
|
|
Loading…
Add table
Reference in a new issue