mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Train] Add support for trainer.best_checkpoint
and Trainer.load_checkpoint_path
(#22306)
Closes #22226
This commit is contained in:
parent
640d92c385
commit
4cbbc81f4c
3 changed files with 73 additions and 9 deletions
|
@ -24,6 +24,15 @@ MIN = "min"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict:
|
||||
"""Utility function to load a checkpoint Dict from a path."""
|
||||
checkpoint_path = Path(checkpoint_to_load).expanduser()
|
||||
if not checkpoint_path.exists():
|
||||
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
|
||||
with checkpoint_path.open("rb") as f:
|
||||
return cloudpickle.load(f)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
@dataclass
|
||||
class CheckpointStrategy:
|
||||
|
@ -169,13 +178,7 @@ class CheckpointManager:
|
|||
return checkpoint_to_load
|
||||
else:
|
||||
# Load checkpoint from path.
|
||||
checkpoint_path = Path(checkpoint_to_load).expanduser()
|
||||
if not checkpoint_path.exists():
|
||||
raise ValueError(
|
||||
f"Checkpoint path {checkpoint_path} " f"does not exist."
|
||||
)
|
||||
with checkpoint_path.open("rb") as f:
|
||||
return cloudpickle.load(f)
|
||||
return load_checkpoint_from_path(checkpoint_to_load)
|
||||
|
||||
def write_checkpoint(self, checkpoint: Dict):
|
||||
"""Writes checkpoint to disk."""
|
||||
|
|
|
@ -367,12 +367,14 @@ def test_checkpoint(ray_start_2_cpus):
|
|||
def train_func():
|
||||
assert train.load_checkpoint() is None
|
||||
for i in range(3):
|
||||
time.sleep(1)
|
||||
train.save_checkpoint(epoch=i)
|
||||
return 1
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
trainer.run(train_func)
|
||||
assert trainer.latest_checkpoint == trainer.best_checkpoint
|
||||
checkpoint = trainer.latest_checkpoint
|
||||
|
||||
assert checkpoint is not None
|
||||
|
@ -384,10 +386,12 @@ def test_checkpoint(ray_start_2_cpus):
|
|||
assert checkpoint["epoch"] == 2
|
||||
|
||||
for i in range(checkpoint["epoch"], 5):
|
||||
time.sleep(1)
|
||||
train.save_checkpoint(epoch=i)
|
||||
return 1
|
||||
|
||||
trainer.run(train_func_checkpoint, checkpoint=checkpoint)
|
||||
assert trainer.latest_checkpoint == trainer.best_checkpoint
|
||||
checkpoint = trainer.latest_checkpoint
|
||||
|
||||
assert checkpoint is not None
|
||||
|
@ -496,6 +500,7 @@ def test_persisted_checkpoint(ray_start_2_cpus, logdir):
|
|||
assert trainer.best_checkpoint_path.is_file()
|
||||
assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}"
|
||||
assert trainer.best_checkpoint_path.parent.name == "checkpoints"
|
||||
assert trainer.best_checkpoint == trainer.latest_checkpoint
|
||||
latest_checkpoint = trainer.latest_checkpoint
|
||||
|
||||
def validate():
|
||||
|
@ -529,6 +534,8 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
|
|||
assert trainer.latest_checkpoint_dir.is_dir()
|
||||
assert trainer.best_checkpoint_path.is_file()
|
||||
assert trainer.best_checkpoint_path.name == f"checkpoint_{1:06d}"
|
||||
assert trainer.latest_checkpoint["loss"] == 5
|
||||
assert trainer.best_checkpoint["loss"] == 3
|
||||
|
||||
checkpoint_dir = trainer.latest_checkpoint_dir
|
||||
file_names = [f.name for f in checkpoint_dir.iterdir()]
|
||||
|
@ -545,6 +552,28 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
|
|||
trainer.run(validate, checkpoint=trainer.best_checkpoint_path)
|
||||
|
||||
|
||||
def test_load_checkpoint_from_path(ray_start_2_cpus, tmpdir):
|
||||
config = TestConfig()
|
||||
|
||||
checkpoint_strategy = CheckpointStrategy(
|
||||
checkpoint_score_attribute="loss", checkpoint_score_order="min"
|
||||
)
|
||||
|
||||
def train_func_checkpoint():
|
||||
train.save_checkpoint(loss=3)
|
||||
train.save_checkpoint(loss=7)
|
||||
|
||||
trainer = Trainer(config, num_workers=2, logdir=tmpdir)
|
||||
trainer.start()
|
||||
trainer.run(train_func_checkpoint, checkpoint_strategy=checkpoint_strategy)
|
||||
|
||||
assert trainer.best_checkpoint["loss"] == 3
|
||||
assert (
|
||||
Trainer.load_checkpoint_from_path(trainer.best_checkpoint_path)
|
||||
== trainer.best_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def test_persisted_checkpoint_strategy_failure(ray_start_2_cpus):
|
||||
logdir = "/tmp/test/trainer/test_persisted_checkpoint_strategy_failure"
|
||||
config = TestConfig()
|
||||
|
|
|
@ -21,6 +21,7 @@ from ray.train.checkpoint import (
|
|||
CheckpointStrategy,
|
||||
TuneCheckpointManager,
|
||||
CheckpointManager,
|
||||
load_checkpoint_from_path,
|
||||
)
|
||||
from ray.train.constants import (
|
||||
TUNE_INSTALLED,
|
||||
|
@ -466,8 +467,8 @@ class Trainer:
|
|||
Default behavior is to return the most recent checkpoint.
|
||||
|
||||
Returns ``None`` if ``run()`` has not been called or if
|
||||
``train.checkpoint()`` has not been called from ``train_func`` within
|
||||
the most recent call to ``run``.
|
||||
``train.save_checkpoint()`` has not been called from ``train_func``
|
||||
within the most recent call to ``run``.
|
||||
"""
|
||||
return self.checkpoint_manager.best_checkpoint_path
|
||||
|
||||
|
@ -482,6 +483,37 @@ class Trainer:
|
|||
"""
|
||||
return self.checkpoint_manager.latest_checkpoint
|
||||
|
||||
@property
|
||||
def best_checkpoint(self) -> Optional[Dict]:
|
||||
"""Best saved checkpoint from the latest run.
|
||||
|
||||
"Best" is defined by the input ``CheckpointStrategy``.
|
||||
Default behavior is to return the most recent checkpoint.
|
||||
|
||||
Returns ``None`` if ``run()`` has not been called or if
|
||||
``train.save_checkpoint()`` has not been called from ``train_func``
|
||||
within the most recent call to ``run``.
|
||||
"""
|
||||
best_checkpoint_path = self.best_checkpoint_path
|
||||
if best_checkpoint_path is None:
|
||||
return None
|
||||
else:
|
||||
return load_checkpoint_from_path(best_checkpoint_path)
|
||||
|
||||
@staticmethod
|
||||
def load_checkpoint_from_path(checkpoint_file_path: Union[str, Path]) -> Dict:
|
||||
"""Convenience method to load a checkpoint from path.
|
||||
|
||||
An error will be raised if the provided path does not exist.
|
||||
|
||||
Args:
|
||||
checkpoint_file_path (Union[str, Path]): The path to the checkpoint
|
||||
to load. If the checkpoint saved in this path has not been
|
||||
created by Ray Train, there is no guarantee that it can be
|
||||
loaded in successfully.
|
||||
"""
|
||||
return load_checkpoint_from_path(checkpoint_file_path)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down the training execution service."""
|
||||
ray.get(self._backend_executor_actor.shutdown.remote())
|
||||
|
|
Loading…
Add table
Reference in a new issue