[tune/rllib] Fix tune cloud tests for function and rllib trainables (#20536)

Fixes some race conditions and softens some constraints around checkpoint numbers.
This commit is contained in:
Kai Fricke 2021-11-24 09:29:12 +00:00 committed by GitHub
parent 414737b7c7
commit 7446269ac9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 77 additions and 54 deletions

View file

@ -283,7 +283,7 @@ class RayTrialExecutor(TrialExecutor):
def _setup_remote_runner(self, trial):
trial.init_logdir()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
self._trials_to_cache.add(trial)
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
if self._reuse_actors and len(self._cached_actor_pg) > 0:

View file

@ -510,8 +510,8 @@ class TrialRunnerTest3(unittest.TestCase):
runner2.step() # Process save
self.assertRaises(TuneError, runner2.step)
def testTrialNoSave(self):
"""Check that non-checkpointing trials are not saved."""
def testTrialNoCheckpointSave(self):
"""Check that non-checkpointing trials *are* saved."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=3)
@ -557,7 +557,7 @@ class TrialRunnerTest3(unittest.TestCase):
runner2.get_trial("checkpoint").status == Trial.TERMINATED)
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
self.assertTrue(
not runner2.get_trial("pending").has_reported_at_least_once)
runner2.get_trial("pending").has_reported_at_least_once)
runner2.step()
def testCheckpointWithFunction(self):

View file

@ -151,25 +151,10 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
trial.status, status)
trial.set_status(status)
if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)
def try_checkpoint_metadata(self, trial: Trial) -> None:
"""Checkpoints trial metadata.
Args:
trial (Trial): Trial to checkpoint.
"""
if trial.checkpoint.storage == Checkpoint.MEMORY:
logger.debug("Trial %s: Not saving data for memory checkpoint.",
trial)
return
try:
logger.debug("Trial %s: Saving trial metadata.", trial)
# Lazy cache trials
self._trials_to_cache.add(trial)
except Exception:
logger.exception("Trial %s: Error checkpointing trial metadata.",
trial)
def mark_trial_to_checkpoint(self, trial: Trial) -> None:
self._trials_to_cache.add(trial)
def get_checkpoints(self) -> Dict[str, str]:
"""Returns a copy of mapping of the trial ID to pickled metadata."""

View file

@ -753,7 +753,7 @@ class TrialRunner:
self._live_trials.add(trial)
with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self, trial)
self.trial_executor.try_checkpoint_metadata(trial)
self.trial_executor.mark_trial_to_checkpoint(trial)
def debug_string(self, delim="\n"):
from ray.tune.progress_reporter import trial_progress_str
@ -987,6 +987,8 @@ class TrialRunner:
if not is_duplicate:
trial.update_last_result(result)
# Include in next experiment checkpoint
self.trial_executor.mark_trial_to_checkpoint(trial)
# Checkpoints to disk. This should be checked even if
# the scheduler decision is STOP or PAUSE. Note that
@ -1089,7 +1091,8 @@ class TrialRunner:
trial=trial,
checkpoint=trial.saving_to)
trial.on_checkpoint(trial.saving_to)
self.trial_executor.try_checkpoint_metadata(trial)
if trial.checkpoint.storage != Checkpoint.MEMORY:
self.trial_executor.mark_trial_to_checkpoint(trial)
except Exception:
logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint_value)

View file

@ -160,8 +160,8 @@ NIGHTLY_TESTS = {
"aws_no_sync_down",
"aws_ssh_sync",
"aws_durable_upload",
# "aws_durable_upload_rllib_str",
# "aws_durable_upload_rllib_trainer",
"aws_durable_upload_rllib_str",
"aws_durable_upload_rllib_trainer",
"gcp_k8s_durable_upload",
],
"~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [

View file

@ -33,13 +33,8 @@ def fn_trainable(config, checkpoint_dir=None):
class RLLibCallback(DefaultCallbacks):
def __init__(self):
super(RLLibCallback, self).__init__()
self.internal_iter = 0
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
result["internal_iter"] = self.internal_iter
self.internal_iter += 1
result["internal_iter"] = result["training_iteration"]
class IndicatorCallback(tune.Callback):
@ -60,7 +55,7 @@ def run_tune(no_syncer: bool,
if trainable == "function":
train = fn_trainable
config = {
"max_iterations": 30,
"max_iterations": 100,
"sleep_time": 5,
"checkpoint_freq": 2,
"score_multiplied": tune.randint(0, 100),
@ -80,8 +75,10 @@ def run_tune(no_syncer: bool,
}
kwargs = {
"stop": {
"training_iteration": 10
"training_iteration": 100
},
"checkpoint_freq": 2,
"checkpoint_at_end": True,
}
else:
raise RuntimeError(f"Unknown trainable: {trainable}")

View file

@ -43,6 +43,7 @@ import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import ray
import ray.cloudpickle as pickle
from ray.tune.trial_runner import find_newest_experiment_checkpoint
from ray.tune.utils.serialization import TuneFunctionDecoder
@ -118,6 +119,7 @@ class TrialCheckpointData:
results: List[Dict[str, Any]]
progress: List[Dict[str, Any]]
checkpoints: List[Tuple[str, Dict[Any, Any]]]
num_skipped: int
# Utility functions
@ -584,6 +586,7 @@ def load_trial_checkpoint_data(trial_dir: str,
progress = []
checkpoints = []
num_skipped = 0
for cp_dir in sorted(os.listdir(trial_dir)):
if not cp_dir.startswith("checkpoint_"):
continue
@ -601,20 +604,32 @@ def load_trial_checkpoint_data(trial_dir: str,
print(f"Skipping unobserved checkpoint: {cp_full_dir} as "
f"{checkpoint_num} > "
f"{node_trial.last_result['internal_iter']}")
num_skipped += 1
continue
except ValueError:
# temporary checkpoint
continue
with open(os.path.join(cp_full_dir, "checkpoint.json"), "rt") as f:
checkpoint_data = json.load(f)
json_path = os.path.join(cp_full_dir, "checkpoint.json")
if os.path.exists(json_path):
with open(json_path, "rt") as f:
checkpoint_data = json.load(f)
else:
meta_path = os.path.join(
cp_full_dir, f"checkpoint-{checkpoint_num}.tune_metadata")
with open(meta_path, "rb") as f:
checkpoint_meta = pickle.load(f)
checkpoint_data = {
"internal_iter": checkpoint_meta["iteration"]
}
checkpoints.append((cp_dir, checkpoint_data))
return TrialCheckpointData(
params=params,
results=results,
progress=progress,
checkpoints=checkpoints)
checkpoints=checkpoints,
num_skipped=num_skipped)
def load_data_from_trial_exp_checkpoints(
@ -713,18 +728,30 @@ def assert_min_num_trials(trials: Iterable[TrialStub], on_driver: int,
def assert_checkpoint_count(experiment_dir_cp: ExperimentDirCheckpoint,
for_driver_trial: int, for_worker_trial: int):
# We relaxed the requirements here and also allow
# skipped checkpoints to count. This could be the case if e.g. the trial
# already checkpointed but the driver did not process the last result, yet.
# We also allow up to one un-collected checkpoint.
# Todo: Can we make this stricter?
for trial, trial_cp in experiment_dir_cp.trial_to_cps.items():
cps = len(trial_cp.checkpoints)
num_skipped = trial_cp.num_skipped
if trial.was_on_driver_node:
assert cps == for_driver_trial, (
f"Trial {trial.trial_id} was on driver, "
f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_driver_trial}).")
assert (
cps == for_driver_trial
or cps + num_skipped == for_driver_trial
or cps == for_driver_trial + 1), (
f"Trial {trial.trial_id} was on driver, "
f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_driver_trial}).")
else:
assert cps == for_worker_trial, (
f"Trial {trial.trial_id} was not on the driver, "
f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_worker_trial}).")
assert (
cps == for_worker_trial
or cps + num_skipped == for_worker_trial
or cps == for_worker_trial + 1), (
f"Trial {trial.trial_id} was not on the driver, "
f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_worker_trial}).")
def assert_trial_progressed_training(trial: TrialStub):
@ -830,13 +857,15 @@ def test_no_sync_down():
f"errored, there is something wrong with restoration. If less, "
f"maybe cleanup has not worked, or syncing to driver took place.")
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow(
experiment_name=experiment_name,
indicator_file=indicator_file,
no_syncer=True,
upload_dir=None,
first_run_time=45,
second_run_time=45,
first_run_time=run_time,
second_run_time=run_time,
between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments)
@ -922,13 +951,15 @@ def test_ssh_sync():
for trial in experiment_state.trials:
assert_trial_progressed_training(trial)
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow(
experiment_name=experiment_name,
indicator_file=indicator_file,
no_syncer=False,
upload_dir=None,
first_run_time=55, # More time because of SSH syncing
second_run_time=55,
first_run_time=run_time + 10, # More time because of SSH syncing
second_run_time=run_time + 10,
between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments)
@ -1048,13 +1079,15 @@ def test_durable_upload(bucket: str):
clear_bucket_contents(bucket)
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow(
experiment_name=experiment_name,
indicator_file=indicator_file,
no_syncer=False,
upload_dir=bucket,
first_run_time=45,
second_run_time=45,
first_run_time=run_time,
second_run_time=run_time,
before_experiments_callback=before_experiments,
between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments)
@ -1094,6 +1127,7 @@ if __name__ == "__main__":
def _run_test(variant: str,
trainable: str = "function",
run_time: int = 180,
bucket: str = "",
cpus_per_trial: int = 2,
overwrite_tune_script: Optional[str] = None):
@ -1103,6 +1137,7 @@ if __name__ == "__main__":
f"{cpus_per_trial} CPUs per trial.")
os.environ["TUNE_TRAINABLE"] = str(trainable)
os.environ["TUNE_RUN_TIME"] = str(run_time)
os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial)
if overwrite_tune_script:
@ -1124,9 +1159,11 @@ if __name__ == "__main__":
with open(release_test_out, "wt") as f:
json.dump(result, f)
run_time = 180 if "rllib" in args.trainable else 90
if not uses_ray_client:
print("This test will *not* use Ray client.")
_run_test(args.variant, args.trainable, args.bucket,
_run_test(args.variant, args.trainable, run_time, args.bucket,
args.cpus_per_trial)
else:
print("This test will run using Ray client.")
@ -1150,8 +1187,9 @@ if __name__ == "__main__":
_run_test_remote = ray.remote(
resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test)
ray.get(
_run_test_remote.remote(args.variant, args.trainable, args.bucket,
args.cpus_per_trial, remote_tune_script))
_run_test_remote.remote(args.variant, args.trainable, run_time,
args.bucket, args.cpus_per_trial,
remote_tune_script))
print(f"Fetching remote release test result file: {release_test_out}")
fetch_remote_file_to_local_file(release_test_out, ip, release_test_out)