From 35dc3cf21b6ce349aba272f916bef7d697ec0247 Mon Sep 17 00:00:00 2001 From: matthewdeng Date: Mon, 15 Nov 2021 14:36:33 -0800 Subject: [PATCH] [train] fix Train/Tune integration on Client (#20351) * [train] fix Train/Tune integration on Client * remove force_on_current_node --- python/ray/train/trainer.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 0e98dcb3b..d8bc43419 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -21,8 +21,6 @@ from ray.train.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \ # Ray Train should be usable even if Tune is not installed. from ray.train.utils import construct_path from ray.train.worker_group import WorkerGroup -from ray.util.ml_utils.node import force_on_current_node, \ - get_current_node_resource_key if TUNE_INSTALLED: from ray import tune @@ -144,11 +142,6 @@ class Trainer: remote_executor = ray.remote(num_cpus=0)(BackendExecutor) - if not ray.is_initialized(): - ray.init() - # Assign BackendExecutor to head node. - remote_executor = force_on_current_node(remote_executor) - self._backend_executor_actor = remote_executor.remote( backend_config=self._backend_config, num_workers=num_workers, @@ -157,7 +150,7 @@ class Trainer: additional_resources_per_worker=resources_per_worker, max_retries=max_retries) - if tune is not None and tune.is_session_enabled(): + if self._is_tune_enabled(): self.checkpoint_manager = TuneCheckpointManager() else: self.checkpoint_manager = CheckpointManager() @@ -203,6 +196,10 @@ class Trainer: else: raise TypeError(f"Invalid type for backend: {type(backend)}.") + def _is_tune_enabled(self): + """Whether or not this Trainer is part of a Tune session.""" + return tune is not None and tune.is_session_enabled() + def start(self, initialization_hook: Optional[Callable[[], None]] = None): """Starts the training execution service. @@ -790,9 +787,7 @@ def _create_tune_trainable(train_func, dataset, backend_config, num_workers, @classmethod def default_resource_request(cls, config: Dict) -> PlacementGroupFactory: - node_resource_key = get_current_node_resource_key() trainer_bundle = [{"CPU": 1}] - backend_executor_bundle = [{node_resource_key: 0.01}] worker_resources = {"CPU": 1, "GPU": int(use_gpu)} worker_resources_extra = {} if resources_per_worker is None else \ resources_per_worker @@ -800,7 +795,7 @@ def _create_tune_trainable(train_func, dataset, backend_config, num_workers, **worker_resources, **worker_resources_extra } for _ in range(num_workers)] - bundles = trainer_bundle + backend_executor_bundle + worker_bundles + bundles = trainer_bundle + worker_bundles return PlacementGroupFactory(bundles, strategy="PACK") return TrainTrainable