[tune] minor clean up of executor.restore(). (#20916)

Removes two unneeded parameters that were not used anymore
This commit is contained in:
xwjiang2010 2021-12-07 08:37:48 -08:00 committed by GitHub
parent 2868d1a2cf
commit 011ae389a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 19 deletions

View file

@ -460,7 +460,7 @@ class RayTrialExecutor(TrialExecutor):
return False
trial.set_runner(runner)
self._notify_trainable_of_new_resources_if_needed(trial)
self.restore(trial, trial.checkpoint)
self.restore(trial)
self.set_status(trial, Trial.RUNNING)
if trial in self._staged_trials:
@ -863,23 +863,18 @@ class RayTrialExecutor(TrialExecutor):
self._running[value] = trial
return checkpoint
def restore(self, trial, checkpoint=None, block=False) -> None:
def restore(self, trial) -> None:
"""Restores training state from a given model checkpoint.
Args:
trial (Trial): The trial to be restored.
checkpoint (Checkpoint): The checkpoint to restore from. If None,
the most recent PERSISTENT checkpoint is used. Defaults to
None.
block (bool): Whether or not to block on restore before returning.
Raises:
RuntimeError: This error is raised if no runner is found.
AbortTrialExecution: This error is raised if the trial is
ineligible for restoration, given the Tune input arguments.
"""
if checkpoint is None or checkpoint.value is None:
checkpoint = trial.checkpoint
checkpoint = trial.checkpoint
if checkpoint.value is None:
return
if trial.runner is None:
@ -910,11 +905,8 @@ class RayTrialExecutor(TrialExecutor):
"restoration. Pass in an `upload_dir` for remote "
"storage-based restoration")
if block:
ray.get(remote)
else:
self._running[remote] = trial
trial.restoring_from = checkpoint
self._running[remote] = trial
trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial: Trial) -> Dict:
"""Exports model of this trial based on trial.export_formats.

View file

@ -210,10 +210,7 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
pass
@abstractmethod
def restore(self,
trial: Trial,
checkpoint: Optional[Checkpoint] = None,
block: bool = False) -> None:
def restore(self, trial: Trial) -> None:
"""Restores training state from a checkpoint.
If checkpoint is None, try to restore from trial.checkpoint.
@ -221,8 +218,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
Args:
trial (Trial): Trial to be restored.
checkpoint (Checkpoint): Checkpoint to restore from.
block (bool): Whether or not to block on restore before returning.
Returns:
False if error occurred, otherwise return True.