[SGD] Retry sgd.local_rank() (#18824)

* finish

* fix

* wip

* address comment

* update

* fix test

* fix failing test

* address comments

* fix test

* fix
This commit is contained in:
Amog Kamsetty 2021-09-22 15:48:38 -07:00 committed by GitHub
parent 73c3cff18b
commit 00dd190df9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 224 additions and 56 deletions

View file

@ -153,6 +153,10 @@ system. Let's take following simple examples:
results = trainer.run(train_func_distributed)
trainer.shutdown()
See :ref:`sgd-porting-code` for a more comprehensive example.
.. group-tab:: TensorFlow
This example shows how you can use RaySGD to set up `Multi-worker training
@ -250,4 +254,7 @@ system. Let's take following simple examples:
trainer.shutdown()
See :ref:`sgd-porting-code` for a more comprehensive example.
**Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`!

View file

@ -83,6 +83,30 @@ training.
sampler=DistributedSampler(dataset))
**Step 3:** Set the proper CUDA device if you are using GPUs.
If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function.
This involves 3 steps:
1. Use the local rank to set the default CUDA device for the worker.
2. Move the model to the default CUDA device (or a specific CUDA device).
3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``.
.. code-block:: python
def train_func():
device = torch.device(f"cuda:{sgd.local_rank()}" if
torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
# Create model.
model = NeuralNetwork()
model = model.to(device)
model = DistributedDataParallel(
model,
device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)
.. group-tab:: TensorFlow
.. note::

View file

@ -73,7 +73,7 @@ py_test(
py_test(
name = "test_worker_group",
size = "small",
size = "medium",
srcs = ["tests/test_worker_group.py"],
tags = ["team:ml", "exclusive"],
deps = [":sgd_v2_lib"]

View file

@ -3,11 +3,11 @@ from ray.util.sgd.v2.backends import (BackendConfig, HorovodConfig,
from ray.util.sgd.v2.callbacks import SGDCallback
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report,
world_rank)
world_rank, local_rank)
from ray.util.sgd.v2.trainer import Trainer, SGDIterator
__all__ = [
"BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint",
"report", "save_checkpoint", "SGDCallback", "SGDIterator",
"local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator",
"TensorflowConfig", "TorchConfig", "Trainer", "world_rank"
]

View file

@ -15,8 +15,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
TUNE_CHECKPOINT_ID
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids, \
check_for_failure
from ray.util.sgd.v2.utils import construct_path, check_for_failure
from ray.util.sgd.v2.worker_group import WorkerGroup
if TUNE_INSTALLED:
@ -309,12 +308,8 @@ class BackendExecutor:
"""
def get_node_id_and_gpu():
node_id = get_node_id()
gpu_ids = get_gpu_ids()
return node_id, gpu_ids
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
node_ids_and_gpu_ids = [(w.metadata.node_id, w.metadata.gpu_ids)
for w in self.worker_group.workers]
node_id_to_worker_id = defaultdict(set)
node_id_to_gpu_ids = defaultdict(set)
@ -336,6 +331,37 @@ class BackendExecutor:
worker_id, set_gpu_ids))
ray.get(futures)
def _create_local_rank_map(self) -> Dict:
"""Create mapping from worker world_rank to local_rank.
Example:
Worker 0: 0.0.0.0
Worker 1: 0.0.0.0
Worker 2: 0.0.0.1
Worker 3: 0.0.0.0
Worker 4: 0.0.0.1
Workers 0, 1, 3 are on 0.0.0.0.
Workers 2, 4 are on 0.0.0.1.
Expected Output:
{
0 -> 0,
1 -> 1,
2 -> 0,
3 -> 2,
4 -> 1
}
"""
rank_mapping = {}
ip_dict = defaultdict(int)
for world_rank in range(len(self.worker_group)):
worker = self.worker_group.workers[world_rank]
node_ip = worker.metadata.node_ip
rank_mapping[world_rank] = ip_dict[node_ip]
ip_dict[node_ip] += 1
return rank_mapping
def start_training(
self,
train_func: Callable[[], T],
@ -371,11 +397,12 @@ class BackendExecutor:
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
# First initialize the session.
def initialize_session(world_rank, train_func, checkpoint):
def initialize_session(world_rank, local_rank, train_func, checkpoint):
try:
init_session(
training_func=train_func,
world_rank=world_rank,
local_rank=local_rank,
checkpoint=checkpoint,
detailed_autofilled_metrics=use_detailed_autofilled_metrics
)
@ -388,6 +415,8 @@ class BackendExecutor:
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
local_rank_map = self._create_local_rank_map()
futures = []
for world_rank in range(len(self.worker_group)):
futures.append(
@ -395,6 +424,7 @@ class BackendExecutor:
world_rank,
initialize_session,
world_rank=world_rank,
local_rank=local_rank_map[world_rank],
train_func=train_func,
checkpoint=checkpoint_dict))

View file

@ -5,7 +5,7 @@ from typing import Optional, Set
import ray
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars
from ray.util.sgd.v2.utils import update_env_vars
from ray.util.sgd.v2.worker_group import WorkerGroup
try:
@ -44,9 +44,9 @@ class HorovodConfig(BackendConfig):
return HorovodBackend
def init_env_vars(world_rank: int, world_size: int):
def init_env_vars(world_rank: int, world_size: int, node_id: str):
"""Initialize Horovod environment variables."""
os.environ["HOROVOD_HOSTNAME"] = get_node_id()
os.environ["HOROVOD_HOSTNAME"] = node_id
os.environ["HOROVOD_RANK"] = str(world_rank)
os.environ["HOROVOD_SIZE"] = str(world_size)
@ -60,9 +60,11 @@ class HorovodBackend(Backend):
# Initialize workers with Horovod environment variables
setup_futures = []
for rank in range(len(worker_group)):
worker_node_id = worker_group.workers[rank].metadata.node_id
setup_futures.append(
worker_group.execute_single_async(rank, init_env_vars, rank,
len(worker_group)))
len(worker_group),
worker_node_id))
ray.get(setup_futures)
# Use Horovod Ray Coordinator
@ -70,8 +72,8 @@ class HorovodBackend(Backend):
self.coordinator = Coordinator(backend_config)
# Get all the hostnames of all workers
node_ids = worker_group.execute(get_node_id)
hostnames = worker_group.execute(get_hostname)
node_ids = [w.metadata.node_id for w in worker_group.workers]
hostnames = [w.metadata.hostname for w in worker_group.workers]
# Register each hostname to the coordinator. assumes the hostname
# ordering is the same.
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):

View file

@ -86,6 +86,9 @@ def train_func(config: Dict):
lr = config["lr"]
epochs = config["epochs"]
device = torch.device(f"cuda:{sgd.local_rank()}"
if torch.cuda.is_available() else "cpu")
# Create data loaders.
train_dataloader = DataLoader(
training_data,
@ -97,10 +100,11 @@ def train_func(config: Dict):
sampler=DistributedSampler(test_data))
# Create model.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNetwork()
model = model.to(device)
model = DistributedDataParallel(model)
model = DistributedDataParallel(
model,
device_ids=[device.index] if torch.cuda.is_available() else None)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

View file

@ -32,12 +32,14 @@ class Session:
def __init__(self,
training_func: Callable,
world_rank: int,
local_rank: int,
checkpoint: Optional[Dict] = None,
detailed_autofilled_metrics: bool = False):
# The Thread object that is running the training function.
self.training_thread = PropagatingThread(
target=training_func, daemon=True)
self.world_rank = world_rank
self.local_rank = local_rank
self.loaded_checkpoint = checkpoint
# This lock is used to control the execution of the training thread.
@ -263,6 +265,29 @@ def world_rank() -> int:
return session.world_rank
def local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
import time
from ray.util import sgd
def train_func():
if torch.cuda.is_available():
torch.cuda.set_device(sgd.local_rank())
...
trainer = Trainer(backend="torch", use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
"""
session = get_session()
return session.local_rank
def load_checkpoint() -> Optional[Dict]:
"""Loads checkpoint data onto the worker.

View file

@ -54,9 +54,10 @@ def ray_2_node_4_gpu():
def gen_execute_special(special_f):
def execute_async_special(self, f):
"""Runs f on worker 0, special_f on other workers."""
futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)]
futures = [self.workers[0].actor._BaseWorkerMixin__execute.remote(f)]
for worker in self.workers[1:]:
futures.append(worker._BaseWorkerMixin__execute.remote(special_f))
futures.append(
worker.actor._BaseWorkerMixin__execute.remote(special_f))
return futures
return execute_async_special
@ -123,6 +124,18 @@ def test_train(ray_start_2_cpus, tmp_path):
assert e.finish_training() == [1, 1]
def test_local_ranks(ray_start_2_cpus, tmp_path):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
def train():
return sgd.local_rank()
e.start_training(train, run_dir=tmp_path)
assert set(e.finish_training()) == {0, 1}
def test_train_failure(ray_start_2_cpus, tmp_path):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)

View file

@ -2,7 +2,8 @@ import time
import pytest
from ray.util.sgd.v2.session import init_session, shutdown_session, \
get_session, world_rank, report, save_checkpoint, TrainingResultType, \
get_session, world_rank, local_rank, report, save_checkpoint, \
TrainingResultType, \
load_checkpoint
@ -11,7 +12,7 @@ def session():
def f():
return 1
init_session(training_func=f, world_rank=0)
init_session(training_func=f, world_rank=0, local_rank=0)
yield get_session()
shutdown_session()
@ -34,6 +35,13 @@ def test_world_rank(session):
world_rank()
def test_local_rank(session):
assert local_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
local_rank()
def test_train(session):
session.start()
output = session.finish()
@ -45,7 +53,7 @@ def test_report():
for i in range(2):
report(loss=i)
init_session(training_func=train, world_rank=0)
init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session()
session.start()
assert session.get_next().data["loss"] == 0
@ -62,7 +70,7 @@ def test_report_fail():
report(i)
return 1
init_session(training_func=train, world_rank=0)
init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session()
session.start()
assert session.get_next() is None
@ -96,7 +104,7 @@ def test_checkpoint():
assert next.type == TrainingResultType.CHECKPOINT
assert next.data["epoch"] == expected
init_session(training_func=train, world_rank=0)
init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session()
session.start()
validate_zero(0)
@ -110,7 +118,7 @@ def test_checkpoint():
assert next.type == TrainingResultType.CHECKPOINT
assert next.data == {}
init_session(training_func=train, world_rank=1)
init_session(training_func=train, world_rank=1, local_rank=1)
session = get_session()
session.start()
validate_nonzero()
@ -129,7 +137,7 @@ def test_load_checkpoint_after_save():
checkpoint = load_checkpoint()
assert checkpoint["epoch"] == i
init_session(training_func=train, world_rank=0)
init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session()
session.start()
for i in range(2):
@ -145,7 +153,7 @@ def test_locking():
import _thread
_thread.interrupt_main()
init_session(training_func=train_1, world_rank=0)
init_session(training_func=train_1, world_rank=0, local_rank=0)
session = get_session()
with pytest.raises(KeyboardInterrupt):
session.start()
@ -156,7 +164,7 @@ def test_locking():
report(loss=i)
train_1()
init_session(training_func=train_2, world_rank=0)
init_session(training_func=train_2, world_rank=0, local_rank=0)
session = get_session()
session.start()
time.sleep(3)

View file

@ -88,7 +88,7 @@ def gen_execute_single_async_special(special_f):
assert len(self.workers) == 2
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
kwargs["train_func"] = special_f
return self.workers[i]._BaseWorkerMixin__execute.remote(
return self.workers[i].actor._BaseWorkerMixin__execute.remote(
f, *args, **kwargs)
return execute_single_async_special
@ -126,7 +126,7 @@ class KillCallback(SGDCallback):
print(results)
assert all(r["loss"] == 1 for r in results)
if self.counter == self.fail_on:
ray.kill(self.worker_group.workers[0])
ray.kill(self.worker_group.workers[0].actor)
time.sleep(3)
self.counter += 1
@ -752,6 +752,27 @@ def test_worker_failure_2(ray_start_2_cpus):
assert results == [1, 1]
def test_worker_failure_local_rank(ray_start_2_cpus):
test_config = TestConfig()
def train():
return sgd.local_rank()
def train_actor_failure():
import sys
sys.exit(0)
return sgd.local_rank()
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
results = trainer.run(train)
assert set(results) == {0, 1}
def test_worker_start_failure(ray_start_2_cpus):
test_config = TestConfig()

View file

@ -480,7 +480,7 @@ class SGDWorkerGroup:
self._worker_group = worker_group
def __getitem__(self, item) -> ActorHandle:
return self._worker_group.workers[item]
return self._worker_group.workers[item].actor
def shutdown(self, patience_s: float = 5):
"""Shutdown all the workers.

View file

@ -87,21 +87,6 @@ class PropagatingThread(Thread):
return self.ret
def get_node_id() -> str:
"""Returns the ID of the node that this worker is on."""
return ray.get_runtime_context().node_id.hex()
def get_hostname() -> str:
"""Returns the hostname that this worker is on."""
return socket.gethostname()
def get_gpu_ids() -> List[int]:
"""Return list of CUDA device IDs available to this worker."""
return ray.get_gpu_ids()
def update_env_vars(env_vars: Dict[str, Any]):
"""Updates the environment variables on this worker process.

View file

@ -1,7 +1,10 @@
import socket
from dataclasses import dataclass
import logging
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
import ray
from ray.actor import ActorHandle
from ray.types import ObjectRef
T = TypeVar("T")
@ -22,6 +25,32 @@ class BaseWorkerMixin:
return func(*args, **kwargs)
@dataclass
class WorkerMetadata:
"""Metadata for each worker/actor.
This information is expected to stay the same throughout the lifetime of
actor.
Args:
node_id (str): ID of the node this worker is on.
node_ip (str): IP address of the node this worker is on.
hostname (str): Hostname that this worker is on.
gpu_ids (List[int]): List of CUDA IDs available to this worker.
"""
node_id: str
node_ip: str
hostname: str
gpu_ids: Optional[List[int]]
@dataclass
class Worker:
"""Class representing a Worker."""
actor: ActorHandle
metadata: WorkerMetadata
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
"""Create the executable class to use as the Ray actors."""
if not executable_cls:
@ -37,6 +66,20 @@ def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
return _WrappedExecutable
def construct_metadata() -> WorkerMetadata:
"""Creates metadata for this worker.
This function is expected to be run on the actor.
"""
node_id = ray.get_runtime_context().node_id.hex()
node_ip = ray.util.get_node_ip_address()
hostname = socket.gethostname()
gpu_ids = ray.get_gpu_ids()
return WorkerMetadata(
node_id=node_id, node_ip=node_ip, hostname=hostname, gpu_ids=gpu_ids)
class WorkerGroup:
"""Group of Ray Actors that can execute arbitrary functions.
@ -118,8 +161,11 @@ class WorkerGroup:
self.start()
def _create_worker(self):
return self._remote_cls.remote(*self._actor_cls_args,
**self._actor_cls_kwargs)
actor = self._remote_cls.remote(*self._actor_cls_args,
**self._actor_cls_kwargs)
actor_metadata = ray.get(
actor._BaseWorkerMixin__execute.remote(construct_metadata))
return Worker(actor=actor, metadata=actor_metadata)
def start(self):
"""Starts all the workers in this worker group."""
@ -145,9 +191,11 @@ class WorkerGroup:
logger.debug(f"Shutting down {len(self.workers)} workers.")
if patience_s <= 0:
for worker in self.workers:
ray.kill(worker)
ray.kill(worker.actor)
else:
done_refs = [w.__ray_terminate__.remote() for w in self.workers]
done_refs = [
w.actor.__ray_terminate__.remote() for w in self.workers
]
# Wait for actors to die gracefully.
done, not_done = ray.wait(done_refs, timeout=patience_s)
if not_done:
@ -155,7 +203,7 @@ class WorkerGroup:
"force kill.")
# If all actors are not able to die gracefully, then kill them.
for worker in self.workers:
ray.kill(worker)
ray.kill(worker.actor)
logger.debug("Shutdown successful.")
self.workers = []
@ -180,7 +228,7 @@ class WorkerGroup:
"create a new WorkerGroup or restart this one.")
return [
w._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
w.actor._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
for w in self.workers
]
@ -214,7 +262,8 @@ class WorkerGroup:
if worker_index >= len(self.workers):
raise ValueError(f"The provided worker_index {worker_index} is "
f"not valid for {self.num_workers} workers.")
return self.workers[worker_index]._BaseWorkerMixin__execute.remote(
return self.workers[worker_index].actor._BaseWorkerMixin__execute\
.remote(
func, *args, **kwargs)
def execute_single(self, worker_index: int, func: Callable[..., T], *args,