[tune] Fix checkpoint manager with nan checkpoints (#24349)

Fixes checkpoints not being recorded in Tune's checkpoint manager if the first checkpoint has None value. This also deflakes `test_checkpoint_manager.py::CheckpointManagerTest`.
This commit is contained in:
Antoni Baum 2022-04-30 10:23:57 +02:00 committed by GitHub
parent b2b1c95aa5
commit 87eaf55d82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 29 deletions

View file

@ -106,14 +106,22 @@ class CheckpointManager:
self._checkpoint_score_attr = checkpoint_score_attr
self.delete = delete_fn
self.newest_persistent_checkpoint = _TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, None
)
self._newest_persistent_checkpoint = None
self._newest_memory_checkpoint = _TuneCheckpoint(_TuneCheckpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
self._cur_order = 0
@property
def newest_persistent_checkpoint(self):
return self._newest_persistent_checkpoint or _TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, None
)
@newest_persistent_checkpoint.setter
def newest_persistent_checkpoint(self, value):
self._newest_persistent_checkpoint = value
@property
def newest_checkpoint(self):
"""Returns the newest checkpoint (based on training iteration)."""
@ -154,9 +162,9 @@ class CheckpointManager:
self.replace_newest_memory_checkpoint(checkpoint)
return
old_checkpoint = self.newest_persistent_checkpoint
old_checkpoint = self._newest_persistent_checkpoint
if old_checkpoint.value == checkpoint.value:
if old_checkpoint and old_checkpoint.value == checkpoint.value:
# Overwrite the order of the checkpoint.
old_checkpoint.order = checkpoint.order
return
@ -164,7 +172,11 @@ class CheckpointManager:
self.newest_persistent_checkpoint = checkpoint
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint.value and old_checkpoint not in self._membership:
if (
old_checkpoint
and old_checkpoint.value
and old_checkpoint not in self._membership
):
self.delete(old_checkpoint)
try:

View file

@ -1,6 +1,6 @@
# coding: utf-8
import itertools
import os
import random
import sys
import tempfile
import unittest
@ -89,46 +89,44 @@ class CheckpointManagerTest(unittest.TestCase):
Tests that the best checkpoints are tracked and ordered correctly.
"""
keep_checkpoints_num = 4
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i))
for i in range(16)
for i in range(8)
]
random.shuffle(checkpoints)
for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)
for permutation in itertools.permutations(checkpoints):
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
for i in range(len(best_checkpoints)):
self.assertEqual(best_checkpoints[i].value, i + 12)
for checkpoint in permutation:
checkpoint_manager.on_checkpoint(checkpoint)
best_checkpoints = checkpoint_manager.best_checkpoints()
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
for i in range(len(best_checkpoints)):
self.assertEqual(best_checkpoints[i].value, i + 4)
def testBestCheckpointsWithNan(self):
"""
Tests that checkpoints with nan priority are handled correctly.
"""
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i)
)
for i in range(2)
]
checkpoints += [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))
]
random.shuffle(checkpoints)
] + [_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))]
for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)
for permutation in itertools.permutations(checkpoints):
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
for checkpoint in permutation:
checkpoint_manager.on_checkpoint(checkpoint)
best_checkpoints = checkpoint_manager.best_checkpoints()
# best_checkpoints is sorted from worst to best
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertEqual(best_checkpoints[0].value, None)
self.assertEqual(best_checkpoints[1].value, 3)
best_checkpoints = checkpoint_manager.best_checkpoints()
# best_checkpoints is sorted from worst to best
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertEqual(best_checkpoints[0].value, None)
self.assertEqual(best_checkpoints[1].value, 3)
def testBestCheckpointsOnlyNan(self):
"""