[tune] Add leading zeros to checkpoint directory (#14152)

* [tune] Add leading zeros to checkpoint directory

* Fix exp analysis tests/support string indices

* Fix tests

* RLLib tests
This commit is contained in:
Kai Fricke 2021-03-01 12:12:19 +01:00 committed by GitHub
parent 8572774304
commit 7f9340bb2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 13 additions and 11 deletions

View file

@ -45,7 +45,7 @@ def train_ppo_model():
# Train for one iteration # Train for one iteration
trainer.train() trainer.train()
trainer.save("/tmp/rllib_checkpoint") trainer.save("/tmp/rllib_checkpoint")
return "/tmp/rllib_checkpoint/checkpoint_1/checkpoint-1" return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
checkpoint_path = train_ppo_model() checkpoint_path = train_ppo_model()

View file

@ -108,14 +108,15 @@ class ExperimentAnalysisSuite(unittest.TestCase):
best_trial = self.ea.get_best_trial(self.metric, mode="max") best_trial = self.ea.get_best_trial(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial) checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
logdir = self.ea.get_best_logdir(self.metric, mode="max") logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint") expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert checkpoints_metrics[0][0] == expected_path assert checkpoints_metrics[0][0] == expected_path
assert checkpoints_metrics[0][1] == 1 assert checkpoints_metrics[0][1] == 1
def testGetTrialCheckpointsPathsByPath(self): def testGetTrialCheckpointsPathsByPath(self):
logdir = self.ea.get_best_logdir(self.metric, mode="max") logdir = self.ea.get_best_logdir(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir) checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
expected_path = os.path.join(logdir, "checkpoint_1/", "checkpoint") expected_path = os.path.join(logdir, "checkpoint_000001/",
"checkpoint")
assert checkpoints_metrics[0][0] == expected_path assert checkpoints_metrics[0][0] == expected_path
assert checkpoints_metrics[0][1] == 1 assert checkpoints_metrics[0][1] == 1
@ -123,7 +124,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
best_trial = self.ea.get_best_trial(self.metric, mode="max") best_trial = self.ea.get_best_trial(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric) paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
logdir = self.ea.get_best_logdir(self.metric, mode="max") logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint") expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path assert paths[0][0] == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"] assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
@ -131,7 +132,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
best_trial = self.ea.get_best_trial(self.metric, mode="max") best_trial = self.ea.get_best_trial(self.metric, mode="max")
logdir = self.ea.get_best_logdir(self.metric, mode="max") logdir = self.ea.get_best_logdir(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric) paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint") expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path assert paths[0][0] == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"] assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]

View file

@ -87,7 +87,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
self.assertTrue(os.path.isdir(abs_trial_dir)) self.assertTrue(os.path.isdir(abs_trial_dir))
self.assertTrue( self.assertTrue(
os.path.isfile( os.path.isfile(
os.path.join(abs_trial_dir, "checkpoint_1/checkpoint-1"))) os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint-1")))
def _restore(self, exp_name, local_dir, absolute_local_dir): def _restore(self, exp_name, local_dir, absolute_local_dir):
trial_name, abs_trial_dir = self._get_trial_dir( trial_name, abs_trial_dir = self._get_trial_dir(
@ -95,7 +95,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
checkpoint_path = os.path.join( checkpoint_path = os.path.join(
local_dir, exp_name, trial_name, local_dir, exp_name, trial_name,
"checkpoint_1/checkpoint-1") # Relative checkpoint path "checkpoint_000001/checkpoint-1") # Relative checkpoint path
# The file tune would find. The absolute checkpoint path. # The file tune would find. The absolute checkpoint path.
tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path)) tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path))

View file

@ -101,14 +101,15 @@ class TrainableUtil:
Args: Args:
checkpoint_dir (str): Path to checkpoint directory. checkpoint_dir (str): Path to checkpoint directory.
index (str): A subdirectory will be created index (int|str): A subdirectory will be created
at the checkpoint directory named 'checkpoint_{index}'. at the checkpoint directory named 'checkpoint_{index}'.
override (bool): Deletes checkpoint_dir before creating override (bool): Deletes checkpoint_dir before creating
a new one. a new one.
""" """
suffix = "checkpoint" suffix = "checkpoint"
if index is not None: if index is not None:
suffix += "_{}".format(index) suffix += f"_{index:06d}" if isinstance(index,
int) else f"_{index}"
checkpoint_dir = os.path.join(checkpoint_dir, suffix) checkpoint_dir = os.path.join(checkpoint_dir, suffix)
if override and os.path.exists(checkpoint_dir): if override and os.path.exists(checkpoint_dir):

View file

@ -10,7 +10,7 @@ Example usage for training:
rllib train --run DQN --env CartPole-v0 rllib train --run DQN --env CartPole-v0
Example usage for rollout: Example usage for rollout:
rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN rllib rollout /trial_dir/checkpoint_000001/checkpoint-1 --run DQN
""" """

View file

@ -40,7 +40,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
"}' --stop='{\"training_iteration\": 1}'" + "}' --stop='{\"training_iteration\": 1}'" +
" --env={}".format(env)) " --env={}".format(env))
checkpoint_path = os.popen("ls {}/default/*/checkpoint_1/" checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
"checkpoint-1".format(tmp_dir)).read()[:-1] "checkpoint-1".format(tmp_dir)).read()[:-1]
if not os.path.exists(checkpoint_path): if not os.path.exists(checkpoint_path):
sys.exit(1) sys.exit(1)