mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[train] add placement group support (#20091)
* [train] add placement group support * fix additional resources * fix tests * add comment to add_workers
This commit is contained in:
parent
f6399e3389
commit
33af739bf2
6 changed files with 201 additions and 17 deletions
|
@ -11,12 +11,16 @@ from ray.ray_constants import env_integer
|
|||
from ray.train.checkpoint import CheckpointManager, CheckpointStrategy, \
|
||||
TuneCheckpointManager
|
||||
from ray.train.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
|
||||
TUNE_INSTALLED, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
|
||||
TUNE_INSTALLED, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV, \
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV
|
||||
from ray.train.session import TrainingResultType, TrainingResult
|
||||
from ray.train.session import init_session, get_session, shutdown_session
|
||||
from ray.train.utils import RayDataset
|
||||
from ray.train.utils import check_for_failure
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.util.placement_group import get_current_placement_group, \
|
||||
remove_placement_group
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
@ -93,6 +97,7 @@ class BackendExecutor:
|
|||
self._max_failures = float("inf")
|
||||
self._num_failures = 0
|
||||
self._initialization_hook = None
|
||||
self._placement_group = None
|
||||
|
||||
if tune is not None and tune.is_session_enabled():
|
||||
self.checkpoint_manager = TuneCheckpointManager()
|
||||
|
@ -110,6 +115,8 @@ class BackendExecutor:
|
|||
train_cls_args: Optional[Tuple] = None,
|
||||
train_cls_kwargs: Optional[Dict] = None):
|
||||
"""Starts the worker group."""
|
||||
self._create_placement_group()
|
||||
placement_group = self._placement_group or "default"
|
||||
self.worker_group = WorkerGroup(
|
||||
num_workers=self._num_workers,
|
||||
num_cpus_per_worker=self._num_cpus_per_worker,
|
||||
|
@ -118,7 +125,8 @@ class BackendExecutor:
|
|||
_additional_resources_per_worker,
|
||||
actor_cls=train_cls,
|
||||
actor_cls_args=train_cls_args,
|
||||
actor_cls_kwargs=train_cls_kwargs)
|
||||
actor_cls_kwargs=train_cls_kwargs,
|
||||
placement_group=placement_group)
|
||||
try:
|
||||
if initialization_hook:
|
||||
self._initialization_hook = initialization_hook
|
||||
|
@ -137,6 +145,57 @@ class BackendExecutor:
|
|||
self._increment_failures()
|
||||
self._restart()
|
||||
|
||||
def _create_placement_group(self):
|
||||
"""Creates a placement group if it does not exist.
|
||||
|
||||
If a placement group is already detected (Tune) this will be a no-op.
|
||||
|
||||
By default the placement group will be created with PACK strategy.
|
||||
This is optimized for colocating GPUs on a minimal number of nodes.
|
||||
This behavior can be overridden to use the SPREAD strategy by defining
|
||||
``TRAIN_ENABLE_WORKER_SPREAD_ENV``
|
||||
|
||||
If a placement group is created it will be stored as
|
||||
self._placement_group.
|
||||
"""
|
||||
current_placement_group = get_current_placement_group()
|
||||
should_capture_child_tasks_in_placement_group = \
|
||||
ray.worker.global_worker \
|
||||
.should_capture_child_tasks_in_placement_group
|
||||
should_create_placement_group = \
|
||||
current_placement_group is None or \
|
||||
not should_capture_child_tasks_in_placement_group
|
||||
|
||||
if should_create_placement_group:
|
||||
additional_resources_per_worker = \
|
||||
self._additional_resources_per_worker or {}
|
||||
bundle = {
|
||||
"CPU": self._num_cpus_per_worker,
|
||||
"GPU": self._num_gpus_per_worker,
|
||||
**additional_resources_per_worker
|
||||
}
|
||||
bundles = [bundle.copy() for _ in range(self._num_workers)]
|
||||
|
||||
use_spread = bool(env_integer(TRAIN_ENABLE_WORKER_SPREAD_ENV, 0))
|
||||
strategy = "SPREAD" if use_spread else "PACK"
|
||||
|
||||
placement_group = ray.util.placement_group(
|
||||
bundles, strategy=strategy)
|
||||
logger.debug("Waiting for placement group to start.")
|
||||
timeout = env_integer(TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)
|
||||
ready, _ = ray.wait([placement_group.ready()], timeout=timeout)
|
||||
if ready:
|
||||
logger.debug("Placement group has started.")
|
||||
else:
|
||||
raise TimeoutError(
|
||||
"Placement group creation timed out. Make sure "
|
||||
"your cluster either has enough resources or use "
|
||||
"an autoscaling cluster. Current resources "
|
||||
"available: {}, resources requested by the "
|
||||
"placement group: {}".format(ray.available_resources(),
|
||||
placement_group.bundle_specs))
|
||||
self._placement_group = placement_group
|
||||
|
||||
def _share_cuda_visible_devices(self):
|
||||
"""Sets CUDA_VISIBLE_DEVICES on all workers.
|
||||
|
||||
|
@ -528,6 +587,11 @@ class BackendExecutor:
|
|||
"expected if one of the workers has crashed.")
|
||||
self.worker_group.shutdown()
|
||||
self.worker_group = InactiveWorkerGroup()
|
||||
|
||||
if self._placement_group:
|
||||
remove_placement_group(self._placement_group)
|
||||
self._placement_group = None
|
||||
|
||||
self.dataset_shards = None
|
||||
|
||||
@property
|
||||
|
@ -567,6 +631,9 @@ class BackendExecutor:
|
|||
initialization_hook = self._initialization_hook
|
||||
else:
|
||||
initialization_hook = None
|
||||
if self._placement_group:
|
||||
remove_placement_group(self._placement_group)
|
||||
self._placement_group = None
|
||||
self.start(initialization_hook=initialization_hook)
|
||||
|
||||
def _increment_failures(self):
|
||||
|
|
|
@ -49,3 +49,12 @@ TUNE_CHECKPOINT_ID = "_current_checkpoint_id"
|
|||
# Backend.share_cuda_visible_devices. 1 for True, 0 for False.
|
||||
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV =\
|
||||
"TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES"
|
||||
|
||||
# Integer value which indicates the number of seconds to wait when creating
|
||||
# the worker placement group before timing out.
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S"
|
||||
|
||||
# Integer value which if set will change the placement group strategy from
|
||||
# PACK to SPREAD. 1 for True, 0 for False.
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV =\
|
||||
"TRAIN_ENABLE_WORKER_SPREAD"
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.cluster_utils import Cluster
|
||||
import ray.train as train
|
||||
from ray.train.backends.backend import BackendConfig, BackendExecutor
|
||||
from ray.train.backends.tensorflow import TensorflowConfig
|
||||
from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.train.backends.torch import TorchConfig
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.train as train
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.train.backends.backend import Backend, \
|
||||
InactiveWorkerGroupError, TrainBackendError, TrainingWorkerError
|
||||
from ray.train.backends.backend import BackendConfig, BackendExecutor
|
||||
from ray.train.backends.tensorflow import TensorflowConfig
|
||||
from ray.train.backends.torch import TorchConfig
|
||||
from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.util.placement_group import get_current_placement_group
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -25,6 +27,20 @@ def ray_start_2_cpus():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_4_node_4_cpu():
|
||||
cluster = Cluster()
|
||||
for _ in range(4):
|
||||
cluster.add_node(num_cpus=4)
|
||||
|
||||
ray.init(address=cluster.address)
|
||||
|
||||
yield
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_2_node_2_gpu():
|
||||
cluster = Cluster()
|
||||
|
@ -393,6 +409,67 @@ def test_cuda_visible_devices_multiple(ray_2_node_4_gpu, worker_results,
|
|||
assert results == expected_results
|
||||
|
||||
|
||||
def get_node_id_set():
|
||||
node_id_set = set()
|
||||
for actor_info in ray.state.actors().values():
|
||||
node_id = actor_info["Address"]["NodeID"]
|
||||
node_id_set.add(node_id)
|
||||
return node_id_set
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [3, 4, 5])
|
||||
def test_placement_group_pack(ray_4_node_4_cpu, num_workers):
|
||||
"""Tests that workers are packed on nodes."""
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=num_workers)
|
||||
e.start()
|
||||
node_id_set = get_node_id_set()
|
||||
assert len(node_id_set) == math.ceil(num_workers / 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [3, 4, 5])
|
||||
def test_placement_group_spread(ray_4_node_4_cpu, num_workers):
|
||||
"""Tests that workers are spread across nodes."""
|
||||
os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV] = "1"
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=num_workers)
|
||||
e.start()
|
||||
node_id_set = get_node_id_set()
|
||||
assert len(node_id_set) == min(num_workers, 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("placement_group_capture_child_tasks", [True, False])
|
||||
def test_placement_group_parent(ray_4_node_4_cpu, tmp_path,
|
||||
placement_group_capture_child_tasks):
|
||||
"""Tests that parent placement group will be used."""
|
||||
num_workers = 2
|
||||
bundle = {"CPU": 1}
|
||||
bundles = [bundle.copy() for _ in range(num_workers + 1)]
|
||||
placement_group = ray.util.placement_group(bundles)
|
||||
|
||||
def train_func():
|
||||
return get_current_placement_group().id
|
||||
|
||||
@ray.remote
|
||||
def test():
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
e.start_training(train_func, run_dir=tmp_path)
|
||||
return e.finish_training()
|
||||
|
||||
results_future = test.options(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=placement_group_capture_child_tasks
|
||||
).remote()
|
||||
results = ray.get(results_future)
|
||||
for worker_result in results:
|
||||
if placement_group_capture_child_tasks:
|
||||
assert worker_result == placement_group.id
|
||||
else:
|
||||
assert worker_result != placement_group.id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -107,6 +107,17 @@ def test_bad_resources(ray_start_2_cpus):
|
|||
WorkerGroup(num_gpus_per_worker=-1)
|
||||
|
||||
|
||||
def test_placement_group(ray_start_2_cpus):
|
||||
"""Tests that workers can be removed and added to a placement group."""
|
||||
num_workers = 2
|
||||
bundle = {"CPU": 1}
|
||||
bundles = [bundle.copy() for _ in range(num_workers)]
|
||||
placement_group = ray.util.placement_group(bundles)
|
||||
wg = WorkerGroup(num_workers=num_workers, placement_group=placement_group)
|
||||
wg.remove_workers([0])
|
||||
wg.add_workers(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -93,6 +93,10 @@ class Trainer:
|
|||
):
|
||||
|
||||
self._backend = backend
|
||||
|
||||
if num_workers <= 0:
|
||||
raise ValueError("`num_workers` must be a positive integer.")
|
||||
|
||||
self._num_workers = num_workers
|
||||
self._use_gpu = use_gpu
|
||||
self._resources_per_worker = resources_per_worker
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
|
||||
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple, Union
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.types import ObjectRef
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -105,6 +106,9 @@ class WorkerGroup:
|
|||
remote actors.
|
||||
remote_cls_args, remote_cls_kwargs: If ``remote_cls`` is provided,
|
||||
these args will be used for the worker initialization.
|
||||
placement_group (PlacementGroup|str): The placement group that workers
|
||||
should be created in. Defaults to "default" which will inherit the
|
||||
parent placement group (if child tasks should be captured).
|
||||
|
||||
|
||||
Example:
|
||||
|
@ -125,7 +129,8 @@ class WorkerGroup:
|
|||
additional_resources_per_worker: Optional[Dict[str, float]] = None,
|
||||
actor_cls: Type = None,
|
||||
actor_cls_args: Optional[Tuple] = None,
|
||||
actor_cls_kwargs: Optional[Dict] = None):
|
||||
actor_cls_kwargs: Optional[Dict] = None,
|
||||
placement_group: Union[PlacementGroup, str] = "default"):
|
||||
|
||||
if num_workers <= 0:
|
||||
raise ValueError("The provided `num_workers` must be greater "
|
||||
|
@ -152,6 +157,8 @@ class WorkerGroup:
|
|||
self._actor_cls_args = actor_cls_args or []
|
||||
self._actor_cls_kwargs = actor_cls_kwargs or {}
|
||||
|
||||
self._placement_group = placement_group
|
||||
|
||||
# TODO(matt): Validate resources. Fast-fail if it is impossible to
|
||||
# handle the request, rather than hang indefinitely.
|
||||
self._remote_cls = ray.remote(
|
||||
|
@ -279,6 +286,9 @@ class WorkerGroup:
|
|||
def remove_workers(self, worker_indexes: List[int]):
|
||||
"""Removes the workers with the specified indexes.
|
||||
|
||||
The removed workers will go out of scope and their actor processes
|
||||
will be terminated.
|
||||
|
||||
Args:
|
||||
worker_indexes (List[int]): The indexes of the workers to remove.
|
||||
"""
|
||||
|
@ -291,14 +301,20 @@ class WorkerGroup:
|
|||
def add_workers(self, num_workers: int):
|
||||
"""Adds ``num_workers`` to this WorkerGroup.
|
||||
|
||||
Note: Adding workers when the cluster/placement group is at capacity
|
||||
may lead to undefined hanging behavior. If you are attempting to
|
||||
replace existing workers in the WorkerGroup, remove_workers() should
|
||||
be called first.
|
||||
|
||||
Args:
|
||||
num_workers (int): The number of workers to add.
|
||||
"""
|
||||
new_actors = []
|
||||
new_actor_metadata = []
|
||||
for _ in range(num_workers):
|
||||
actor = self._remote_cls.remote(*self._actor_cls_args,
|
||||
**self._actor_cls_kwargs)
|
||||
actor = self._remote_cls.options(
|
||||
placement_group=self._placement_group).remote(
|
||||
*self._actor_cls_args, **self._actor_cls_kwargs)
|
||||
new_actors.append(actor)
|
||||
new_actor_metadata.append(
|
||||
actor._BaseWorkerMixin__execute.remote(construct_metadata))
|
||||
|
|
Loading…
Add table
Reference in a new issue