mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Fix checkpoint manager with nan checkpoints (#24349)
Fixes checkpoints not being recorded in Tune's checkpoint manager if the first checkpoint has None value. This also deflakes `test_checkpoint_manager.py::CheckpointManagerTest`.
This commit is contained in:
parent
b2b1c95aa5
commit
87eaf55d82
2 changed files with 39 additions and 29 deletions
|
@ -106,14 +106,22 @@ class CheckpointManager:
|
|||
self._checkpoint_score_attr = checkpoint_score_attr
|
||||
|
||||
self.delete = delete_fn
|
||||
self.newest_persistent_checkpoint = _TuneCheckpoint(
|
||||
_TuneCheckpoint.PERSISTENT, None
|
||||
)
|
||||
self._newest_persistent_checkpoint = None
|
||||
self._newest_memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, None)
|
||||
self._best_checkpoints = []
|
||||
self._membership = set()
|
||||
self._cur_order = 0
|
||||
|
||||
@property
|
||||
def newest_persistent_checkpoint(self):
|
||||
return self._newest_persistent_checkpoint or _TuneCheckpoint(
|
||||
_TuneCheckpoint.PERSISTENT, None
|
||||
)
|
||||
|
||||
@newest_persistent_checkpoint.setter
|
||||
def newest_persistent_checkpoint(self, value):
|
||||
self._newest_persistent_checkpoint = value
|
||||
|
||||
@property
|
||||
def newest_checkpoint(self):
|
||||
"""Returns the newest checkpoint (based on training iteration)."""
|
||||
|
@ -154,9 +162,9 @@ class CheckpointManager:
|
|||
self.replace_newest_memory_checkpoint(checkpoint)
|
||||
return
|
||||
|
||||
old_checkpoint = self.newest_persistent_checkpoint
|
||||
old_checkpoint = self._newest_persistent_checkpoint
|
||||
|
||||
if old_checkpoint.value == checkpoint.value:
|
||||
if old_checkpoint and old_checkpoint.value == checkpoint.value:
|
||||
# Overwrite the order of the checkpoint.
|
||||
old_checkpoint.order = checkpoint.order
|
||||
return
|
||||
|
@ -164,7 +172,11 @@ class CheckpointManager:
|
|||
self.newest_persistent_checkpoint = checkpoint
|
||||
|
||||
# Remove the old checkpoint if it isn't one of the best ones.
|
||||
if old_checkpoint.value and old_checkpoint not in self._membership:
|
||||
if (
|
||||
old_checkpoint
|
||||
and old_checkpoint.value
|
||||
and old_checkpoint not in self._membership
|
||||
):
|
||||
self.delete(old_checkpoint)
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# coding: utf-8
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
@ -89,46 +89,44 @@ class CheckpointManagerTest(unittest.TestCase):
|
|||
Tests that the best checkpoints are tracked and ordered correctly.
|
||||
"""
|
||||
keep_checkpoints_num = 4
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i))
|
||||
for i in range(16)
|
||||
for i in range(8)
|
||||
]
|
||||
random.shuffle(checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
for permutation in itertools.permutations(checkpoints):
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
for i in range(len(best_checkpoints)):
|
||||
self.assertEqual(best_checkpoints[i].value, i + 12)
|
||||
for checkpoint in permutation:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
for i in range(len(best_checkpoints)):
|
||||
self.assertEqual(best_checkpoints[i].value, i + 4)
|
||||
|
||||
def testBestCheckpointsWithNan(self):
|
||||
"""
|
||||
Tests that checkpoints with nan priority are handled correctly.
|
||||
"""
|
||||
keep_checkpoints_num = 2
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
checkpoints = [
|
||||
_TuneCheckpoint(
|
||||
_TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i)
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
checkpoints += [
|
||||
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))
|
||||
]
|
||||
random.shuffle(checkpoints)
|
||||
] + [_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))]
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
for permutation in itertools.permutations(checkpoints):
|
||||
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
|
||||
for checkpoint in permutation:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
# best_checkpoints is sorted from worst to best
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
self.assertEqual(best_checkpoints[0].value, None)
|
||||
self.assertEqual(best_checkpoints[1].value, 3)
|
||||
best_checkpoints = checkpoint_manager.best_checkpoints()
|
||||
# best_checkpoints is sorted from worst to best
|
||||
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
|
||||
self.assertEqual(best_checkpoints[0].value, None)
|
||||
self.assertEqual(best_checkpoints[1].value, 3)
|
||||
|
||||
def testBestCheckpointsOnlyNan(self):
|
||||
"""
|
||||
|
|
Loading…
Add table
Reference in a new issue