[tune/rllib] Fix tune cloud tests for function and rllib trainables (#20536)

Fixes some race conditions and softens some constraints around checkpoint numbers.
This commit is contained in:
Kai Fricke 2021-11-24 09:29:12 +00:00 committed by GitHub
parent 414737b7c7
commit 7446269ac9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 77 additions and 54 deletions

View file

@ -283,7 +283,7 @@ class RayTrialExecutor(TrialExecutor):
def _setup_remote_runner(self, trial): def _setup_remote_runner(self, trial):
trial.init_logdir() trial.init_logdir()
# We checkpoint metadata here to try mitigating logdir duplication # We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial) self._trials_to_cache.add(trial)
logger_creator = partial(noop_logger_creator, logdir=trial.logdir) logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
if self._reuse_actors and len(self._cached_actor_pg) > 0: if self._reuse_actors and len(self._cached_actor_pg) > 0:

View file

@ -510,8 +510,8 @@ class TrialRunnerTest3(unittest.TestCase):
runner2.step() # Process save runner2.step() # Process save
self.assertRaises(TuneError, runner2.step) self.assertRaises(TuneError, runner2.step)
def testTrialNoSave(self): def testTrialNoCheckpointSave(self):
"""Check that non-checkpointing trials are not saved.""" """Check that non-checkpointing trials *are* saved."""
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"
ray.init(num_cpus=3) ray.init(num_cpus=3)
@ -557,7 +557,7 @@ class TrialRunnerTest3(unittest.TestCase):
runner2.get_trial("checkpoint").status == Trial.TERMINATED) runner2.get_trial("checkpoint").status == Trial.TERMINATED)
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
self.assertTrue( self.assertTrue(
not runner2.get_trial("pending").has_reported_at_least_once) runner2.get_trial("pending").has_reported_at_least_once)
runner2.step() runner2.step()
def testCheckpointWithFunction(self): def testCheckpointWithFunction(self):

View file

@ -151,25 +151,10 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
trial.status, status) trial.status, status)
trial.set_status(status) trial.set_status(status)
if status in [Trial.TERMINATED, Trial.ERROR]: if status in [Trial.TERMINATED, Trial.ERROR]:
self.try_checkpoint_metadata(trial)
def try_checkpoint_metadata(self, trial: Trial) -> None:
"""Checkpoints trial metadata.
Args:
trial (Trial): Trial to checkpoint.
"""
if trial.checkpoint.storage == Checkpoint.MEMORY:
logger.debug("Trial %s: Not saving data for memory checkpoint.",
trial)
return
try:
logger.debug("Trial %s: Saving trial metadata.", trial)
# Lazy cache trials
self._trials_to_cache.add(trial) self._trials_to_cache.add(trial)
except Exception:
logger.exception("Trial %s: Error checkpointing trial metadata.", def mark_trial_to_checkpoint(self, trial: Trial) -> None:
trial) self._trials_to_cache.add(trial)
def get_checkpoints(self) -> Dict[str, str]: def get_checkpoints(self) -> Dict[str, str]:
"""Returns a copy of mapping of the trial ID to pickled metadata.""" """Returns a copy of mapping of the trial ID to pickled metadata."""

View file

@ -753,7 +753,7 @@ class TrialRunner:
self._live_trials.add(trial) self._live_trials.add(trial)
with warn_if_slow("scheduler.on_trial_add"): with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self, trial) self._scheduler_alg.on_trial_add(self, trial)
self.trial_executor.try_checkpoint_metadata(trial) self.trial_executor.mark_trial_to_checkpoint(trial)
def debug_string(self, delim="\n"): def debug_string(self, delim="\n"):
from ray.tune.progress_reporter import trial_progress_str from ray.tune.progress_reporter import trial_progress_str
@ -987,6 +987,8 @@ class TrialRunner:
if not is_duplicate: if not is_duplicate:
trial.update_last_result(result) trial.update_last_result(result)
# Include in next experiment checkpoint
self.trial_executor.mark_trial_to_checkpoint(trial)
# Checkpoints to disk. This should be checked even if # Checkpoints to disk. This should be checked even if
# the scheduler decision is STOP or PAUSE. Note that # the scheduler decision is STOP or PAUSE. Note that
@ -1089,7 +1091,8 @@ class TrialRunner:
trial=trial, trial=trial,
checkpoint=trial.saving_to) checkpoint=trial.saving_to)
trial.on_checkpoint(trial.saving_to) trial.on_checkpoint(trial.saving_to)
self.trial_executor.try_checkpoint_metadata(trial) if trial.checkpoint.storage != Checkpoint.MEMORY:
self.trial_executor.mark_trial_to_checkpoint(trial)
except Exception: except Exception:
logger.exception("Trial %s: Error handling checkpoint %s", logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint_value) trial, checkpoint_value)

View file

@ -160,8 +160,8 @@ NIGHTLY_TESTS = {
"aws_no_sync_down", "aws_no_sync_down",
"aws_ssh_sync", "aws_ssh_sync",
"aws_durable_upload", "aws_durable_upload",
# "aws_durable_upload_rllib_str", "aws_durable_upload_rllib_str",
# "aws_durable_upload_rllib_trainer", "aws_durable_upload_rllib_trainer",
"gcp_k8s_durable_upload", "gcp_k8s_durable_upload",
], ],
"~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [ "~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [

View file

@ -33,13 +33,8 @@ def fn_trainable(config, checkpoint_dir=None):
class RLLibCallback(DefaultCallbacks): class RLLibCallback(DefaultCallbacks):
def __init__(self):
super(RLLibCallback, self).__init__()
self.internal_iter = 0
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
result["internal_iter"] = self.internal_iter result["internal_iter"] = result["training_iteration"]
self.internal_iter += 1
class IndicatorCallback(tune.Callback): class IndicatorCallback(tune.Callback):
@ -60,7 +55,7 @@ def run_tune(no_syncer: bool,
if trainable == "function": if trainable == "function":
train = fn_trainable train = fn_trainable
config = { config = {
"max_iterations": 30, "max_iterations": 100,
"sleep_time": 5, "sleep_time": 5,
"checkpoint_freq": 2, "checkpoint_freq": 2,
"score_multiplied": tune.randint(0, 100), "score_multiplied": tune.randint(0, 100),
@ -80,8 +75,10 @@ def run_tune(no_syncer: bool,
} }
kwargs = { kwargs = {
"stop": { "stop": {
"training_iteration": 10 "training_iteration": 100
}, },
"checkpoint_freq": 2,
"checkpoint_at_end": True,
} }
else: else:
raise RuntimeError(f"Unknown trainable: {trainable}") raise RuntimeError(f"Unknown trainable: {trainable}")

View file

@ -43,6 +43,7 @@ import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import ray import ray
import ray.cloudpickle as pickle
from ray.tune.trial_runner import find_newest_experiment_checkpoint from ray.tune.trial_runner import find_newest_experiment_checkpoint
from ray.tune.utils.serialization import TuneFunctionDecoder from ray.tune.utils.serialization import TuneFunctionDecoder
@ -118,6 +119,7 @@ class TrialCheckpointData:
results: List[Dict[str, Any]] results: List[Dict[str, Any]]
progress: List[Dict[str, Any]] progress: List[Dict[str, Any]]
checkpoints: List[Tuple[str, Dict[Any, Any]]] checkpoints: List[Tuple[str, Dict[Any, Any]]]
num_skipped: int
# Utility functions # Utility functions
@ -584,6 +586,7 @@ def load_trial_checkpoint_data(trial_dir: str,
progress = [] progress = []
checkpoints = [] checkpoints = []
num_skipped = 0
for cp_dir in sorted(os.listdir(trial_dir)): for cp_dir in sorted(os.listdir(trial_dir)):
if not cp_dir.startswith("checkpoint_"): if not cp_dir.startswith("checkpoint_"):
continue continue
@ -601,20 +604,32 @@ def load_trial_checkpoint_data(trial_dir: str,
print(f"Skipping unobserved checkpoint: {cp_full_dir} as " print(f"Skipping unobserved checkpoint: {cp_full_dir} as "
f"{checkpoint_num} > " f"{checkpoint_num} > "
f"{node_trial.last_result['internal_iter']}") f"{node_trial.last_result['internal_iter']}")
num_skipped += 1
continue continue
except ValueError: except ValueError:
# temporary checkpoint # temporary checkpoint
continue continue
with open(os.path.join(cp_full_dir, "checkpoint.json"), "rt") as f: json_path = os.path.join(cp_full_dir, "checkpoint.json")
if os.path.exists(json_path):
with open(json_path, "rt") as f:
checkpoint_data = json.load(f) checkpoint_data = json.load(f)
else:
meta_path = os.path.join(
cp_full_dir, f"checkpoint-{checkpoint_num}.tune_metadata")
with open(meta_path, "rb") as f:
checkpoint_meta = pickle.load(f)
checkpoint_data = {
"internal_iter": checkpoint_meta["iteration"]
}
checkpoints.append((cp_dir, checkpoint_data)) checkpoints.append((cp_dir, checkpoint_data))
return TrialCheckpointData( return TrialCheckpointData(
params=params, params=params,
results=results, results=results,
progress=progress, progress=progress,
checkpoints=checkpoints) checkpoints=checkpoints,
num_skipped=num_skipped)
def load_data_from_trial_exp_checkpoints( def load_data_from_trial_exp_checkpoints(
@ -713,15 +728,27 @@ def assert_min_num_trials(trials: Iterable[TrialStub], on_driver: int,
def assert_checkpoint_count(experiment_dir_cp: ExperimentDirCheckpoint, def assert_checkpoint_count(experiment_dir_cp: ExperimentDirCheckpoint,
for_driver_trial: int, for_worker_trial: int): for_driver_trial: int, for_worker_trial: int):
# We relaxed the requirements here and also allow
# skipped checkpoints to count. This could be the case if e.g. the trial
# already checkpointed but the driver did not process the last result, yet.
# We also allow up to one un-collected checkpoint.
# Todo: Can we make this stricter?
for trial, trial_cp in experiment_dir_cp.trial_to_cps.items(): for trial, trial_cp in experiment_dir_cp.trial_to_cps.items():
cps = len(trial_cp.checkpoints) cps = len(trial_cp.checkpoints)
num_skipped = trial_cp.num_skipped
if trial.was_on_driver_node: if trial.was_on_driver_node:
assert cps == for_driver_trial, ( assert (
cps == for_driver_trial
or cps + num_skipped == for_driver_trial
or cps == for_driver_trial + 1), (
f"Trial {trial.trial_id} was on driver, " f"Trial {trial.trial_id} was on driver, "
f"but did not observe the expected amount of checkpoints " f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_driver_trial}).") f"({cps} != {for_driver_trial}).")
else: else:
assert cps == for_worker_trial, ( assert (
cps == for_worker_trial
or cps + num_skipped == for_worker_trial
or cps == for_worker_trial + 1), (
f"Trial {trial.trial_id} was not on the driver, " f"Trial {trial.trial_id} was not on the driver, "
f"but did not observe the expected amount of checkpoints " f"but did not observe the expected amount of checkpoints "
f"({cps} != {for_worker_trial}).") f"({cps} != {for_worker_trial}).")
@ -830,13 +857,15 @@ def test_no_sync_down():
f"errored, there is something wrong with restoration. If less, " f"errored, there is something wrong with restoration. If less, "
f"maybe cleanup has not worked, or syncing to driver took place.") f"maybe cleanup has not worked, or syncing to driver took place.")
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow( run_resume_flow(
experiment_name=experiment_name, experiment_name=experiment_name,
indicator_file=indicator_file, indicator_file=indicator_file,
no_syncer=True, no_syncer=True,
upload_dir=None, upload_dir=None,
first_run_time=45, first_run_time=run_time,
second_run_time=45, second_run_time=run_time,
between_experiments_callback=between_experiments, between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments) after_experiments_callback=after_experiments)
@ -922,13 +951,15 @@ def test_ssh_sync():
for trial in experiment_state.trials: for trial in experiment_state.trials:
assert_trial_progressed_training(trial) assert_trial_progressed_training(trial)
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow( run_resume_flow(
experiment_name=experiment_name, experiment_name=experiment_name,
indicator_file=indicator_file, indicator_file=indicator_file,
no_syncer=False, no_syncer=False,
upload_dir=None, upload_dir=None,
first_run_time=55, # More time because of SSH syncing first_run_time=run_time + 10, # More time because of SSH syncing
second_run_time=55, second_run_time=run_time + 10,
between_experiments_callback=between_experiments, between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments) after_experiments_callback=after_experiments)
@ -1048,13 +1079,15 @@ def test_durable_upload(bucket: str):
clear_bucket_contents(bucket) clear_bucket_contents(bucket)
run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180
run_resume_flow( run_resume_flow(
experiment_name=experiment_name, experiment_name=experiment_name,
indicator_file=indicator_file, indicator_file=indicator_file,
no_syncer=False, no_syncer=False,
upload_dir=bucket, upload_dir=bucket,
first_run_time=45, first_run_time=run_time,
second_run_time=45, second_run_time=run_time,
before_experiments_callback=before_experiments, before_experiments_callback=before_experiments,
between_experiments_callback=between_experiments, between_experiments_callback=between_experiments,
after_experiments_callback=after_experiments) after_experiments_callback=after_experiments)
@ -1094,6 +1127,7 @@ if __name__ == "__main__":
def _run_test(variant: str, def _run_test(variant: str,
trainable: str = "function", trainable: str = "function",
run_time: int = 180,
bucket: str = "", bucket: str = "",
cpus_per_trial: int = 2, cpus_per_trial: int = 2,
overwrite_tune_script: Optional[str] = None): overwrite_tune_script: Optional[str] = None):
@ -1103,6 +1137,7 @@ if __name__ == "__main__":
f"{cpus_per_trial} CPUs per trial.") f"{cpus_per_trial} CPUs per trial.")
os.environ["TUNE_TRAINABLE"] = str(trainable) os.environ["TUNE_TRAINABLE"] = str(trainable)
os.environ["TUNE_RUN_TIME"] = str(run_time)
os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial) os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial)
if overwrite_tune_script: if overwrite_tune_script:
@ -1124,9 +1159,11 @@ if __name__ == "__main__":
with open(release_test_out, "wt") as f: with open(release_test_out, "wt") as f:
json.dump(result, f) json.dump(result, f)
run_time = 180 if "rllib" in args.trainable else 90
if not uses_ray_client: if not uses_ray_client:
print("This test will *not* use Ray client.") print("This test will *not* use Ray client.")
_run_test(args.variant, args.trainable, args.bucket, _run_test(args.variant, args.trainable, run_time, args.bucket,
args.cpus_per_trial) args.cpus_per_trial)
else: else:
print("This test will run using Ray client.") print("This test will run using Ray client.")
@ -1150,8 +1187,9 @@ if __name__ == "__main__":
_run_test_remote = ray.remote( _run_test_remote = ray.remote(
resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test) resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test)
ray.get( ray.get(
_run_test_remote.remote(args.variant, args.trainable, args.bucket, _run_test_remote.remote(args.variant, args.trainable, run_time,
args.cpus_per_trial, remote_tune_script)) args.bucket, args.cpus_per_trial,
remote_tune_script))
print(f"Fetching remote release test result file: {release_test_out}") print(f"Fetching remote release test result file: {release_test_out}")
fetch_remote_file_to_local_file(release_test_out, ip, release_test_out) fetch_remote_file_to_local_file(release_test_out, ip, release_test_out)