mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Add leading zeros to checkpoint directory (#14152)
* [tune] Add leading zeros to checkpoint directory * Fix exp analysis tests/support string indices * Fix tests * RLLib tests
This commit is contained in:
parent
8572774304
commit
7f9340bb2f
6 changed files with 13 additions and 11 deletions
|
@ -45,7 +45,7 @@ def train_ppo_model():
|
|||
# Train for one iteration
|
||||
trainer.train()
|
||||
trainer.save("/tmp/rllib_checkpoint")
|
||||
return "/tmp/rllib_checkpoint/checkpoint_1/checkpoint-1"
|
||||
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
|
||||
|
||||
|
||||
checkpoint_path = train_ppo_model()
|
||||
|
|
|
@ -108,14 +108,15 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
|
||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||
assert checkpoints_metrics[0][0] == expected_path
|
||||
assert checkpoints_metrics[0][1] == 1
|
||||
|
||||
def testGetTrialCheckpointsPathsByPath(self):
|
||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1/", "checkpoint")
|
||||
expected_path = os.path.join(logdir, "checkpoint_000001/",
|
||||
"checkpoint")
|
||||
assert checkpoints_metrics[0][0] == expected_path
|
||||
assert checkpoints_metrics[0][1] == 1
|
||||
|
||||
|
@ -123,7 +124,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||
assert paths[0][0] == expected_path
|
||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||
|
||||
|
@ -131,7 +132,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
||||
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||
assert paths[0][0] == expected_path
|
||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
|||
self.assertTrue(os.path.isdir(abs_trial_dir))
|
||||
self.assertTrue(
|
||||
os.path.isfile(
|
||||
os.path.join(abs_trial_dir, "checkpoint_1/checkpoint-1")))
|
||||
os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint-1")))
|
||||
|
||||
def _restore(self, exp_name, local_dir, absolute_local_dir):
|
||||
trial_name, abs_trial_dir = self._get_trial_dir(
|
||||
|
@ -95,7 +95,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
|||
|
||||
checkpoint_path = os.path.join(
|
||||
local_dir, exp_name, trial_name,
|
||||
"checkpoint_1/checkpoint-1") # Relative checkpoint path
|
||||
"checkpoint_000001/checkpoint-1") # Relative checkpoint path
|
||||
|
||||
# The file tune would find. The absolute checkpoint path.
|
||||
tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path))
|
||||
|
|
|
@ -101,14 +101,15 @@ class TrainableUtil:
|
|||
|
||||
Args:
|
||||
checkpoint_dir (str): Path to checkpoint directory.
|
||||
index (str): A subdirectory will be created
|
||||
index (int|str): A subdirectory will be created
|
||||
at the checkpoint directory named 'checkpoint_{index}'.
|
||||
override (bool): Deletes checkpoint_dir before creating
|
||||
a new one.
|
||||
"""
|
||||
suffix = "checkpoint"
|
||||
if index is not None:
|
||||
suffix += "_{}".format(index)
|
||||
suffix += f"_{index:06d}" if isinstance(index,
|
||||
int) else f"_{index}"
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
|
||||
|
||||
if override and os.path.exists(checkpoint_dir):
|
||||
|
|
|
@ -10,7 +10,7 @@ Example usage for training:
|
|||
rllib train --run DQN --env CartPole-v0
|
||||
|
||||
Example usage for rollout:
|
||||
rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN
|
||||
rllib rollout /trial_dir/checkpoint_000001/checkpoint-1 --run DQN
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
|
|||
"}' --stop='{\"training_iteration\": 1}'" +
|
||||
" --env={}".format(env))
|
||||
|
||||
checkpoint_path = os.popen("ls {}/default/*/checkpoint_1/"
|
||||
checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
|
||||
"checkpoint-1".format(tmp_dir)).read()[:-1]
|
||||
if not os.path.exists(checkpoint_path):
|
||||
sys.exit(1)
|
||||
|
|
Loading…
Add table
Reference in a new issue