diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 866b669ae..1f2fbfa53 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -217,7 +217,8 @@ class _TuneCheckpointCallback(TuneCallback): def _handle(self, trainer: Trainer, pl_module: LightningModule): if trainer.running_sanity_check: return - with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir: + step = f"epoch={trainer.current_epoch}-step={trainer.global_step}" + with tune.checkpoint_dir(step=step) as checkpoint_dir: trainer.save_checkpoint( os.path.join(checkpoint_dir, self._filename))