mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] clean up start_trial API. (#20796)
This commit is contained in:
parent
9e38f6f613
commit
8c0bf41b17
3 changed files with 8 additions and 24 deletions
|
@ -456,15 +456,11 @@ class RayTrialExecutor(TrialExecutor):
|
|||
trial_item = self._find_item(self._running, trial)
|
||||
assert len(trial_item) < 2, trial_item
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None, train=True) -> bool:
|
||||
def _start_trial(self, trial) -> bool:
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
|
||||
Args:
|
||||
trial (Trial): The trial to start.
|
||||
checkpoint (Optional[Checkpoint]): The checkpoint to restore from.
|
||||
If None, and no trial checkpoint exists, the trial is started
|
||||
from the beginning.
|
||||
train (bool): Whether or not to start training.
|
||||
|
||||
Returns:
|
||||
True if trial was started successfully, False otherwise.
|
||||
|
@ -477,13 +473,13 @@ class RayTrialExecutor(TrialExecutor):
|
|||
return False
|
||||
trial.set_runner(runner)
|
||||
self._notify_trainable_of_new_resources_if_needed(trial)
|
||||
self.restore(trial, checkpoint)
|
||||
self.restore(trial, trial.checkpoint)
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
|
||||
if trial in self._staged_trials:
|
||||
self._staged_trials.remove(trial)
|
||||
|
||||
if train and not trial.is_restoring:
|
||||
if not trial.is_restoring:
|
||||
self._train(trial)
|
||||
return True
|
||||
|
||||
|
@ -561,26 +557,20 @@ class RayTrialExecutor(TrialExecutor):
|
|||
finally:
|
||||
trial.set_runner(None)
|
||||
|
||||
def start_trial(self,
|
||||
trial: Trial,
|
||||
checkpoint: Optional[Checkpoint] = None,
|
||||
train: bool = True) -> bool:
|
||||
def start_trial(self, trial: Trial) -> bool:
|
||||
"""Starts the trial.
|
||||
|
||||
Will not return resources if trial repeatedly fails on start.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
train (bool): Whether or not to start training.
|
||||
|
||||
Returns:
|
||||
True if the remote runner has been started. False if trial was
|
||||
not started (e.g. because of lacking resources/pending PG).
|
||||
"""
|
||||
try:
|
||||
return self._start_trial(trial, checkpoint, train=train)
|
||||
return self._start_trial(trial)
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
|
|
|
@ -177,9 +177,9 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.fetch_result(trial)
|
||||
checkpoint = self.trial_executor.pause_trial(trial)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.start_trial(trial, checkpoint)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
|
|
@ -169,17 +169,11 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_trial(self,
|
||||
trial: Trial,
|
||||
checkpoint: Optional[Checkpoint] = None,
|
||||
train: bool = True) -> bool:
|
||||
def start_trial(self, trial: Trial) -> bool:
|
||||
"""Starts the trial restoring from checkpoint if checkpoint is provided.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
train (bool): Whether or not to start training.
|
||||
|
||||
Returns:
|
||||
True if trial started successfully, False otherwise.
|
||||
|
|
Loading…
Add table
Reference in a new issue