diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 0332c18d0..5d6b91c86 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -456,15 +456,11 @@ class RayTrialExecutor(TrialExecutor): trial_item = self._find_item(self._running, trial) assert len(trial_item) < 2, trial_item - def _start_trial(self, trial, checkpoint=None, train=True) -> bool: + def _start_trial(self, trial) -> bool: """Starts trial and restores last result if trial was paused. Args: trial (Trial): The trial to start. - checkpoint (Optional[Checkpoint]): The checkpoint to restore from. - If None, and no trial checkpoint exists, the trial is started - from the beginning. - train (bool): Whether or not to start training. Returns: True if trial was started successfully, False otherwise. @@ -477,13 +473,13 @@ class RayTrialExecutor(TrialExecutor): return False trial.set_runner(runner) self._notify_trainable_of_new_resources_if_needed(trial) - self.restore(trial, checkpoint) + self.restore(trial, trial.checkpoint) self.set_status(trial, Trial.RUNNING) if trial in self._staged_trials: self._staged_trials.remove(trial) - if train and not trial.is_restoring: + if not trial.is_restoring: self._train(trial) return True @@ -561,26 +557,20 @@ class RayTrialExecutor(TrialExecutor): finally: trial.set_runner(None) - def start_trial(self, - trial: Trial, - checkpoint: Optional[Checkpoint] = None, - train: bool = True) -> bool: + def start_trial(self, trial: Trial) -> bool: """Starts the trial. Will not return resources if trial repeatedly fails on start. Args: trial (Trial): Trial to be started. - checkpoint (Checkpoint): A Python object or path storing the state - of trial. - train (bool): Whether or not to start training. Returns: True if the remote runner has been started. False if trial was not started (e.g. because of lacking resources/pending PG). """ try: - return self._start_trial(trial, checkpoint, train=train) + return self._start_trial(trial) except AbortTrialExecution: logger.exception("Trial %s: Error starting runner, aborting!", trial) diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index a1ef76b22..760347466 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -177,9 +177,9 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) - checkpoint = self.trial_executor.pause_trial(trial) + self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) - self.trial_executor.start_trial(trial, checkpoint) + self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 77c12ad21..4c6cc999b 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -169,17 +169,11 @@ class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta): pass @abstractmethod - def start_trial(self, - trial: Trial, - checkpoint: Optional[Checkpoint] = None, - train: bool = True) -> bool: + def start_trial(self, trial: Trial) -> bool: """Starts the trial restoring from checkpoint if checkpoint is provided. Args: trial (Trial): Trial to be started. - checkpoint (Checkpoint): A Python object or path storing the state - of trial. - train (bool): Whether or not to start training. Returns: True if trial started successfully, False otherwise.