[Train] Add support for trainer.best_checkpoint and Trainer.load_checkpoint_path (#22306)

Closes #22226
This commit is contained in:
Amog Kamsetty 2022-02-11 22:29:37 -08:00 committed by GitHub
parent 640d92c385
commit 4cbbc81f4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 9 deletions

View file

@ -24,6 +24,15 @@ MIN = "min"
logger = logging.getLogger(__name__)
def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict:
"""Utility function to load a checkpoint Dict from a path."""
checkpoint_path = Path(checkpoint_to_load).expanduser()
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
with checkpoint_path.open("rb") as f:
return cloudpickle.load(f)
@PublicAPI(stability="beta")
@dataclass
class CheckpointStrategy:
@ -169,13 +178,7 @@ class CheckpointManager:
return checkpoint_to_load
else:
# Load checkpoint from path.
checkpoint_path = Path(checkpoint_to_load).expanduser()
if not checkpoint_path.exists():
raise ValueError(
f"Checkpoint path {checkpoint_path} " f"does not exist."
)
with checkpoint_path.open("rb") as f:
return cloudpickle.load(f)
return load_checkpoint_from_path(checkpoint_to_load)
def write_checkpoint(self, checkpoint: Dict):
"""Writes checkpoint to disk."""

View file

@ -367,12 +367,14 @@ def test_checkpoint(ray_start_2_cpus):
def train_func():
assert train.load_checkpoint() is None
for i in range(3):
time.sleep(1)
train.save_checkpoint(epoch=i)
return 1
trainer = Trainer(config, num_workers=2)
trainer.start()
trainer.run(train_func)
assert trainer.latest_checkpoint == trainer.best_checkpoint
checkpoint = trainer.latest_checkpoint
assert checkpoint is not None
@ -384,10 +386,12 @@ def test_checkpoint(ray_start_2_cpus):
assert checkpoint["epoch"] == 2
for i in range(checkpoint["epoch"], 5):
time.sleep(1)
train.save_checkpoint(epoch=i)
return 1
trainer.run(train_func_checkpoint, checkpoint=checkpoint)
assert trainer.latest_checkpoint == trainer.best_checkpoint
checkpoint = trainer.latest_checkpoint
assert checkpoint is not None
@ -496,6 +500,7 @@ def test_persisted_checkpoint(ray_start_2_cpus, logdir):
assert trainer.best_checkpoint_path.is_file()
assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}"
assert trainer.best_checkpoint_path.parent.name == "checkpoints"
assert trainer.best_checkpoint == trainer.latest_checkpoint
latest_checkpoint = trainer.latest_checkpoint
def validate():
@ -529,6 +534,8 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
assert trainer.latest_checkpoint_dir.is_dir()
assert trainer.best_checkpoint_path.is_file()
assert trainer.best_checkpoint_path.name == f"checkpoint_{1:06d}"
assert trainer.latest_checkpoint["loss"] == 5
assert trainer.best_checkpoint["loss"] == 3
checkpoint_dir = trainer.latest_checkpoint_dir
file_names = [f.name for f in checkpoint_dir.iterdir()]
@ -545,6 +552,28 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
trainer.run(validate, checkpoint=trainer.best_checkpoint_path)
def test_load_checkpoint_from_path(ray_start_2_cpus, tmpdir):
config = TestConfig()
checkpoint_strategy = CheckpointStrategy(
checkpoint_score_attribute="loss", checkpoint_score_order="min"
)
def train_func_checkpoint():
train.save_checkpoint(loss=3)
train.save_checkpoint(loss=7)
trainer = Trainer(config, num_workers=2, logdir=tmpdir)
trainer.start()
trainer.run(train_func_checkpoint, checkpoint_strategy=checkpoint_strategy)
assert trainer.best_checkpoint["loss"] == 3
assert (
Trainer.load_checkpoint_from_path(trainer.best_checkpoint_path)
== trainer.best_checkpoint
)
def test_persisted_checkpoint_strategy_failure(ray_start_2_cpus):
logdir = "/tmp/test/trainer/test_persisted_checkpoint_strategy_failure"
config = TestConfig()

View file

@ -21,6 +21,7 @@ from ray.train.checkpoint import (
CheckpointStrategy,
TuneCheckpointManager,
CheckpointManager,
load_checkpoint_from_path,
)
from ray.train.constants import (
TUNE_INSTALLED,
@ -466,8 +467,8 @@ class Trainer:
Default behavior is to return the most recent checkpoint.
Returns ``None`` if ``run()`` has not been called or if
``train.checkpoint()`` has not been called from ``train_func`` within
the most recent call to ``run``.
``train.save_checkpoint()`` has not been called from ``train_func``
within the most recent call to ``run``.
"""
return self.checkpoint_manager.best_checkpoint_path
@ -482,6 +483,37 @@ class Trainer:
"""
return self.checkpoint_manager.latest_checkpoint
@property
def best_checkpoint(self) -> Optional[Dict]:
"""Best saved checkpoint from the latest run.
"Best" is defined by the input ``CheckpointStrategy``.
Default behavior is to return the most recent checkpoint.
Returns ``None`` if ``run()`` has not been called or if
``train.save_checkpoint()`` has not been called from ``train_func``
within the most recent call to ``run``.
"""
best_checkpoint_path = self.best_checkpoint_path
if best_checkpoint_path is None:
return None
else:
return load_checkpoint_from_path(best_checkpoint_path)
@staticmethod
def load_checkpoint_from_path(checkpoint_file_path: Union[str, Path]) -> Dict:
"""Convenience method to load a checkpoint from path.
An error will be raised if the provided path does not exist.
Args:
checkpoint_file_path (Union[str, Path]): The path to the checkpoint
to load. If the checkpoint saved in this path has not been
created by Ray Train, there is no guarantee that it can be
loaded in successfully.
"""
return load_checkpoint_from_path(checkpoint_file_path)
def shutdown(self):
"""Shuts down the training execution service."""
ray.get(self._backend_executor_actor.shutdown.remote())