mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[SGD] Retry sgd.local_rank()
(#18824)
* finish * fix * wip * address comment * update * fix test * fix failing test * address comments * fix test * fix
This commit is contained in:
parent
73c3cff18b
commit
00dd190df9
14 changed files with 224 additions and 56 deletions
|
@ -153,6 +153,10 @@ system. Let's take following simple examples:
|
||||||
results = trainer.run(train_func_distributed)
|
results = trainer.run(train_func_distributed)
|
||||||
trainer.shutdown()
|
trainer.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
See :ref:`sgd-porting-code` for a more comprehensive example.
|
||||||
|
|
||||||
|
|
||||||
.. group-tab:: TensorFlow
|
.. group-tab:: TensorFlow
|
||||||
|
|
||||||
This example shows how you can use RaySGD to set up `Multi-worker training
|
This example shows how you can use RaySGD to set up `Multi-worker training
|
||||||
|
@ -250,4 +254,7 @@ system. Let's take following simple examples:
|
||||||
trainer.shutdown()
|
trainer.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
See :ref:`sgd-porting-code` for a more comprehensive example.
|
||||||
|
|
||||||
|
|
||||||
**Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`!
|
**Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`!
|
||||||
|
|
|
@ -83,6 +83,30 @@ training.
|
||||||
sampler=DistributedSampler(dataset))
|
sampler=DistributedSampler(dataset))
|
||||||
|
|
||||||
|
|
||||||
|
**Step 3:** Set the proper CUDA device if you are using GPUs.
|
||||||
|
|
||||||
|
If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function.
|
||||||
|
|
||||||
|
This involves 3 steps:
|
||||||
|
1. Use the local rank to set the default CUDA device for the worker.
|
||||||
|
2. Move the model to the default CUDA device (or a specific CUDA device).
|
||||||
|
3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
device = torch.device(f"cuda:{sgd.local_rank()}" if
|
||||||
|
torch.cuda.is_available() else "cpu")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Create model.
|
||||||
|
model = NeuralNetwork()
|
||||||
|
model = model.to(device)
|
||||||
|
model = DistributedDataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)
|
||||||
|
|
||||||
|
|
||||||
.. group-tab:: TensorFlow
|
.. group-tab:: TensorFlow
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
|
@ -73,7 +73,7 @@ py_test(
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_worker_group",
|
name = "test_worker_group",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["tests/test_worker_group.py"],
|
srcs = ["tests/test_worker_group.py"],
|
||||||
tags = ["team:ml", "exclusive"],
|
tags = ["team:ml", "exclusive"],
|
||||||
deps = [":sgd_v2_lib"]
|
deps = [":sgd_v2_lib"]
|
||||||
|
|
|
@ -3,11 +3,11 @@ from ray.util.sgd.v2.backends import (BackendConfig, HorovodConfig,
|
||||||
from ray.util.sgd.v2.callbacks import SGDCallback
|
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||||
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
|
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
|
||||||
from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report,
|
from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report,
|
||||||
world_rank)
|
world_rank, local_rank)
|
||||||
from ray.util.sgd.v2.trainer import Trainer, SGDIterator
|
from ray.util.sgd.v2.trainer import Trainer, SGDIterator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint",
|
"BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint",
|
||||||
"report", "save_checkpoint", "SGDCallback", "SGDIterator",
|
"local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator",
|
||||||
"TensorflowConfig", "TorchConfig", "Trainer", "world_rank"
|
"TensorflowConfig", "TorchConfig", "Trainer", "world_rank"
|
||||||
]
|
]
|
||||||
|
|
|
@ -15,8 +15,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
|
||||||
TUNE_CHECKPOINT_ID
|
TUNE_CHECKPOINT_ID
|
||||||
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
|
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
|
||||||
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
||||||
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids, \
|
from ray.util.sgd.v2.utils import construct_path, check_for_failure
|
||||||
check_for_failure
|
|
||||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||||
|
|
||||||
if TUNE_INSTALLED:
|
if TUNE_INSTALLED:
|
||||||
|
@ -309,12 +308,8 @@ class BackendExecutor:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_node_id_and_gpu():
|
node_ids_and_gpu_ids = [(w.metadata.node_id, w.metadata.gpu_ids)
|
||||||
node_id = get_node_id()
|
for w in self.worker_group.workers]
|
||||||
gpu_ids = get_gpu_ids()
|
|
||||||
return node_id, gpu_ids
|
|
||||||
|
|
||||||
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
|
|
||||||
|
|
||||||
node_id_to_worker_id = defaultdict(set)
|
node_id_to_worker_id = defaultdict(set)
|
||||||
node_id_to_gpu_ids = defaultdict(set)
|
node_id_to_gpu_ids = defaultdict(set)
|
||||||
|
@ -336,6 +331,37 @@ class BackendExecutor:
|
||||||
worker_id, set_gpu_ids))
|
worker_id, set_gpu_ids))
|
||||||
ray.get(futures)
|
ray.get(futures)
|
||||||
|
|
||||||
|
def _create_local_rank_map(self) -> Dict:
|
||||||
|
"""Create mapping from worker world_rank to local_rank.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Worker 0: 0.0.0.0
|
||||||
|
Worker 1: 0.0.0.0
|
||||||
|
Worker 2: 0.0.0.1
|
||||||
|
Worker 3: 0.0.0.0
|
||||||
|
Worker 4: 0.0.0.1
|
||||||
|
|
||||||
|
Workers 0, 1, 3 are on 0.0.0.0.
|
||||||
|
Workers 2, 4 are on 0.0.0.1.
|
||||||
|
|
||||||
|
Expected Output:
|
||||||
|
{
|
||||||
|
0 -> 0,
|
||||||
|
1 -> 1,
|
||||||
|
2 -> 0,
|
||||||
|
3 -> 2,
|
||||||
|
4 -> 1
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
rank_mapping = {}
|
||||||
|
ip_dict = defaultdict(int)
|
||||||
|
for world_rank in range(len(self.worker_group)):
|
||||||
|
worker = self.worker_group.workers[world_rank]
|
||||||
|
node_ip = worker.metadata.node_ip
|
||||||
|
rank_mapping[world_rank] = ip_dict[node_ip]
|
||||||
|
ip_dict[node_ip] += 1
|
||||||
|
return rank_mapping
|
||||||
|
|
||||||
def start_training(
|
def start_training(
|
||||||
self,
|
self,
|
||||||
train_func: Callable[[], T],
|
train_func: Callable[[], T],
|
||||||
|
@ -371,11 +397,12 @@ class BackendExecutor:
|
||||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
|
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
|
||||||
|
|
||||||
# First initialize the session.
|
# First initialize the session.
|
||||||
def initialize_session(world_rank, train_func, checkpoint):
|
def initialize_session(world_rank, local_rank, train_func, checkpoint):
|
||||||
try:
|
try:
|
||||||
init_session(
|
init_session(
|
||||||
training_func=train_func,
|
training_func=train_func,
|
||||||
world_rank=world_rank,
|
world_rank=world_rank,
|
||||||
|
local_rank=local_rank,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
detailed_autofilled_metrics=use_detailed_autofilled_metrics
|
detailed_autofilled_metrics=use_detailed_autofilled_metrics
|
||||||
)
|
)
|
||||||
|
@ -388,6 +415,8 @@ class BackendExecutor:
|
||||||
|
|
||||||
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
|
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
local_rank_map = self._create_local_rank_map()
|
||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
for world_rank in range(len(self.worker_group)):
|
for world_rank in range(len(self.worker_group)):
|
||||||
futures.append(
|
futures.append(
|
||||||
|
@ -395,6 +424,7 @@ class BackendExecutor:
|
||||||
world_rank,
|
world_rank,
|
||||||
initialize_session,
|
initialize_session,
|
||||||
world_rank=world_rank,
|
world_rank=world_rank,
|
||||||
|
local_rank=local_rank_map[world_rank],
|
||||||
train_func=train_func,
|
train_func=train_func,
|
||||||
checkpoint=checkpoint_dict))
|
checkpoint=checkpoint_dict))
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional, Set
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend
|
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend
|
||||||
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars
|
from ray.util.sgd.v2.utils import update_env_vars
|
||||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -44,9 +44,9 @@ class HorovodConfig(BackendConfig):
|
||||||
return HorovodBackend
|
return HorovodBackend
|
||||||
|
|
||||||
|
|
||||||
def init_env_vars(world_rank: int, world_size: int):
|
def init_env_vars(world_rank: int, world_size: int, node_id: str):
|
||||||
"""Initialize Horovod environment variables."""
|
"""Initialize Horovod environment variables."""
|
||||||
os.environ["HOROVOD_HOSTNAME"] = get_node_id()
|
os.environ["HOROVOD_HOSTNAME"] = node_id
|
||||||
os.environ["HOROVOD_RANK"] = str(world_rank)
|
os.environ["HOROVOD_RANK"] = str(world_rank)
|
||||||
os.environ["HOROVOD_SIZE"] = str(world_size)
|
os.environ["HOROVOD_SIZE"] = str(world_size)
|
||||||
|
|
||||||
|
@ -60,9 +60,11 @@ class HorovodBackend(Backend):
|
||||||
# Initialize workers with Horovod environment variables
|
# Initialize workers with Horovod environment variables
|
||||||
setup_futures = []
|
setup_futures = []
|
||||||
for rank in range(len(worker_group)):
|
for rank in range(len(worker_group)):
|
||||||
|
worker_node_id = worker_group.workers[rank].metadata.node_id
|
||||||
setup_futures.append(
|
setup_futures.append(
|
||||||
worker_group.execute_single_async(rank, init_env_vars, rank,
|
worker_group.execute_single_async(rank, init_env_vars, rank,
|
||||||
len(worker_group)))
|
len(worker_group),
|
||||||
|
worker_node_id))
|
||||||
ray.get(setup_futures)
|
ray.get(setup_futures)
|
||||||
|
|
||||||
# Use Horovod Ray Coordinator
|
# Use Horovod Ray Coordinator
|
||||||
|
@ -70,8 +72,8 @@ class HorovodBackend(Backend):
|
||||||
self.coordinator = Coordinator(backend_config)
|
self.coordinator = Coordinator(backend_config)
|
||||||
|
|
||||||
# Get all the hostnames of all workers
|
# Get all the hostnames of all workers
|
||||||
node_ids = worker_group.execute(get_node_id)
|
node_ids = [w.metadata.node_id for w in worker_group.workers]
|
||||||
hostnames = worker_group.execute(get_hostname)
|
hostnames = [w.metadata.hostname for w in worker_group.workers]
|
||||||
# Register each hostname to the coordinator. assumes the hostname
|
# Register each hostname to the coordinator. assumes the hostname
|
||||||
# ordering is the same.
|
# ordering is the same.
|
||||||
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
|
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
|
||||||
|
|
|
@ -86,6 +86,9 @@ def train_func(config: Dict):
|
||||||
lr = config["lr"]
|
lr = config["lr"]
|
||||||
epochs = config["epochs"]
|
epochs = config["epochs"]
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{sgd.local_rank()}"
|
||||||
|
if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Create data loaders.
|
# Create data loaders.
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
training_data,
|
training_data,
|
||||||
|
@ -97,10 +100,11 @@ def train_func(config: Dict):
|
||||||
sampler=DistributedSampler(test_data))
|
sampler=DistributedSampler(test_data))
|
||||||
|
|
||||||
# Create model.
|
# Create model.
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = NeuralNetwork()
|
model = NeuralNetwork()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model = DistributedDataParallel(model)
|
model = DistributedDataParallel(
|
||||||
|
model,
|
||||||
|
device_ids=[device.index] if torch.cuda.is_available() else None)
|
||||||
|
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||||
|
|
|
@ -32,12 +32,14 @@ class Session:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
training_func: Callable,
|
training_func: Callable,
|
||||||
world_rank: int,
|
world_rank: int,
|
||||||
|
local_rank: int,
|
||||||
checkpoint: Optional[Dict] = None,
|
checkpoint: Optional[Dict] = None,
|
||||||
detailed_autofilled_metrics: bool = False):
|
detailed_autofilled_metrics: bool = False):
|
||||||
# The Thread object that is running the training function.
|
# The Thread object that is running the training function.
|
||||||
self.training_thread = PropagatingThread(
|
self.training_thread = PropagatingThread(
|
||||||
target=training_func, daemon=True)
|
target=training_func, daemon=True)
|
||||||
self.world_rank = world_rank
|
self.world_rank = world_rank
|
||||||
|
self.local_rank = local_rank
|
||||||
self.loaded_checkpoint = checkpoint
|
self.loaded_checkpoint = checkpoint
|
||||||
|
|
||||||
# This lock is used to control the execution of the training thread.
|
# This lock is used to control the execution of the training thread.
|
||||||
|
@ -263,6 +265,29 @@ def world_rank() -> int:
|
||||||
return session.world_rank
|
return session.world_rank
|
||||||
|
|
||||||
|
|
||||||
|
def local_rank() -> int:
|
||||||
|
"""Get the local rank of this worker (rank of the worker on its node).
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import time
|
||||||
|
from ray.util import sgd
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_device(sgd.local_rank())
|
||||||
|
...
|
||||||
|
|
||||||
|
trainer = Trainer(backend="torch", use_gpu=True)
|
||||||
|
trainer.start()
|
||||||
|
trainer.run(train_func)
|
||||||
|
trainer.shutdown()
|
||||||
|
|
||||||
|
"""
|
||||||
|
session = get_session()
|
||||||
|
return session.local_rank
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint() -> Optional[Dict]:
|
def load_checkpoint() -> Optional[Dict]:
|
||||||
"""Loads checkpoint data onto the worker.
|
"""Loads checkpoint data onto the worker.
|
||||||
|
|
||||||
|
|
|
@ -54,9 +54,10 @@ def ray_2_node_4_gpu():
|
||||||
def gen_execute_special(special_f):
|
def gen_execute_special(special_f):
|
||||||
def execute_async_special(self, f):
|
def execute_async_special(self, f):
|
||||||
"""Runs f on worker 0, special_f on other workers."""
|
"""Runs f on worker 0, special_f on other workers."""
|
||||||
futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)]
|
futures = [self.workers[0].actor._BaseWorkerMixin__execute.remote(f)]
|
||||||
for worker in self.workers[1:]:
|
for worker in self.workers[1:]:
|
||||||
futures.append(worker._BaseWorkerMixin__execute.remote(special_f))
|
futures.append(
|
||||||
|
worker.actor._BaseWorkerMixin__execute.remote(special_f))
|
||||||
return futures
|
return futures
|
||||||
|
|
||||||
return execute_async_special
|
return execute_async_special
|
||||||
|
@ -123,6 +124,18 @@ def test_train(ray_start_2_cpus, tmp_path):
|
||||||
assert e.finish_training() == [1, 1]
|
assert e.finish_training() == [1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_ranks(ray_start_2_cpus, tmp_path):
|
||||||
|
config = TestConfig()
|
||||||
|
e = BackendExecutor(config, num_workers=2)
|
||||||
|
e.start()
|
||||||
|
|
||||||
|
def train():
|
||||||
|
return sgd.local_rank()
|
||||||
|
|
||||||
|
e.start_training(train, run_dir=tmp_path)
|
||||||
|
assert set(e.finish_training()) == {0, 1}
|
||||||
|
|
||||||
|
|
||||||
def test_train_failure(ray_start_2_cpus, tmp_path):
|
def test_train_failure(ray_start_2_cpus, tmp_path):
|
||||||
config = TestConfig()
|
config = TestConfig()
|
||||||
e = BackendExecutor(config, num_workers=2)
|
e = BackendExecutor(config, num_workers=2)
|
||||||
|
|
|
@ -2,7 +2,8 @@ import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from ray.util.sgd.v2.session import init_session, shutdown_session, \
|
from ray.util.sgd.v2.session import init_session, shutdown_session, \
|
||||||
get_session, world_rank, report, save_checkpoint, TrainingResultType, \
|
get_session, world_rank, local_rank, report, save_checkpoint, \
|
||||||
|
TrainingResultType, \
|
||||||
load_checkpoint
|
load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ def session():
|
||||||
def f():
|
def f():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
init_session(training_func=f, world_rank=0)
|
init_session(training_func=f, world_rank=0, local_rank=0)
|
||||||
yield get_session()
|
yield get_session()
|
||||||
shutdown_session()
|
shutdown_session()
|
||||||
|
|
||||||
|
@ -34,6 +35,13 @@ def test_world_rank(session):
|
||||||
world_rank()
|
world_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_rank(session):
|
||||||
|
assert local_rank() == 0
|
||||||
|
shutdown_session()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
local_rank()
|
||||||
|
|
||||||
|
|
||||||
def test_train(session):
|
def test_train(session):
|
||||||
session.start()
|
session.start()
|
||||||
output = session.finish()
|
output = session.finish()
|
||||||
|
@ -45,7 +53,7 @@ def test_report():
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
report(loss=i)
|
report(loss=i)
|
||||||
|
|
||||||
init_session(training_func=train, world_rank=0)
|
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
assert session.get_next().data["loss"] == 0
|
assert session.get_next().data["loss"] == 0
|
||||||
|
@ -62,7 +70,7 @@ def test_report_fail():
|
||||||
report(i)
|
report(i)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
init_session(training_func=train, world_rank=0)
|
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
assert session.get_next() is None
|
assert session.get_next() is None
|
||||||
|
@ -96,7 +104,7 @@ def test_checkpoint():
|
||||||
assert next.type == TrainingResultType.CHECKPOINT
|
assert next.type == TrainingResultType.CHECKPOINT
|
||||||
assert next.data["epoch"] == expected
|
assert next.data["epoch"] == expected
|
||||||
|
|
||||||
init_session(training_func=train, world_rank=0)
|
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
validate_zero(0)
|
validate_zero(0)
|
||||||
|
@ -110,7 +118,7 @@ def test_checkpoint():
|
||||||
assert next.type == TrainingResultType.CHECKPOINT
|
assert next.type == TrainingResultType.CHECKPOINT
|
||||||
assert next.data == {}
|
assert next.data == {}
|
||||||
|
|
||||||
init_session(training_func=train, world_rank=1)
|
init_session(training_func=train, world_rank=1, local_rank=1)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
validate_nonzero()
|
validate_nonzero()
|
||||||
|
@ -129,7 +137,7 @@ def test_load_checkpoint_after_save():
|
||||||
checkpoint = load_checkpoint()
|
checkpoint = load_checkpoint()
|
||||||
assert checkpoint["epoch"] == i
|
assert checkpoint["epoch"] == i
|
||||||
|
|
||||||
init_session(training_func=train, world_rank=0)
|
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
|
@ -145,7 +153,7 @@ def test_locking():
|
||||||
import _thread
|
import _thread
|
||||||
_thread.interrupt_main()
|
_thread.interrupt_main()
|
||||||
|
|
||||||
init_session(training_func=train_1, world_rank=0)
|
init_session(training_func=train_1, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
with pytest.raises(KeyboardInterrupt):
|
with pytest.raises(KeyboardInterrupt):
|
||||||
session.start()
|
session.start()
|
||||||
|
@ -156,7 +164,7 @@ def test_locking():
|
||||||
report(loss=i)
|
report(loss=i)
|
||||||
train_1()
|
train_1()
|
||||||
|
|
||||||
init_session(training_func=train_2, world_rank=0)
|
init_session(training_func=train_2, world_rank=0, local_rank=0)
|
||||||
session = get_session()
|
session = get_session()
|
||||||
session.start()
|
session.start()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
|
@ -88,7 +88,7 @@ def gen_execute_single_async_special(special_f):
|
||||||
assert len(self.workers) == 2
|
assert len(self.workers) == 2
|
||||||
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
|
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
|
||||||
kwargs["train_func"] = special_f
|
kwargs["train_func"] = special_f
|
||||||
return self.workers[i]._BaseWorkerMixin__execute.remote(
|
return self.workers[i].actor._BaseWorkerMixin__execute.remote(
|
||||||
f, *args, **kwargs)
|
f, *args, **kwargs)
|
||||||
|
|
||||||
return execute_single_async_special
|
return execute_single_async_special
|
||||||
|
@ -126,7 +126,7 @@ class KillCallback(SGDCallback):
|
||||||
print(results)
|
print(results)
|
||||||
assert all(r["loss"] == 1 for r in results)
|
assert all(r["loss"] == 1 for r in results)
|
||||||
if self.counter == self.fail_on:
|
if self.counter == self.fail_on:
|
||||||
ray.kill(self.worker_group.workers[0])
|
ray.kill(self.worker_group.workers[0].actor)
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
|
|
||||||
|
@ -752,6 +752,27 @@ def test_worker_failure_2(ray_start_2_cpus):
|
||||||
assert results == [1, 1]
|
assert results == [1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_failure_local_rank(ray_start_2_cpus):
|
||||||
|
test_config = TestConfig()
|
||||||
|
|
||||||
|
def train():
|
||||||
|
return sgd.local_rank()
|
||||||
|
|
||||||
|
def train_actor_failure():
|
||||||
|
import sys
|
||||||
|
sys.exit(0)
|
||||||
|
return sgd.local_rank()
|
||||||
|
|
||||||
|
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
|
||||||
|
|
||||||
|
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||||
|
new_backend_executor_cls):
|
||||||
|
trainer = Trainer(test_config, num_workers=2)
|
||||||
|
trainer.start()
|
||||||
|
results = trainer.run(train)
|
||||||
|
assert set(results) == {0, 1}
|
||||||
|
|
||||||
|
|
||||||
def test_worker_start_failure(ray_start_2_cpus):
|
def test_worker_start_failure(ray_start_2_cpus):
|
||||||
test_config = TestConfig()
|
test_config = TestConfig()
|
||||||
|
|
||||||
|
|
|
@ -480,7 +480,7 @@ class SGDWorkerGroup:
|
||||||
self._worker_group = worker_group
|
self._worker_group = worker_group
|
||||||
|
|
||||||
def __getitem__(self, item) -> ActorHandle:
|
def __getitem__(self, item) -> ActorHandle:
|
||||||
return self._worker_group.workers[item]
|
return self._worker_group.workers[item].actor
|
||||||
|
|
||||||
def shutdown(self, patience_s: float = 5):
|
def shutdown(self, patience_s: float = 5):
|
||||||
"""Shutdown all the workers.
|
"""Shutdown all the workers.
|
||||||
|
|
|
@ -87,21 +87,6 @@ class PropagatingThread(Thread):
|
||||||
return self.ret
|
return self.ret
|
||||||
|
|
||||||
|
|
||||||
def get_node_id() -> str:
|
|
||||||
"""Returns the ID of the node that this worker is on."""
|
|
||||||
return ray.get_runtime_context().node_id.hex()
|
|
||||||
|
|
||||||
|
|
||||||
def get_hostname() -> str:
|
|
||||||
"""Returns the hostname that this worker is on."""
|
|
||||||
return socket.gethostname()
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_ids() -> List[int]:
|
|
||||||
"""Return list of CUDA device IDs available to this worker."""
|
|
||||||
return ray.get_gpu_ids()
|
|
||||||
|
|
||||||
|
|
||||||
def update_env_vars(env_vars: Dict[str, Any]):
|
def update_env_vars(env_vars: Dict[str, Any]):
|
||||||
"""Updates the environment variables on this worker process.
|
"""Updates the environment variables on this worker process.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
|
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.actor import ActorHandle
|
||||||
from ray.types import ObjectRef
|
from ray.types import ObjectRef
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -22,6 +25,32 @@ class BaseWorkerMixin:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerMetadata:
|
||||||
|
"""Metadata for each worker/actor.
|
||||||
|
|
||||||
|
This information is expected to stay the same throughout the lifetime of
|
||||||
|
actor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id (str): ID of the node this worker is on.
|
||||||
|
node_ip (str): IP address of the node this worker is on.
|
||||||
|
hostname (str): Hostname that this worker is on.
|
||||||
|
gpu_ids (List[int]): List of CUDA IDs available to this worker.
|
||||||
|
"""
|
||||||
|
node_id: str
|
||||||
|
node_ip: str
|
||||||
|
hostname: str
|
||||||
|
gpu_ids: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Worker:
|
||||||
|
"""Class representing a Worker."""
|
||||||
|
actor: ActorHandle
|
||||||
|
metadata: WorkerMetadata
|
||||||
|
|
||||||
|
|
||||||
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
||||||
"""Create the executable class to use as the Ray actors."""
|
"""Create the executable class to use as the Ray actors."""
|
||||||
if not executable_cls:
|
if not executable_cls:
|
||||||
|
@ -37,6 +66,20 @@ def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
||||||
return _WrappedExecutable
|
return _WrappedExecutable
|
||||||
|
|
||||||
|
|
||||||
|
def construct_metadata() -> WorkerMetadata:
|
||||||
|
"""Creates metadata for this worker.
|
||||||
|
|
||||||
|
This function is expected to be run on the actor.
|
||||||
|
"""
|
||||||
|
node_id = ray.get_runtime_context().node_id.hex()
|
||||||
|
node_ip = ray.util.get_node_ip_address()
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
gpu_ids = ray.get_gpu_ids()
|
||||||
|
|
||||||
|
return WorkerMetadata(
|
||||||
|
node_id=node_id, node_ip=node_ip, hostname=hostname, gpu_ids=gpu_ids)
|
||||||
|
|
||||||
|
|
||||||
class WorkerGroup:
|
class WorkerGroup:
|
||||||
"""Group of Ray Actors that can execute arbitrary functions.
|
"""Group of Ray Actors that can execute arbitrary functions.
|
||||||
|
|
||||||
|
@ -118,8 +161,11 @@ class WorkerGroup:
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def _create_worker(self):
|
def _create_worker(self):
|
||||||
return self._remote_cls.remote(*self._actor_cls_args,
|
actor = self._remote_cls.remote(*self._actor_cls_args,
|
||||||
**self._actor_cls_kwargs)
|
**self._actor_cls_kwargs)
|
||||||
|
actor_metadata = ray.get(
|
||||||
|
actor._BaseWorkerMixin__execute.remote(construct_metadata))
|
||||||
|
return Worker(actor=actor, metadata=actor_metadata)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts all the workers in this worker group."""
|
"""Starts all the workers in this worker group."""
|
||||||
|
@ -145,9 +191,11 @@ class WorkerGroup:
|
||||||
logger.debug(f"Shutting down {len(self.workers)} workers.")
|
logger.debug(f"Shutting down {len(self.workers)} workers.")
|
||||||
if patience_s <= 0:
|
if patience_s <= 0:
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
ray.kill(worker)
|
ray.kill(worker.actor)
|
||||||
else:
|
else:
|
||||||
done_refs = [w.__ray_terminate__.remote() for w in self.workers]
|
done_refs = [
|
||||||
|
w.actor.__ray_terminate__.remote() for w in self.workers
|
||||||
|
]
|
||||||
# Wait for actors to die gracefully.
|
# Wait for actors to die gracefully.
|
||||||
done, not_done = ray.wait(done_refs, timeout=patience_s)
|
done, not_done = ray.wait(done_refs, timeout=patience_s)
|
||||||
if not_done:
|
if not_done:
|
||||||
|
@ -155,7 +203,7 @@ class WorkerGroup:
|
||||||
"force kill.")
|
"force kill.")
|
||||||
# If all actors are not able to die gracefully, then kill them.
|
# If all actors are not able to die gracefully, then kill them.
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
ray.kill(worker)
|
ray.kill(worker.actor)
|
||||||
|
|
||||||
logger.debug("Shutdown successful.")
|
logger.debug("Shutdown successful.")
|
||||||
self.workers = []
|
self.workers = []
|
||||||
|
@ -180,7 +228,7 @@ class WorkerGroup:
|
||||||
"create a new WorkerGroup or restart this one.")
|
"create a new WorkerGroup or restart this one.")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
w._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
|
w.actor._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
|
||||||
for w in self.workers
|
for w in self.workers
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -214,7 +262,8 @@ class WorkerGroup:
|
||||||
if worker_index >= len(self.workers):
|
if worker_index >= len(self.workers):
|
||||||
raise ValueError(f"The provided worker_index {worker_index} is "
|
raise ValueError(f"The provided worker_index {worker_index} is "
|
||||||
f"not valid for {self.num_workers} workers.")
|
f"not valid for {self.num_workers} workers.")
|
||||||
return self.workers[worker_index]._BaseWorkerMixin__execute.remote(
|
return self.workers[worker_index].actor._BaseWorkerMixin__execute\
|
||||||
|
.remote(
|
||||||
func, *args, **kwargs)
|
func, *args, **kwargs)
|
||||||
|
|
||||||
def execute_single(self, worker_index: int, func: Callable[..., T], *args,
|
def execute_single(self, worker_index: int, func: Callable[..., T], *args,
|
||||||
|
|
Loading…
Add table
Reference in a new issue