[tune] Global checkpointing for tune at end (#5499)

This commit is contained in:
Richard Liaw 2019-09-03 15:36:25 -07:00 committed by GitHub
parent 1711e202a3
commit 130b8f21da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 13 deletions

View file

@ -366,7 +366,6 @@ tune.run(
config=dict(env="CartPole-v1"),
stop=dict(training_iteration=10),
local_dir="{checkpoint_dir}",
global_checkpoint_period=0,
checkpoint_freq=1,
max_failures=1,
dict(experiment=kwargs),

View file

@ -74,8 +74,7 @@ def test_ls(start_ray, tmpdir):
name=experiment_name,
stop={"training_iteration": 1},
num_samples=num_samples,
local_dir=str(tmpdir),
global_checkpoint_period=0)
local_dir=str(tmpdir))
columns = ["episode_reward_mean", "training_iteration", "trial_id"]
limit = 2
@ -114,8 +113,7 @@ def test_ls_with_cfg(start_ray, tmpdir):
name=experiment_name,
stop={"training_iteration": 1},
config={"test_variable": tune.grid_search(list(range(5)))},
local_dir=str(tmpdir),
global_checkpoint_period=0)
local_dir=str(tmpdir))
columns = [CONFIG_PREFIX + "test_variable", "trial_id"]
limit = 4
@ -138,8 +136,7 @@ def test_lsx(start_ray, tmpdir):
name=experiment_name,
stop={"training_iteration": 1},
num_samples=1,
local_dir=project_path,
global_checkpoint_period=0)
local_dir=project_path)
limit = 2
with Capturing() as output:

View file

@ -31,7 +31,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
def run_test_exp(self):
self.ea = run(
MyTrainableClass,
global_checkpoint_period=0,
name=self.test_name,
local_dir=self.test_dir,
stop={"training_iteration": 1},
@ -85,7 +84,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
def testIgnoreOtherExperiment(self):
analysis = run(
MyTrainableClass,
global_checkpoint_period=0,
name="test_example",
local_dir=self.test_dir,
return_trials=False,
@ -111,7 +109,6 @@ class AnalysisSuite(unittest.TestCase):
def run_test_exp(self, test_name=None):
run(MyTrainableClass,
global_checkpoint_period=0,
name=test_name,
local_dir=self.test_dir,
return_trials=False,

View file

@ -240,17 +240,22 @@ class TrialRunner(object):
else:
logger.info("TrialRunner resumed, ignoring new add_experiment.")
def checkpoint(self):
def checkpoint(self, force=False):
"""Saves execution state to `self._local_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated. Throttle depends on self._checkpoint_period.
Args:
force (bool): Forces a checkpoint despite checkpoint_period.
"""
if not self._local_checkpoint_dir:
return
if time.time() - self._last_checkpoint_time < self._checkpoint_period:
now = time.time()
if now - self._last_checkpoint_time < self._checkpoint_period and (
not force):
return
self._last_checkpoint_time = time.time()
self._last_checkpoint_time = now
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),

View file

@ -247,6 +247,11 @@ def run(run_or_experiment,
print(runner.debug_string())
last_debug = time.time()
try:
runner.checkpoint(force=True)
except Exception:
logger.exception("Trial Runner checkpointing failed.")
if verbose:
print(runner.debug_string(max_debug=99999))