mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Add leading zeros to checkpoint directory (#14152)
* [tune] Add leading zeros to checkpoint directory * Fix exp analysis tests/support string indices * Fix tests * RLLib tests
This commit is contained in:
parent
8572774304
commit
7f9340bb2f
6 changed files with 13 additions and 11 deletions
|
@ -45,7 +45,7 @@ def train_ppo_model():
|
||||||
# Train for one iteration
|
# Train for one iteration
|
||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.save("/tmp/rllib_checkpoint")
|
trainer.save("/tmp/rllib_checkpoint")
|
||||||
return "/tmp/rllib_checkpoint/checkpoint_1/checkpoint-1"
|
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
|
||||||
|
|
||||||
|
|
||||||
checkpoint_path = train_ppo_model()
|
checkpoint_path = train_ppo_model()
|
||||||
|
|
|
@ -108,14 +108,15 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
|
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
|
||||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||||
assert checkpoints_metrics[0][0] == expected_path
|
assert checkpoints_metrics[0][0] == expected_path
|
||||||
assert checkpoints_metrics[0][1] == 1
|
assert checkpoints_metrics[0][1] == 1
|
||||||
|
|
||||||
def testGetTrialCheckpointsPathsByPath(self):
|
def testGetTrialCheckpointsPathsByPath(self):
|
||||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
|
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
|
||||||
expected_path = os.path.join(logdir, "checkpoint_1/", "checkpoint")
|
expected_path = os.path.join(logdir, "checkpoint_000001/",
|
||||||
|
"checkpoint")
|
||||||
assert checkpoints_metrics[0][0] == expected_path
|
assert checkpoints_metrics[0][0] == expected_path
|
||||||
assert checkpoints_metrics[0][1] == 1
|
assert checkpoints_metrics[0][1] == 1
|
||||||
|
|
||||||
|
@ -123,7 +124,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||||
assert paths[0][0] == expected_path
|
assert paths[0][0] == expected_path
|
||||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||||
|
|
||||||
|
@ -131,7 +132,7 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||||
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
best_trial = self.ea.get_best_trial(self.metric, mode="max")
|
||||||
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
logdir = self.ea.get_best_logdir(self.metric, mode="max")
|
||||||
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
|
||||||
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
|
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
|
||||||
assert paths[0][0] == expected_path
|
assert paths[0][0] == expected_path
|
||||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||||
self.assertTrue(os.path.isdir(abs_trial_dir))
|
self.assertTrue(os.path.isdir(abs_trial_dir))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
os.path.isfile(
|
os.path.isfile(
|
||||||
os.path.join(abs_trial_dir, "checkpoint_1/checkpoint-1")))
|
os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint-1")))
|
||||||
|
|
||||||
def _restore(self, exp_name, local_dir, absolute_local_dir):
|
def _restore(self, exp_name, local_dir, absolute_local_dir):
|
||||||
trial_name, abs_trial_dir = self._get_trial_dir(
|
trial_name, abs_trial_dir = self._get_trial_dir(
|
||||||
|
@ -95,7 +95,7 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||||
|
|
||||||
checkpoint_path = os.path.join(
|
checkpoint_path = os.path.join(
|
||||||
local_dir, exp_name, trial_name,
|
local_dir, exp_name, trial_name,
|
||||||
"checkpoint_1/checkpoint-1") # Relative checkpoint path
|
"checkpoint_000001/checkpoint-1") # Relative checkpoint path
|
||||||
|
|
||||||
# The file tune would find. The absolute checkpoint path.
|
# The file tune would find. The absolute checkpoint path.
|
||||||
tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path))
|
tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path))
|
||||||
|
|
|
@ -101,14 +101,15 @@ class TrainableUtil:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_dir (str): Path to checkpoint directory.
|
checkpoint_dir (str): Path to checkpoint directory.
|
||||||
index (str): A subdirectory will be created
|
index (int|str): A subdirectory will be created
|
||||||
at the checkpoint directory named 'checkpoint_{index}'.
|
at the checkpoint directory named 'checkpoint_{index}'.
|
||||||
override (bool): Deletes checkpoint_dir before creating
|
override (bool): Deletes checkpoint_dir before creating
|
||||||
a new one.
|
a new one.
|
||||||
"""
|
"""
|
||||||
suffix = "checkpoint"
|
suffix = "checkpoint"
|
||||||
if index is not None:
|
if index is not None:
|
||||||
suffix += "_{}".format(index)
|
suffix += f"_{index:06d}" if isinstance(index,
|
||||||
|
int) else f"_{index}"
|
||||||
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
|
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
|
||||||
|
|
||||||
if override and os.path.exists(checkpoint_dir):
|
if override and os.path.exists(checkpoint_dir):
|
||||||
|
|
|
@ -10,7 +10,7 @@ Example usage for training:
|
||||||
rllib train --run DQN --env CartPole-v0
|
rllib train --run DQN --env CartPole-v0
|
||||||
|
|
||||||
Example usage for rollout:
|
Example usage for rollout:
|
||||||
rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN
|
rllib rollout /trial_dir/checkpoint_000001/checkpoint-1 --run DQN
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
|
||||||
"}' --stop='{\"training_iteration\": 1}'" +
|
"}' --stop='{\"training_iteration\": 1}'" +
|
||||||
" --env={}".format(env))
|
" --env={}".format(env))
|
||||||
|
|
||||||
checkpoint_path = os.popen("ls {}/default/*/checkpoint_1/"
|
checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
|
||||||
"checkpoint-1".format(tmp_dir)).read()[:-1]
|
"checkpoint-1".format(tmp_dir)).read()[:-1]
|
||||||
if not os.path.exists(checkpoint_path):
|
if not os.path.exists(checkpoint_path):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue