diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index ec4df4433..6cdb8de24 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -433,8 +433,7 @@ class RayTrialExecutor(TrialExecutor): self.restore(trial) self.set_status(trial, Trial.RUNNING) - if trial in self._staged_trials: - self._staged_trials.remove(trial) + self._staged_trials.discard(trial) if not trial.is_restoring: self._train(trial) @@ -503,8 +502,7 @@ class RayTrialExecutor(TrialExecutor): if self._trial_cleanup: # force trial cleanup within a deadline self._trial_cleanup.add(future) - if trial in self._staged_trials: - self._staged_trials.remove(trial) + self._staged_trials.discard(trial) except Exception: logger.exception("Trial %s: Error stopping runner.", trial) @@ -626,8 +624,13 @@ class RayTrialExecutor(TrialExecutor): """ return ( trial in self._staged_trials + or ( + len(self._cached_actor_pg) > 0 + and (self._pg_manager.has_cached_pg(trial.placement_group_factory)) + ) or self._pg_manager.can_stage() or self._pg_manager.has_ready(trial, update=True) + or self._pg_manager.has_staging(trial) ) def debug_string(self) -> str: diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index f9e78b125..77334abb0 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -15,13 +15,13 @@ from ray.tune.ray_trial_executor import ( ExecutorEventType, RayTrialExecutor, ) -from ray.tune.registry import _global_registry, TRAINABLE_CLASS +from ray.tune.registry import _global_registry, TRAINABLE_CLASS, register_trainable from ray.tune.result import PID, TRAINING_ITERATION, TRIAL_ID from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, _TuneCheckpoint from ray.tune.resources import Resources from ray.cluster_utils import Cluster -from ray.tune.utils.placement_groups import PlacementGroupFactory +from ray.tune.utils.placement_groups import PlacementGroupFactory, PlacementGroupManager from unittest.mock import patch @@ -488,6 +488,62 @@ class RayExecutorPlacementGroupTest(unittest.TestCase): self.assertEqual(counter[pgf_2], 3) self.assertEqual(counter[pgf_3], 3) + def testHasResourcesForTrialWithCaching(self): + pgm = PlacementGroupManager() + pgf1 = PlacementGroupFactory([{"CPU": self.head_cpus}]) + pgf2 = PlacementGroupFactory([{"CPU": self.head_cpus - 1}]) + + executor = RayTrialExecutor(reuse_actors=True) + executor._pg_manager = pgm + executor.set_max_pending_trials(1) + + def train(config): + yield 1 + yield 2 + yield 3 + yield 4 + + register_trainable("resettable", train) + + trial1 = Trial("resettable", placement_group_factory=pgf1) + trial2 = Trial("resettable", placement_group_factory=pgf1) + trial3 = Trial("resettable", placement_group_factory=pgf2) + + assert executor.has_resources_for_trial(trial1) + assert executor.has_resources_for_trial(trial2) + assert executor.has_resources_for_trial(trial3) + + executor._stage_and_update_status([trial1, trial2, trial3]) + + while not pgm.has_ready(trial1): + time.sleep(1) + executor._stage_and_update_status([trial1, trial2, trial3]) + + # Fill staging + executor._stage_and_update_status([trial1, trial2, trial3]) + + assert executor.has_resources_for_trial(trial1) + assert executor.has_resources_for_trial(trial2) + assert not executor.has_resources_for_trial(trial3) + + executor._start_trial(trial1) + executor._stage_and_update_status([trial1, trial2, trial3]) + executor.pause_trial(trial1) # Caches the PG and removes a PG from staging + + assert len(pgm._staging_futures) == 0 + + # This will re-schedule a placement group + pgm.reconcile_placement_groups([trial1, trial2]) + + assert len(pgm._staging_futures) == 1 + + assert not pgm.can_stage() + + # We should still have resources for this trial as it has a cached PG + assert executor.has_resources_for_trial(trial1) + assert executor.has_resources_for_trial(trial2) + assert not executor.has_resources_for_trial(trial3) + class LocalModeExecutorTest(RayTrialExecutorTest): def setUp(self): diff --git a/python/ray/tune/utils/placement_groups.py b/python/ray/tune/utils/placement_groups.py index 2905df713..4eb0e4bac 100644 --- a/python/ray/tune/utils/placement_groups.py +++ b/python/ray/tune/utils/placement_groups.py @@ -537,6 +537,21 @@ class PlacementGroupManager: self.update_status() return bool(self._ready[trial.placement_group_factory]) + def has_staging(self, trial: "Trial", update: bool = False) -> bool: + """Return True if placement group for trial is staging. + + Args: + trial: :obj:`Trial` object. + update: Update status first. + + Returns: + Boolean. + + """ + if update: + self.update_status() + return bool(self._staging[trial.placement_group_factory]) + def trial_in_use(self, trial: "Trial"): return trial in self._in_use_trials @@ -604,6 +619,10 @@ class PlacementGroupManager: def clean_cached_pg(self, pg: PlacementGroup): self._cached_pgs.pop(pg) + def has_cached_pg(self, pgf: PlacementGroupFactory): + """Check if a placement group for given factory has been cached""" + return any(cached_pgf == pgf for cached_pgf in self._cached_pgs.values()) + def remove_from_in_use(self, trial: "Trial") -> PlacementGroup: """Return pg back to Core scheduling.