diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 53416d2f4..8864a6107 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -283,7 +283,7 @@ class RayTrialExecutor(TrialExecutor): def _setup_remote_runner(self, trial): trial.init_logdir() # We checkpoint metadata here to try mitigating logdir duplication - self.try_checkpoint_metadata(trial) + self._trials_to_cache.add(trial) logger_creator = partial(noop_logger_creator, logdir=trial.logdir) if self._reuse_actors and len(self._cached_actor_pg) > 0: diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index 037a531b6..e3e392c94 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -510,8 +510,8 @@ class TrialRunnerTest3(unittest.TestCase): runner2.step() # Process save self.assertRaises(TuneError, runner2.step) - def testTrialNoSave(self): - """Check that non-checkpointing trials are not saved.""" + def testTrialNoCheckpointSave(self): + """Check that non-checkpointing trials *are* saved.""" os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=3) @@ -557,7 +557,7 @@ class TrialRunnerTest3(unittest.TestCase): runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) self.assertTrue( - not runner2.get_trial("pending").has_reported_at_least_once) + runner2.get_trial("pending").has_reported_at_least_once) runner2.step() def testCheckpointWithFunction(self): diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 840caa9d3..35379ea8b 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -151,25 +151,10 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta): trial.status, status) trial.set_status(status) if status in [Trial.TERMINATED, Trial.ERROR]: - self.try_checkpoint_metadata(trial) - - def try_checkpoint_metadata(self, trial: Trial) -> None: - """Checkpoints trial metadata. - - Args: - trial (Trial): Trial to checkpoint. - """ - if trial.checkpoint.storage == Checkpoint.MEMORY: - logger.debug("Trial %s: Not saving data for memory checkpoint.", - trial) - return - try: - logger.debug("Trial %s: Saving trial metadata.", trial) - # Lazy cache trials self._trials_to_cache.add(trial) - except Exception: - logger.exception("Trial %s: Error checkpointing trial metadata.", - trial) + + def mark_trial_to_checkpoint(self, trial: Trial) -> None: + self._trials_to_cache.add(trial) def get_checkpoints(self) -> Dict[str, str]: """Returns a copy of mapping of the trial ID to pickled metadata.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 858ef820d..72ea8f01c 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -753,7 +753,7 @@ class TrialRunner: self._live_trials.add(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self, trial) - self.trial_executor.try_checkpoint_metadata(trial) + self.trial_executor.mark_trial_to_checkpoint(trial) def debug_string(self, delim="\n"): from ray.tune.progress_reporter import trial_progress_str @@ -987,6 +987,8 @@ class TrialRunner: if not is_duplicate: trial.update_last_result(result) + # Include in next experiment checkpoint + self.trial_executor.mark_trial_to_checkpoint(trial) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that @@ -1089,7 +1091,8 @@ class TrialRunner: trial=trial, checkpoint=trial.saving_to) trial.on_checkpoint(trial.saving_to) - self.trial_executor.try_checkpoint_metadata(trial) + if trial.checkpoint.storage != Checkpoint.MEMORY: + self.trial_executor.mark_trial_to_checkpoint(trial) except Exception: logger.exception("Trial %s: Error handling checkpoint %s", trial, checkpoint_value) diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index ffcdc603b..6e6935f5f 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -160,8 +160,8 @@ NIGHTLY_TESTS = { "aws_no_sync_down", "aws_ssh_sync", "aws_durable_upload", - # "aws_durable_upload_rllib_str", - # "aws_durable_upload_rllib_trainer", + "aws_durable_upload_rllib_str", + "aws_durable_upload_rllib_trainer", "gcp_k8s_durable_upload", ], "~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [ diff --git a/release/tune_tests/cloud_tests/workloads/_tune_script.py b/release/tune_tests/cloud_tests/workloads/_tune_script.py index e9dbd4182..b73a6a2e1 100644 --- a/release/tune_tests/cloud_tests/workloads/_tune_script.py +++ b/release/tune_tests/cloud_tests/workloads/_tune_script.py @@ -33,13 +33,8 @@ def fn_trainable(config, checkpoint_dir=None): class RLLibCallback(DefaultCallbacks): - def __init__(self): - super(RLLibCallback, self).__init__() - self.internal_iter = 0 - def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: - result["internal_iter"] = self.internal_iter - self.internal_iter += 1 + result["internal_iter"] = result["training_iteration"] class IndicatorCallback(tune.Callback): @@ -60,7 +55,7 @@ def run_tune(no_syncer: bool, if trainable == "function": train = fn_trainable config = { - "max_iterations": 30, + "max_iterations": 100, "sleep_time": 5, "checkpoint_freq": 2, "score_multiplied": tune.randint(0, 100), @@ -80,8 +75,10 @@ def run_tune(no_syncer: bool, } kwargs = { "stop": { - "training_iteration": 10 + "training_iteration": 100 }, + "checkpoint_freq": 2, + "checkpoint_at_end": True, } else: raise RuntimeError(f"Unknown trainable: {trainable}") diff --git a/release/tune_tests/cloud_tests/workloads/run_cloud_test.py b/release/tune_tests/cloud_tests/workloads/run_cloud_test.py index ca3f76bfb..7ecc7cede 100644 --- a/release/tune_tests/cloud_tests/workloads/run_cloud_test.py +++ b/release/tune_tests/cloud_tests/workloads/run_cloud_test.py @@ -43,6 +43,7 @@ import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import ray +import ray.cloudpickle as pickle from ray.tune.trial_runner import find_newest_experiment_checkpoint from ray.tune.utils.serialization import TuneFunctionDecoder @@ -118,6 +119,7 @@ class TrialCheckpointData: results: List[Dict[str, Any]] progress: List[Dict[str, Any]] checkpoints: List[Tuple[str, Dict[Any, Any]]] + num_skipped: int # Utility functions @@ -584,6 +586,7 @@ def load_trial_checkpoint_data(trial_dir: str, progress = [] checkpoints = [] + num_skipped = 0 for cp_dir in sorted(os.listdir(trial_dir)): if not cp_dir.startswith("checkpoint_"): continue @@ -601,20 +604,32 @@ def load_trial_checkpoint_data(trial_dir: str, print(f"Skipping unobserved checkpoint: {cp_full_dir} as " f"{checkpoint_num} > " f"{node_trial.last_result['internal_iter']}") + num_skipped += 1 continue except ValueError: # temporary checkpoint continue - with open(os.path.join(cp_full_dir, "checkpoint.json"), "rt") as f: - checkpoint_data = json.load(f) + json_path = os.path.join(cp_full_dir, "checkpoint.json") + if os.path.exists(json_path): + with open(json_path, "rt") as f: + checkpoint_data = json.load(f) + else: + meta_path = os.path.join( + cp_full_dir, f"checkpoint-{checkpoint_num}.tune_metadata") + with open(meta_path, "rb") as f: + checkpoint_meta = pickle.load(f) + checkpoint_data = { + "internal_iter": checkpoint_meta["iteration"] + } checkpoints.append((cp_dir, checkpoint_data)) return TrialCheckpointData( params=params, results=results, progress=progress, - checkpoints=checkpoints) + checkpoints=checkpoints, + num_skipped=num_skipped) def load_data_from_trial_exp_checkpoints( @@ -713,18 +728,30 @@ def assert_min_num_trials(trials: Iterable[TrialStub], on_driver: int, def assert_checkpoint_count(experiment_dir_cp: ExperimentDirCheckpoint, for_driver_trial: int, for_worker_trial: int): + # We relaxed the requirements here and also allow + # skipped checkpoints to count. This could be the case if e.g. the trial + # already checkpointed but the driver did not process the last result, yet. + # We also allow up to one un-collected checkpoint. + # Todo: Can we make this stricter? for trial, trial_cp in experiment_dir_cp.trial_to_cps.items(): cps = len(trial_cp.checkpoints) + num_skipped = trial_cp.num_skipped if trial.was_on_driver_node: - assert cps == for_driver_trial, ( - f"Trial {trial.trial_id} was on driver, " - f"but did not observe the expected amount of checkpoints " - f"({cps} != {for_driver_trial}).") + assert ( + cps == for_driver_trial + or cps + num_skipped == for_driver_trial + or cps == for_driver_trial + 1), ( + f"Trial {trial.trial_id} was on driver, " + f"but did not observe the expected amount of checkpoints " + f"({cps} != {for_driver_trial}).") else: - assert cps == for_worker_trial, ( - f"Trial {trial.trial_id} was not on the driver, " - f"but did not observe the expected amount of checkpoints " - f"({cps} != {for_worker_trial}).") + assert ( + cps == for_worker_trial + or cps + num_skipped == for_worker_trial + or cps == for_worker_trial + 1), ( + f"Trial {trial.trial_id} was not on the driver, " + f"but did not observe the expected amount of checkpoints " + f"({cps} != {for_worker_trial}).") def assert_trial_progressed_training(trial: TrialStub): @@ -830,13 +857,15 @@ def test_no_sync_down(): f"errored, there is something wrong with restoration. If less, " f"maybe cleanup has not worked, or syncing to driver took place.") + run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180 + run_resume_flow( experiment_name=experiment_name, indicator_file=indicator_file, no_syncer=True, upload_dir=None, - first_run_time=45, - second_run_time=45, + first_run_time=run_time, + second_run_time=run_time, between_experiments_callback=between_experiments, after_experiments_callback=after_experiments) @@ -922,13 +951,15 @@ def test_ssh_sync(): for trial in experiment_state.trials: assert_trial_progressed_training(trial) + run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180 + run_resume_flow( experiment_name=experiment_name, indicator_file=indicator_file, no_syncer=False, upload_dir=None, - first_run_time=55, # More time because of SSH syncing - second_run_time=55, + first_run_time=run_time + 10, # More time because of SSH syncing + second_run_time=run_time + 10, between_experiments_callback=between_experiments, after_experiments_callback=after_experiments) @@ -1048,13 +1079,15 @@ def test_durable_upload(bucket: str): clear_bucket_contents(bucket) + run_time = int(os.getenv("TUNE_RUN_TIME", "180")) or 180 + run_resume_flow( experiment_name=experiment_name, indicator_file=indicator_file, no_syncer=False, upload_dir=bucket, - first_run_time=45, - second_run_time=45, + first_run_time=run_time, + second_run_time=run_time, before_experiments_callback=before_experiments, between_experiments_callback=between_experiments, after_experiments_callback=after_experiments) @@ -1094,6 +1127,7 @@ if __name__ == "__main__": def _run_test(variant: str, trainable: str = "function", + run_time: int = 180, bucket: str = "", cpus_per_trial: int = 2, overwrite_tune_script: Optional[str] = None): @@ -1103,6 +1137,7 @@ if __name__ == "__main__": f"{cpus_per_trial} CPUs per trial.") os.environ["TUNE_TRAINABLE"] = str(trainable) + os.environ["TUNE_RUN_TIME"] = str(run_time) os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial) if overwrite_tune_script: @@ -1124,9 +1159,11 @@ if __name__ == "__main__": with open(release_test_out, "wt") as f: json.dump(result, f) + run_time = 180 if "rllib" in args.trainable else 90 + if not uses_ray_client: print("This test will *not* use Ray client.") - _run_test(args.variant, args.trainable, args.bucket, + _run_test(args.variant, args.trainable, run_time, args.bucket, args.cpus_per_trial) else: print("This test will run using Ray client.") @@ -1150,8 +1187,9 @@ if __name__ == "__main__": _run_test_remote = ray.remote( resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test) ray.get( - _run_test_remote.remote(args.variant, args.trainable, args.bucket, - args.cpus_per_trial, remote_tune_script)) + _run_test_remote.remote(args.variant, args.trainable, run_time, + args.bucket, args.cpus_per_trial, + remote_tune_script)) print(f"Fetching remote release test result file: {release_test_out}") fetch_remote_file_to_local_file(release_test_out, ip, release_test_out)