mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[train] fix Train/Tune integration on Client (#20351)
* [train] fix Train/Tune integration on Client * remove force_on_current_node
This commit is contained in:
parent
884bb3de33
commit
35dc3cf21b
1 changed files with 6 additions and 11 deletions
|
@ -21,8 +21,6 @@ from ray.train.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \
|
|||
# Ray Train should be usable even if Tune is not installed.
|
||||
from ray.train.utils import construct_path
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.util.ml_utils.node import force_on_current_node, \
|
||||
get_current_node_resource_key
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
@ -144,11 +142,6 @@ class Trainer:
|
|||
|
||||
remote_executor = ray.remote(num_cpus=0)(BackendExecutor)
|
||||
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
# Assign BackendExecutor to head node.
|
||||
remote_executor = force_on_current_node(remote_executor)
|
||||
|
||||
self._backend_executor_actor = remote_executor.remote(
|
||||
backend_config=self._backend_config,
|
||||
num_workers=num_workers,
|
||||
|
@ -157,7 +150,7 @@ class Trainer:
|
|||
additional_resources_per_worker=resources_per_worker,
|
||||
max_retries=max_retries)
|
||||
|
||||
if tune is not None and tune.is_session_enabled():
|
||||
if self._is_tune_enabled():
|
||||
self.checkpoint_manager = TuneCheckpointManager()
|
||||
else:
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
|
@ -203,6 +196,10 @@ class Trainer:
|
|||
else:
|
||||
raise TypeError(f"Invalid type for backend: {type(backend)}.")
|
||||
|
||||
def _is_tune_enabled(self):
|
||||
"""Whether or not this Trainer is part of a Tune session."""
|
||||
return tune is not None and tune.is_session_enabled()
|
||||
|
||||
def start(self, initialization_hook: Optional[Callable[[], None]] = None):
|
||||
"""Starts the training execution service.
|
||||
|
||||
|
@ -790,9 +787,7 @@ def _create_tune_trainable(train_func, dataset, backend_config, num_workers,
|
|||
@classmethod
|
||||
def default_resource_request(cls,
|
||||
config: Dict) -> PlacementGroupFactory:
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
trainer_bundle = [{"CPU": 1}]
|
||||
backend_executor_bundle = [{node_resource_key: 0.01}]
|
||||
worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
|
||||
worker_resources_extra = {} if resources_per_worker is None else \
|
||||
resources_per_worker
|
||||
|
@ -800,7 +795,7 @@ def _create_tune_trainable(train_func, dataset, backend_config, num_workers,
|
|||
**worker_resources,
|
||||
**worker_resources_extra
|
||||
} for _ in range(num_workers)]
|
||||
bundles = trainer_bundle + backend_executor_bundle + worker_bundles
|
||||
bundles = trainer_bundle + worker_bundles
|
||||
return PlacementGroupFactory(bundles, strategy="PACK")
|
||||
|
||||
return TrainTrainable
|
||||
|
|
Loading…
Add table
Reference in a new issue