From 130b8f21da4fb5383b079493faaea5d81065b772 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 3 Sep 2019 15:36:25 -0700 Subject: [PATCH] [tune] Global checkpointing for tune at end (#5499) --- python/ray/tune/tests/test_cluster.py | 1 - python/ray/tune/tests/test_commands.py | 9 +++------ python/ray/tune/tests/test_experiment_analysis.py | 3 --- python/ray/tune/trial_runner.py | 11 ++++++++--- python/ray/tune/tune.py | 5 +++++ 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index c11afa509..ad18079f8 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -366,7 +366,6 @@ tune.run( config=dict(env="CartPole-v1"), stop=dict(training_iteration=10), local_dir="{checkpoint_dir}", - global_checkpoint_period=0, checkpoint_freq=1, max_failures=1, dict(experiment=kwargs), diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index 94322dba0..f640260be 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -74,8 +74,7 @@ def test_ls(start_ray, tmpdir): name=experiment_name, stop={"training_iteration": 1}, num_samples=num_samples, - local_dir=str(tmpdir), - global_checkpoint_period=0) + local_dir=str(tmpdir)) columns = ["episode_reward_mean", "training_iteration", "trial_id"] limit = 2 @@ -114,8 +113,7 @@ def test_ls_with_cfg(start_ray, tmpdir): name=experiment_name, stop={"training_iteration": 1}, config={"test_variable": tune.grid_search(list(range(5)))}, - local_dir=str(tmpdir), - global_checkpoint_period=0) + local_dir=str(tmpdir)) columns = [CONFIG_PREFIX + "test_variable", "trial_id"] limit = 4 @@ -138,8 +136,7 @@ def test_lsx(start_ray, tmpdir): name=experiment_name, stop={"training_iteration": 1}, num_samples=1, - local_dir=project_path, - global_checkpoint_period=0) + local_dir=project_path) limit = 2 with Capturing() as output: diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 99d9c8e1c..d9f7e766e 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -31,7 +31,6 @@ class ExperimentAnalysisSuite(unittest.TestCase): def run_test_exp(self): self.ea = run( MyTrainableClass, - global_checkpoint_period=0, name=self.test_name, local_dir=self.test_dir, stop={"training_iteration": 1}, @@ -85,7 +84,6 @@ class ExperimentAnalysisSuite(unittest.TestCase): def testIgnoreOtherExperiment(self): analysis = run( MyTrainableClass, - global_checkpoint_period=0, name="test_example", local_dir=self.test_dir, return_trials=False, @@ -111,7 +109,6 @@ class AnalysisSuite(unittest.TestCase): def run_test_exp(self, test_name=None): run(MyTrainableClass, - global_checkpoint_period=0, name=test_name, local_dir=self.test_dir, return_trials=False, diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 94666985c..0b03957a0 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -240,17 +240,22 @@ class TrialRunner(object): else: logger.info("TrialRunner resumed, ignoring new add_experiment.") - def checkpoint(self): + def checkpoint(self, force=False): """Saves execution state to `self._local_checkpoint_dir`. Overwrites the current session checkpoint, which starts when self is instantiated. Throttle depends on self._checkpoint_period. + + Args: + force (bool): Forces a checkpoint despite checkpoint_period. """ if not self._local_checkpoint_dir: return - if time.time() - self._last_checkpoint_time < self._checkpoint_period: + now = time.time() + if now - self._last_checkpoint_time < self._checkpoint_period and ( + not force): return - self._last_checkpoint_time = time.time() + self._last_checkpoint_time = now runner_state = { "checkpoints": list( self.trial_executor.get_checkpoints().values()), diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 049306305..881751823 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -247,6 +247,11 @@ def run(run_or_experiment, print(runner.debug_string()) last_debug = time.time() + try: + runner.checkpoint(force=True) + except Exception: + logger.exception("Trial Runner checkpointing failed.") + if verbose: print(runner.debug_string(max_debug=99999))