[tune] minor clean up of executor.restore(). (#20916)

Removes two unneeded parameters that were not used anymore
This commit is contained in:
xwjiang2010 2021-12-07 08:37:48 -08:00 committed by GitHub
parent 2868d1a2cf
commit 011ae389a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 19 deletions

View file

@ -460,7 +460,7 @@ class RayTrialExecutor(TrialExecutor):
return False return False
trial.set_runner(runner) trial.set_runner(runner)
self._notify_trainable_of_new_resources_if_needed(trial) self._notify_trainable_of_new_resources_if_needed(trial)
self.restore(trial, trial.checkpoint) self.restore(trial)
self.set_status(trial, Trial.RUNNING) self.set_status(trial, Trial.RUNNING)
if trial in self._staged_trials: if trial in self._staged_trials:
@ -863,23 +863,18 @@ class RayTrialExecutor(TrialExecutor):
self._running[value] = trial self._running[value] = trial
return checkpoint return checkpoint
def restore(self, trial, checkpoint=None, block=False) -> None: def restore(self, trial) -> None:
"""Restores training state from a given model checkpoint. """Restores training state from a given model checkpoint.
Args: Args:
trial (Trial): The trial to be restored. trial (Trial): The trial to be restored.
checkpoint (Checkpoint): The checkpoint to restore from. If None,
the most recent PERSISTENT checkpoint is used. Defaults to
None.
block (bool): Whether or not to block on restore before returning.
Raises: Raises:
RuntimeError: This error is raised if no runner is found. RuntimeError: This error is raised if no runner is found.
AbortTrialExecution: This error is raised if the trial is AbortTrialExecution: This error is raised if the trial is
ineligible for restoration, given the Tune input arguments. ineligible for restoration, given the Tune input arguments.
""" """
if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint
checkpoint = trial.checkpoint
if checkpoint.value is None: if checkpoint.value is None:
return return
if trial.runner is None: if trial.runner is None:
@ -910,11 +905,8 @@ class RayTrialExecutor(TrialExecutor):
"restoration. Pass in an `upload_dir` for remote " "restoration. Pass in an `upload_dir` for remote "
"storage-based restoration") "storage-based restoration")
if block: self._running[remote] = trial
ray.get(remote) trial.restoring_from = checkpoint
else:
self._running[remote] = trial
trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial: Trial) -> Dict: def export_trial_if_needed(self, trial: Trial) -> Dict:
"""Exports model of this trial based on trial.export_formats. """Exports model of this trial based on trial.export_formats.

View file

@ -210,10 +210,7 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
pass pass
@abstractmethod @abstractmethod
def restore(self, def restore(self, trial: Trial) -> None:
trial: Trial,
checkpoint: Optional[Checkpoint] = None,
block: bool = False) -> None:
"""Restores training state from a checkpoint. """Restores training state from a checkpoint.
If checkpoint is None, try to restore from trial.checkpoint. If checkpoint is None, try to restore from trial.checkpoint.
@ -221,8 +218,6 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
Args: Args:
trial (Trial): Trial to be restored. trial (Trial): Trial to be restored.
checkpoint (Checkpoint): Checkpoint to restore from.
block (bool): Whether or not to block on restore before returning.
Returns: Returns:
False if error occurred, otherwise return True. False if error occurred, otherwise return True.