[tune] clean up start_trial API. (#20796)

This commit is contained in:
xwjiang2010 2021-12-01 08:46:22 -08:00 committed by GitHub
parent 9e38f6f613
commit 8c0bf41b17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 24 deletions

View file

@ -456,15 +456,11 @@ class RayTrialExecutor(TrialExecutor):
trial_item = self._find_item(self._running, trial)
assert len(trial_item) < 2, trial_item
def _start_trial(self, trial, checkpoint=None, train=True) -> bool:
def _start_trial(self, trial) -> bool:
"""Starts trial and restores last result if trial was paused.
Args:
trial (Trial): The trial to start.
checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
If None, and no trial checkpoint exists, the trial is started
from the beginning.
train (bool): Whether or not to start training.
Returns:
True if trial was started successfully, False otherwise.
@ -477,13 +473,13 @@ class RayTrialExecutor(TrialExecutor):
return False
trial.set_runner(runner)
self._notify_trainable_of_new_resources_if_needed(trial)
self.restore(trial, checkpoint)
self.restore(trial, trial.checkpoint)
self.set_status(trial, Trial.RUNNING)
if trial in self._staged_trials:
self._staged_trials.remove(trial)
if train and not trial.is_restoring:
if not trial.is_restoring:
self._train(trial)
return True
@ -561,26 +557,20 @@ class RayTrialExecutor(TrialExecutor):
finally:
trial.set_runner(None)
def start_trial(self,
trial: Trial,
checkpoint: Optional[Checkpoint] = None,
train: bool = True) -> bool:
def start_trial(self, trial: Trial) -> bool:
"""Starts the trial.
Will not return resources if trial repeatedly fails on start.
Args:
trial (Trial): Trial to be started.
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
train (bool): Whether or not to start training.
Returns:
True if the remote runner has been started. False if trial was
not started (e.g. because of lacking resources/pending PG).
"""
try:
return self._start_trial(trial, checkpoint, train=train)
return self._start_trial(trial)
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)

View file

@ -177,9 +177,9 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial)
checkpoint = self.trial_executor.pause_trial(trial)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial, checkpoint)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)

View file

@ -169,17 +169,11 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
pass
@abstractmethod
def start_trial(self,
trial: Trial,
checkpoint: Optional[Checkpoint] = None,
train: bool = True) -> bool:
def start_trial(self, trial: Trial) -> bool:
"""Starts the trial restoring from checkpoint if checkpoint is provided.
Args:
trial (Trial): Trial to be started.
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
train (bool): Whether or not to start training.
Returns:
True if trial started successfully, False otherwise.