mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune][minor] Reduce checkpointing frequency (#4859)
This commit is contained in:
parent
4b56a5eb27
commit
c3e9d94b18
6 changed files with 45 additions and 41 deletions
|
@ -272,7 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
|||
cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath)
|
||||
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
|
@ -359,15 +359,16 @@ from ray import tune
|
|||
|
||||
ray.init(redis_address="{redis_address}")
|
||||
|
||||
kwargs = dict(
|
||||
run="PG",
|
||||
env="CartPole-v1",
|
||||
|
||||
tune.run(
|
||||
"PG",
|
||||
name="experiment",
|
||||
config=dict(env="CartPole-v1"),
|
||||
stop=dict(training_iteration=10),
|
||||
local_dir="{checkpoint_dir}",
|
||||
global_checkpoint_period=0,
|
||||
checkpoint_freq=1,
|
||||
max_failures=1)
|
||||
|
||||
tune.run_experiments(
|
||||
max_failures=1,
|
||||
dict(experiment=kwargs),
|
||||
raise_on_failed_trial=False)
|
||||
""".format(
|
||||
|
@ -449,15 +450,14 @@ ray.init(redis_address="{redis_address}")
|
|||
|
||||
{fail_class_code}
|
||||
|
||||
kwargs = dict(
|
||||
run={fail_class},
|
||||
tune.run(
|
||||
{fail_class},
|
||||
name="experiment",
|
||||
stop=dict(training_iteration=5),
|
||||
local_dir="{checkpoint_dir}",
|
||||
checkpoint_freq=1,
|
||||
max_failures=1)
|
||||
|
||||
tune.run_experiments(
|
||||
dict(experiment=kwargs),
|
||||
global_checkpoint_period=0,
|
||||
max_failures=1,
|
||||
raise_on_failed_trial=False)
|
||||
""".format(
|
||||
redis_address=cluster.redis_address,
|
||||
|
|
|
@ -67,16 +67,13 @@ def test_ls(start_ray, tmpdir):
|
|||
experiment_name = "test_ls"
|
||||
experiment_path = os.path.join(str(tmpdir), experiment_name)
|
||||
num_samples = 3
|
||||
tune.run_experiments({
|
||||
experiment_name: {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"num_samples": num_samples,
|
||||
"local_dir": str(tmpdir)
|
||||
}
|
||||
})
|
||||
tune.run(
|
||||
"__fake",
|
||||
name=experiment_name,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=num_samples,
|
||||
local_dir=str(tmpdir),
|
||||
global_checkpoint_period=0)
|
||||
|
||||
columns = ["status", "episode_reward_mean", "training_iteration"]
|
||||
limit = 2
|
||||
|
@ -104,16 +101,13 @@ def test_lsx(start_ray, tmpdir):
|
|||
num_experiments = 3
|
||||
for i in range(num_experiments):
|
||||
experiment_name = "test_lsx{}".format(i)
|
||||
tune.run_experiments({
|
||||
experiment_name: {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"num_samples": 1,
|
||||
"local_dir": project_path
|
||||
}
|
||||
})
|
||||
tune.run(
|
||||
"__fake",
|
||||
name=experiment_name,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=1,
|
||||
local_dir=project_path,
|
||||
global_checkpoint_period=0)
|
||||
|
||||
limit = 2
|
||||
with Capturing() as output:
|
||||
|
|
|
@ -32,6 +32,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
def run_test_exp(self):
|
||||
self.ea = run(
|
||||
MyTrainableClass,
|
||||
global_checkpoint_period=0,
|
||||
name=self.test_name,
|
||||
local_dir=self.test_dir,
|
||||
return_trials=False,
|
||||
|
|
|
@ -2086,7 +2086,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
trials = [
|
||||
Trial(
|
||||
"__fake",
|
||||
|
@ -2145,8 +2145,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
"__fake",
|
||||
|
@ -2200,7 +2199,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
},
|
||||
checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
|
@ -2221,7 +2220,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
ray.init()
|
||||
trial = Trial("__fake", checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
|
|
|
@ -111,6 +111,7 @@ class TrialRunner(object):
|
|||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True,
|
||||
checkpoint_period=10,
|
||||
trial_executor=None):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
|
@ -174,6 +175,8 @@ class TrialRunner(object):
|
|||
logger.info("Starting a new experiment.")
|
||||
|
||||
self._start_time = time.time()
|
||||
self._last_checkpoint_time = -float("inf")
|
||||
self._checkpoint_period = checkpoint_period
|
||||
self._session_str = datetime.fromtimestamp(
|
||||
self._start_time).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
@ -235,18 +238,20 @@ class TrialRunner(object):
|
|||
"""Saves execution state to `self._local_checkpoint_dir`.
|
||||
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated.
|
||||
is instantiated. Throttle depends on self._checkpoint_period.
|
||||
"""
|
||||
if not self._local_checkpoint_dir:
|
||||
return
|
||||
|
||||
if time.time() - self._last_checkpoint_time < self._checkpoint_period:
|
||||
return
|
||||
self._last_checkpoint_time = time.time()
|
||||
runner_state = {
|
||||
"checkpoints": list(
|
||||
self.trial_executor.get_checkpoints().values()),
|
||||
"runner_data": self.__getstate__(),
|
||||
"stats": {
|
||||
"start_time": self._start_time,
|
||||
"timestamp": time.time()
|
||||
"timestamp": self._last_checkpoint_time
|
||||
}
|
||||
}
|
||||
tmp_file_name = os.path.join(self._local_checkpoint_dir,
|
||||
|
|
|
@ -49,6 +49,7 @@ def run(run_or_experiment,
|
|||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
global_checkpoint_period=10,
|
||||
export_formats=None,
|
||||
max_failures=3,
|
||||
restore=None,
|
||||
|
@ -113,6 +114,9 @@ def run(run_or_experiment,
|
|||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
global_checkpoint_period (int): Seconds between global checkpointing.
|
||||
This does not affect `checkpoint_freq`, which specifies frequency
|
||||
for individual trials.
|
||||
export_formats (list): List of formats that exported at the end of
|
||||
the experiment. Default is None.
|
||||
max_failures (int): Try to recover a trial from its last
|
||||
|
@ -212,6 +216,7 @@ def run(run_or_experiment,
|
|||
local_checkpoint_dir=experiment.checkpoint_dir,
|
||||
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
checkpoint_period=global_checkpoint_period,
|
||||
resume=resume,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
|
|
Loading…
Add table
Reference in a new issue