[tune][minor] Reduce checkpointing frequency (#4859)

This commit is contained in:
Richard Liaw 2019-07-06 00:54:24 -07:00 committed by GitHub
parent 4b56a5eb27
commit c3e9d94b18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 41 deletions

View file

@ -272,7 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
cluster.wait_for_nodes()
dirpath = str(tmpdir)
runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath)
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0)
kwargs = {
"stopping_criterion": {
"training_iteration": 2
@ -359,15 +359,16 @@ from ray import tune
ray.init(redis_address="{redis_address}")
kwargs = dict(
run="PG",
env="CartPole-v1",
tune.run(
"PG",
name="experiment",
config=dict(env="CartPole-v1"),
stop=dict(training_iteration=10),
local_dir="{checkpoint_dir}",
global_checkpoint_period=0,
checkpoint_freq=1,
max_failures=1)
tune.run_experiments(
max_failures=1,
dict(experiment=kwargs),
raise_on_failed_trial=False)
""".format(
@ -449,15 +450,14 @@ ray.init(redis_address="{redis_address}")
{fail_class_code}
kwargs = dict(
run={fail_class},
tune.run(
{fail_class},
name="experiment",
stop=dict(training_iteration=5),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
max_failures=1)
tune.run_experiments(
dict(experiment=kwargs),
global_checkpoint_period=0,
max_failures=1,
raise_on_failed_trial=False)
""".format(
redis_address=cluster.redis_address,

View file

@ -67,16 +67,13 @@ def test_ls(start_ray, tmpdir):
experiment_name = "test_ls"
experiment_path = os.path.join(str(tmpdir), experiment_name)
num_samples = 3
tune.run_experiments({
experiment_name: {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"num_samples": num_samples,
"local_dir": str(tmpdir)
}
})
tune.run(
"__fake",
name=experiment_name,
stop={"training_iteration": 1},
num_samples=num_samples,
local_dir=str(tmpdir),
global_checkpoint_period=0)
columns = ["status", "episode_reward_mean", "training_iteration"]
limit = 2
@ -104,16 +101,13 @@ def test_lsx(start_ray, tmpdir):
num_experiments = 3
for i in range(num_experiments):
experiment_name = "test_lsx{}".format(i)
tune.run_experiments({
experiment_name: {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"num_samples": 1,
"local_dir": project_path
}
})
tune.run(
"__fake",
name=experiment_name,
stop={"training_iteration": 1},
num_samples=1,
local_dir=project_path,
global_checkpoint_period=0)
limit = 2
with Capturing() as output:

View file

@ -32,6 +32,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
def run_test_exp(self):
self.ea = run(
MyTrainableClass,
global_checkpoint_period=0,
name=self.test_name,
local_dir=self.test_dir,
return_trials=False,

View file

@ -2086,7 +2086,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
trials = [
Trial(
"__fake",
@ -2145,8 +2145,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(
Trial(
"__fake",
@ -2200,7 +2199,7 @@ class TrialRunnerTest(unittest.TestCase):
},
checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(trial)
for i in range(5):
runner.step()
@ -2221,7 +2220,7 @@ class TrialRunnerTest(unittest.TestCase):
ray.init()
trial = Trial("__fake", checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir)
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(trial)
for i in range(5):
runner.step()

View file

@ -111,6 +111,7 @@ class TrialRunner(object):
resume=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
checkpoint_period=10,
trial_executor=None):
"""Initializes a new TrialRunner.
@ -174,6 +175,8 @@ class TrialRunner(object):
logger.info("Starting a new experiment.")
self._start_time = time.time()
self._last_checkpoint_time = -float("inf")
self._checkpoint_period = checkpoint_period
self._session_str = datetime.fromtimestamp(
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
@ -235,18 +238,20 @@ class TrialRunner(object):
"""Saves execution state to `self._local_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated.
is instantiated. Throttle depends on self._checkpoint_period.
"""
if not self._local_checkpoint_dir:
return
if time.time() - self._last_checkpoint_time < self._checkpoint_period:
return
self._last_checkpoint_time = time.time()
runner_state = {
"checkpoints": list(
self.trial_executor.get_checkpoints().values()),
"runner_data": self.__getstate__(),
"stats": {
"start_time": self._start_time,
"timestamp": time.time()
"timestamp": self._last_checkpoint_time
}
}
tmp_file_name = os.path.join(self._local_checkpoint_dir,

View file

@ -49,6 +49,7 @@ def run(run_or_experiment,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
global_checkpoint_period=10,
export_formats=None,
max_failures=3,
restore=None,
@ -113,6 +114,9 @@ def run(run_or_experiment,
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
global_checkpoint_period (int): Seconds between global checkpointing.
This does not affect `checkpoint_freq`, which specifies frequency
for individual trials.
export_formats (list): List of formats that exported at the end of
the experiment. Default is None.
max_failures (int): Try to recover a trial from its last
@ -212,6 +216,7 @@ def run(run_or_experiment,
local_checkpoint_dir=experiment.checkpoint_dir,
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
checkpoint_period=global_checkpoint_period,
resume=resume,
launch_web_server=with_server,
server_port=server_port,