mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Global checkpointing for tune at end (#5499)
This commit is contained in:
parent
1711e202a3
commit
130b8f21da
5 changed files with 16 additions and 13 deletions
|
@ -366,7 +366,6 @@ tune.run(
|
|||
config=dict(env="CartPole-v1"),
|
||||
stop=dict(training_iteration=10),
|
||||
local_dir="{checkpoint_dir}",
|
||||
global_checkpoint_period=0,
|
||||
checkpoint_freq=1,
|
||||
max_failures=1,
|
||||
dict(experiment=kwargs),
|
||||
|
|
|
@ -74,8 +74,7 @@ def test_ls(start_ray, tmpdir):
|
|||
name=experiment_name,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=num_samples,
|
||||
local_dir=str(tmpdir),
|
||||
global_checkpoint_period=0)
|
||||
local_dir=str(tmpdir))
|
||||
|
||||
columns = ["episode_reward_mean", "training_iteration", "trial_id"]
|
||||
limit = 2
|
||||
|
@ -114,8 +113,7 @@ def test_ls_with_cfg(start_ray, tmpdir):
|
|||
name=experiment_name,
|
||||
stop={"training_iteration": 1},
|
||||
config={"test_variable": tune.grid_search(list(range(5)))},
|
||||
local_dir=str(tmpdir),
|
||||
global_checkpoint_period=0)
|
||||
local_dir=str(tmpdir))
|
||||
|
||||
columns = [CONFIG_PREFIX + "test_variable", "trial_id"]
|
||||
limit = 4
|
||||
|
@ -138,8 +136,7 @@ def test_lsx(start_ray, tmpdir):
|
|||
name=experiment_name,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=1,
|
||||
local_dir=project_path,
|
||||
global_checkpoint_period=0)
|
||||
local_dir=project_path)
|
||||
|
||||
limit = 2
|
||||
with Capturing() as output:
|
||||
|
|
|
@ -31,7 +31,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
def run_test_exp(self):
|
||||
self.ea = run(
|
||||
MyTrainableClass,
|
||||
global_checkpoint_period=0,
|
||||
name=self.test_name,
|
||||
local_dir=self.test_dir,
|
||||
stop={"training_iteration": 1},
|
||||
|
@ -85,7 +84,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
def testIgnoreOtherExperiment(self):
|
||||
analysis = run(
|
||||
MyTrainableClass,
|
||||
global_checkpoint_period=0,
|
||||
name="test_example",
|
||||
local_dir=self.test_dir,
|
||||
return_trials=False,
|
||||
|
@ -111,7 +109,6 @@ class AnalysisSuite(unittest.TestCase):
|
|||
|
||||
def run_test_exp(self, test_name=None):
|
||||
run(MyTrainableClass,
|
||||
global_checkpoint_period=0,
|
||||
name=test_name,
|
||||
local_dir=self.test_dir,
|
||||
return_trials=False,
|
||||
|
|
|
@ -240,17 +240,22 @@ class TrialRunner(object):
|
|||
else:
|
||||
logger.info("TrialRunner resumed, ignoring new add_experiment.")
|
||||
|
||||
def checkpoint(self):
|
||||
def checkpoint(self, force=False):
|
||||
"""Saves execution state to `self._local_checkpoint_dir`.
|
||||
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated. Throttle depends on self._checkpoint_period.
|
||||
|
||||
Args:
|
||||
force (bool): Forces a checkpoint despite checkpoint_period.
|
||||
"""
|
||||
if not self._local_checkpoint_dir:
|
||||
return
|
||||
if time.time() - self._last_checkpoint_time < self._checkpoint_period:
|
||||
now = time.time()
|
||||
if now - self._last_checkpoint_time < self._checkpoint_period and (
|
||||
not force):
|
||||
return
|
||||
self._last_checkpoint_time = time.time()
|
||||
self._last_checkpoint_time = now
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
|
|
|
@ -247,6 +247,11 @@ def run(run_or_experiment,
|
|||
print(runner.debug_string())
|
||||
last_debug = time.time()
|
||||
|
||||
try:
|
||||
runner.checkpoint(force=True)
|
||||
except Exception:
|
||||
logger.exception("Trial Runner checkpointing failed.")
|
||||
|
||||
if verbose:
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue