mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[SGD] v2 Horovod backend (#18047)
* [SGD] add Horovod backend * address comments: set CUDA_VISIBLE_DEVICES, refactor code * fix gpu test * fix lint/test import * address comments, add example cluster config * delay horovod imports
This commit is contained in:
parent
6133a561e9
commit
a3123b6860
13 changed files with 639 additions and 19 deletions
|
@ -434,7 +434,7 @@
|
|||
conditions: ["RAY_CI_SGD_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- SGD_TESTING=1 ./ci/travis/install-dependencies.sh
|
||||
- SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh
|
||||
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client python/ray/util/sgd/...
|
||||
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client python/ray/util/sgd/...
|
||||
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/...
|
||||
|
|
|
@ -42,6 +42,8 @@ MOCK_MODULES = [
|
|||
"gym.spaces",
|
||||
"horovod",
|
||||
"horovod.ray",
|
||||
"horovod.ray.runner",
|
||||
"horovod.ray.utils",
|
||||
"hyperopt",
|
||||
"hyperopt.hp"
|
||||
"kubernetes",
|
||||
|
|
|
@ -118,7 +118,8 @@ def train_fn(data_dir=None,
|
|||
def main(num_workers, use_gpu, **kwargs):
|
||||
settings = RayExecutor.create_settings(timeout_s=30)
|
||||
executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers)
|
||||
executor.run(train_fn, kwargs=kwargs)
|
||||
executor.start()
|
||||
executor.run(train_fn, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -133,7 +134,7 @@ if __name__ == "__main__":
|
|||
metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
metavar="N",
|
||||
|
@ -151,10 +152,10 @@ if __name__ == "__main__":
|
|||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)")
|
||||
parser.add_argument(
|
||||
"--no-cuda",
|
||||
"--use-cuda",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="disables CUDA training")
|
||||
help="enables CUDA training")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
|
@ -183,17 +184,10 @@ if __name__ == "__main__":
|
|||
"will be downloaded if needed)")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
require=False,
|
||||
types=str,
|
||||
default=None,
|
||||
help="Address of Ray cluster.")
|
||||
parser.add_argument(
|
||||
"--server-address",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The address of server to connect to if using "
|
||||
"Ray Client.")
|
||||
help="Address of Ray cluster.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -201,8 +195,6 @@ if __name__ == "__main__":
|
|||
|
||||
if args.address:
|
||||
ray.init(args.address)
|
||||
elif args.server_address:
|
||||
ray.init(f"ray://{args.server_address}")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, TypeVar, List, Optional, Dict, Union
|
||||
|
@ -12,7 +14,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
|
|||
DEFAULT_RESULTS_DIR
|
||||
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
|
||||
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
||||
from ray.util.sgd.v2.utils import construct_path
|
||||
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -105,8 +107,64 @@ class BackendExecutor:
|
|||
self._num_gpus_per_worker)
|
||||
if initialization_hook:
|
||||
self.worker_group.execute(initialization_hook)
|
||||
|
||||
if self._num_gpus_per_worker > 0:
|
||||
self._setup_gpus()
|
||||
|
||||
self._backend.on_start(self.worker_group, self._backend_config)
|
||||
|
||||
def _setup_gpus(self):
|
||||
"""Sets CUDA_VISIBLE_DEVICES on all workers.
|
||||
|
||||
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
|
||||
visible to all workers on that worker's node.
|
||||
|
||||
This allows GPU workers on the same node to communicate with one
|
||||
another.
|
||||
|
||||
Example:
|
||||
|
||||
Setup:
|
||||
- Node1:
|
||||
- Worker1: {0, 1}
|
||||
- Worker2: {2, 3}
|
||||
- Node2:
|
||||
- Worker3: {0, 1}
|
||||
|
||||
CUDA_VISIBLE_DEVICES:
|
||||
- Worker1: "0,1,2,3"
|
||||
- Worker2: "0,1,2,3"
|
||||
- Worker2: "0,1"
|
||||
|
||||
"""
|
||||
|
||||
def get_node_id_and_gpu():
|
||||
node_id = get_node_id()
|
||||
gpu_ids = get_gpu_ids()
|
||||
return node_id, gpu_ids
|
||||
|
||||
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
|
||||
|
||||
node_id_to_worker_id = defaultdict(set)
|
||||
node_id_to_gpu_ids = defaultdict(set)
|
||||
|
||||
for worker_id, (node_id, gpu_ids) in enumerate(node_ids_and_gpu_ids):
|
||||
node_id_to_worker_id[node_id].add(worker_id)
|
||||
node_id_to_gpu_ids[node_id].update(gpu_ids)
|
||||
|
||||
futures = []
|
||||
for node_id, gpu_ids in node_id_to_gpu_ids.items():
|
||||
all_gpu_ids = ",".join([str(gpu_id) for gpu_id in gpu_ids])
|
||||
|
||||
def set_gpu_ids():
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = all_gpu_ids
|
||||
|
||||
for worker_id in node_id_to_worker_id[node_id]:
|
||||
futures.append(
|
||||
self.worker_group.execute_single_async(
|
||||
worker_id, set_gpu_ids))
|
||||
ray.get(futures)
|
||||
|
||||
def start_training(
|
||||
self,
|
||||
train_func: Callable[[], T],
|
||||
|
|
101
python/ray/util/sgd/v2/backends/horovod.py
Normal file
101
python/ray/util/sgd/v2/backends/horovod.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface
|
||||
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
try:
|
||||
from horovod.ray.runner import Coordinator
|
||||
from horovod.ray.utils import detect_nics, nics_to_env_var
|
||||
except ImportError:
|
||||
Coordinator = None
|
||||
detect_nics = None
|
||||
nics_to_env_var = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HorovodConfig(BackendConfig):
|
||||
"""Configurations for Horovod setup.
|
||||
|
||||
See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
|
||||
|
||||
Args:
|
||||
nics (Optional[Set[str]): Network interfaces that can be used for
|
||||
communication.
|
||||
verbose (int): Horovod logging verbosity.
|
||||
"""
|
||||
nics: Optional[Set[str]] = None
|
||||
verbose: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if Coordinator is None:
|
||||
raise ValueError(
|
||||
"`horovod[ray]` is not installed. "
|
||||
"Please install 'horovod[ray]' to use this backend.")
|
||||
|
||||
@property
|
||||
def backend_cls(self):
|
||||
return HorovodBackend
|
||||
|
||||
|
||||
def init_env_vars(world_rank: int, world_size: int):
|
||||
"""Initialize Horovod environment variables."""
|
||||
os.environ["HOROVOD_HOSTNAME"] = get_node_id()
|
||||
os.environ["HOROVOD_RANK"] = str(world_rank)
|
||||
os.environ["HOROVOD_SIZE"] = str(world_size)
|
||||
|
||||
|
||||
class HorovodBackend(BackendInterface):
|
||||
def on_start(self, worker_group: WorkerGroup,
|
||||
backend_config: HorovodConfig):
|
||||
|
||||
# TODO(matt): Implement placement group strategies in BackendExecutor.
|
||||
|
||||
# Initialize workers with Horovod environment variables
|
||||
setup_futures = []
|
||||
for rank in range(len(worker_group)):
|
||||
setup_futures.append(
|
||||
worker_group.execute_single_async(rank, init_env_vars, rank,
|
||||
len(worker_group)))
|
||||
ray.get(setup_futures)
|
||||
|
||||
# Use Horovod Ray Coordinator
|
||||
# backend_config as settings
|
||||
self.coordinator = Coordinator(backend_config)
|
||||
|
||||
# Get all the hostnames of all workers
|
||||
node_ids = worker_group.execute(get_node_id)
|
||||
hostnames = worker_group.execute(get_hostname)
|
||||
# Register each hostname to the coordinator. assumes the hostname
|
||||
# ordering is the same.
|
||||
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
|
||||
self.coordinator.register(hostname, node_id, rank)
|
||||
all_info = self.coordinator.finalize_registration()
|
||||
|
||||
setup_futures = []
|
||||
for rank, local_cross_env_var in all_info.items():
|
||||
setup_futures.append(
|
||||
worker_group.execute_single_async(rank, update_env_vars,
|
||||
local_cross_env_var))
|
||||
ray.get(setup_futures)
|
||||
|
||||
coordinator_envs = self.coordinator.establish_rendezvous()
|
||||
|
||||
nics = detect_nics(
|
||||
backend_config,
|
||||
all_host_names=list(self.coordinator.hostnames),
|
||||
node_workers=worker_group.workers)
|
||||
coordinator_envs.update(nics_to_env_var(nics))
|
||||
|
||||
worker_group.execute(update_env_vars, coordinator_envs)
|
||||
|
||||
def on_shutdown(self, worker_group: WorkerGroup,
|
||||
backend_config: HorovodConfig):
|
||||
# Currently no additional steps are needed
|
||||
pass
|
0
python/ray/util/sgd/v2/examples/horovod/__init__.py
Normal file
0
python/ray/util/sgd/v2/examples/horovod/__init__.py
Normal file
55
python/ray/util/sgd/v2/examples/horovod/horovod-cluster.yaml
Normal file
55
python/ray/util/sgd/v2/examples/horovod/horovod-cluster.yaml
Normal file
|
@ -0,0 +1,55 @@
|
|||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: horovod-cluster
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 3
|
||||
max_workers: 3
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-west-2
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
available_node_types:
|
||||
ray.head.default:
|
||||
min_workers: 0
|
||||
max_workers: 0
|
||||
resources: {}
|
||||
node_config:
|
||||
InstanceType: g3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
|
||||
|
||||
ray.worker.default:
|
||||
min_workers: 3
|
||||
max_workers: 3
|
||||
resources: {}
|
||||
node_config:
|
||||
InstanceType: g3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
|
||||
|
||||
setup_commands:
|
||||
# This replaces the standard anaconda Ray installation
|
||||
- pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl
|
||||
- pip install ray[tune]
|
||||
|
||||
# Install Horovod
|
||||
- HOROVOD_WITH_GLOO=1 HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install torch torchvision horovod
|
221
python/ray/util/sgd/v2/examples/horovod/horovod_example.py
Normal file
221
python/ray/util/sgd/v2/examples/horovod/horovod_example.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import horovod.torch as hvd
|
||||
import ray
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data.distributed
|
||||
from filelock import FileLock
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
def metric_average(val, name):
|
||||
tensor = torch.tensor(val)
|
||||
avg_tensor = hvd.allreduce(tensor, name=name)
|
||||
return avg_tensor.item()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
def train_func(config):
|
||||
data_dir = config.get("data_dir", None)
|
||||
seed = config.get("seed", 42)
|
||||
use_cuda = config.get("use_cuda", False)
|
||||
batch_size = config.get("batch_size", 64)
|
||||
use_adasum = config.get("use_adasum", False)
|
||||
lr = config.get("lr", 0.01)
|
||||
momentum = config.get("momentum", 0.5)
|
||||
num_epochs = config.get("num_epochs", 10)
|
||||
log_interval = config.get("log_interval", 10)
|
||||
|
||||
# Horovod: initialize library.
|
||||
hvd.init()
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if use_cuda:
|
||||
# Horovod: pin GPU to local rank.
|
||||
torch.cuda.set_device(hvd.local_rank())
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Horovod: limit # of CPU threads to be used per worker.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
|
||||
data_dir = data_dir or "~/data"
|
||||
with FileLock(os.path.expanduser("~/.horovod_lock")):
|
||||
train_dataset = \
|
||||
datasets.MNIST(data_dir, train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
]))
|
||||
# Horovod: use DistributedSampler to partition the training data.
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs)
|
||||
|
||||
model = Net()
|
||||
|
||||
# By default, Adasum doesn't need scaling up learning rate.
|
||||
lr_scaler = hvd.size() if not use_adasum else 1
|
||||
|
||||
if use_cuda:
|
||||
# Move model to GPU.
|
||||
model.cuda()
|
||||
# If using GPU Adasum allreduce, scale learning rate by local_size.
|
||||
if use_adasum and hvd.nccl_built():
|
||||
lr_scaler = hvd.local_size()
|
||||
|
||||
# Horovod: scale learning rate by lr_scaler.
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=lr * lr_scaler, momentum=momentum)
|
||||
|
||||
# Horovod: wrap optimizer with DistributedOptimizer.
|
||||
optimizer = hvd.DistributedOptimizer(
|
||||
optimizer,
|
||||
named_parameters=model.named_parameters(),
|
||||
op=hvd.Adasum if use_adasum else hvd.Average)
|
||||
|
||||
results = []
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
model.train()
|
||||
# Horovod: set epoch to sampler for shuffling.
|
||||
train_sampler.set_epoch(epoch)
|
||||
num_batches = len(train_loader)
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if use_cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % log_interval == 0:
|
||||
# Horovod: use train_sampler to determine the number of
|
||||
# examples in this worker's partition.
|
||||
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch, batch_idx * len(data), len(train_sampler),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
if batch_idx == num_batches - 1:
|
||||
results.append(loss.item())
|
||||
return results
|
||||
|
||||
|
||||
def main(num_workers, use_gpu, kwargs):
|
||||
trainer = Trainer("horovod", use_gpu=use_gpu, num_workers=num_workers)
|
||||
trainer.start()
|
||||
loss_per_epoch = trainer.run(train_func, config=kwargs)
|
||||
trainer.shutdown()
|
||||
print(loss_per_epoch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch MNIST Example",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 10)")
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 0.01)")
|
||||
parser.add_argument(
|
||||
"--momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables CUDA training")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
metavar="S",
|
||||
help="random seed (default: 42)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status")
|
||||
parser.add_argument(
|
||||
"--use-adasum",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use adasum algorithm to do reduction")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of Ray workers to use for training.")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
help="location of the training dataset in the local filesystem ("
|
||||
"will be downloaded if needed)")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of Ray cluster.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.address:
|
||||
ray.init(args.address)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
use_cuda = args.use_gpu if args.use_gpu is not None else False
|
||||
|
||||
kwargs = {
|
||||
"data_dir": args.data_dir,
|
||||
"seed": args.seed,
|
||||
"use_cuda": use_cuda,
|
||||
"batch_size": args.batch_size,
|
||||
"use_adasum": args.use_adasum if args.use_adasum else False,
|
||||
"lr": args.lr,
|
||||
"momentum": args.momentum,
|
||||
"num_epochs": args.num_epochs,
|
||||
"log_interval": args.log_interval
|
||||
}
|
||||
|
||||
main(num_workers=args.num_workers, use_gpu=use_cuda, kwargs=kwargs)
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.util.sgd import v2 as sgd
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor
|
||||
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
|
||||
|
@ -22,6 +23,34 @@ def ray_start_2_cpus():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_2_node_2_gpu():
|
||||
cluster = Cluster()
|
||||
for _ in range(2):
|
||||
cluster.add_node(num_cpus=2, num_gpus=2)
|
||||
|
||||
ray.init(address=cluster.address)
|
||||
|
||||
yield
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_2_node_4_gpu():
|
||||
cluster = Cluster()
|
||||
for _ in range(2):
|
||||
cluster.add_node(num_cpus=2, num_gpus=4)
|
||||
|
||||
ray.init(address=cluster.address)
|
||||
|
||||
yield
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def gen_execute_special(special_f):
|
||||
def execute_async_special(self, f):
|
||||
"""Runs f on worker 0, special_f on worker 1."""
|
||||
|
@ -249,6 +278,80 @@ def test_torch_start_shutdown(ray_start_2_cpus, init_method):
|
|||
assert not any(e.finish_training())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker_results", [(1, ["0"]), (2, ["0,1", "0,1"]),
|
||||
(3, ["0", "0,1", "0,1"]),
|
||||
(4, ["0,1", "0,1", "0,1", "0,1"])])
|
||||
def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results):
|
||||
config = TestConfig()
|
||||
|
||||
def get_resources():
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
num_workers, expected_results = worker_results
|
||||
|
||||
e = BackendExecutor(
|
||||
config,
|
||||
num_workers=num_workers,
|
||||
num_cpus_per_worker=0,
|
||||
num_gpus_per_worker=1)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"worker_results",
|
||||
[(1, ["0"]), (2, ["0", "0"]), (3, ["0,1", "0,1", "0,1"]),
|
||||
(4, ["0,1", "0,1", "0,1", "0,1"]), (5, ["0", "0,1", "0,1", "0,1", "0,1"]),
|
||||
(6, ["0", "0", "0,1", "0,1", "0,1", "0,1"]),
|
||||
(7, ["0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1"]),
|
||||
(8, ["0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1"])])
|
||||
def test_cuda_visible_devices_fractional(ray_2_node_2_gpu, worker_results):
|
||||
config = TestConfig()
|
||||
|
||||
def get_resources():
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
num_workers, expected_results = worker_results
|
||||
|
||||
e = BackendExecutor(
|
||||
config,
|
||||
num_workers=num_workers,
|
||||
num_cpus_per_worker=0,
|
||||
num_gpus_per_worker=0.5)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker_results",
|
||||
[(1, ["0,1"]), (2, ["0,1,2,3", "0,1,2,3"]),
|
||||
(3, ["0,1", "0,1,2,3", "0,1,2,3"]),
|
||||
(4, ["0,1,2,3", "0,1,2,3", "0,1,2,3", "0,1,2,3"])])
|
||||
def test_cuda_visible_devices_multiple(ray_2_node_4_gpu, worker_results):
|
||||
config = TestConfig()
|
||||
|
||||
def get_resources():
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
num_workers, expected_results = worker_results
|
||||
|
||||
e = BackendExecutor(
|
||||
config,
|
||||
num_workers=num_workers,
|
||||
num_cpus_per_worker=0,
|
||||
num_gpus_per_worker=2)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -2,6 +2,7 @@ import time
|
|||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import horovod.torch as hvd_torch
|
||||
import pytest
|
||||
import ray
|
||||
import ray.util.sgd.v2 as sgd
|
||||
|
@ -17,6 +18,9 @@ from ray.util.sgd.v2.examples.train_fashion_mnist import train_func as \
|
|||
fashion_mnist_train_func
|
||||
from ray.util.sgd.v2.examples.train_linear import train_func as \
|
||||
linear_train_func
|
||||
|
||||
from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \
|
||||
horovod_torch_train_func
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
|
||||
|
@ -537,6 +541,61 @@ def test_torch_fashion_mnist_gpu(ray_start_2_cpus_2_gpus):
|
|||
assert result[-1] < result[0]
|
||||
|
||||
|
||||
def test_horovod_simple(ray_start_2_cpus):
|
||||
def simple_fn():
|
||||
hvd_torch.init()
|
||||
return hvd_torch.rank()
|
||||
|
||||
num_workers = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
trainer.start()
|
||||
result = trainer.run(simple_fn)
|
||||
trainer.shutdown()
|
||||
|
||||
assert result == list(range(num_workers))
|
||||
|
||||
|
||||
def test_horovod_torch_mnist(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
horovod_torch_train_func,
|
||||
config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
for worker_result in results:
|
||||
assert len(worker_result) == num_epochs
|
||||
assert worker_result[num_epochs - 1] < worker_result[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2,
|
||||
reason="Only run if multiple GPUs are available.")
|
||||
def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers, use_gpu=True)
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
horovod_torch_train_func,
|
||||
config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
for worker_result in results:
|
||||
assert len(worker_result) == num_epochs
|
||||
assert worker_result[num_epochs - 1] < worker_result[0]
|
||||
|
||||
|
||||
def test_init_failure(ray_start_2_cpus):
|
||||
with pytest.raises(TypeError):
|
||||
Trainer(5)
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
|
|||
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \
|
||||
InactiveWorkerGroupError, SGDBackendError
|
||||
from ray.util.sgd.v2.backends.horovod import HorovodConfig
|
||||
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
|
||||
from ray.util.sgd.v2.backends.torch import TorchConfig
|
||||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
|
@ -33,6 +34,7 @@ S = TypeVar("S")
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
BACKEND_NAME_TO_CONFIG_CLS = {
|
||||
"horovod": HorovodConfig,
|
||||
"tensorflow": TensorflowConfig,
|
||||
"torch": TorchConfig
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
from contextlib import closing
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Dict, List, Any
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -49,3 +50,28 @@ class PropagatingThread(Thread):
|
|||
if self.exc:
|
||||
raise self.exc
|
||||
return self.ret
|
||||
|
||||
|
||||
def get_node_id() -> str:
|
||||
"""Returns the ID of the node that this worker is on."""
|
||||
return ray.get_runtime_context().node_id.hex()
|
||||
|
||||
|
||||
def get_hostname() -> str:
|
||||
"""Returns the hostname that this worker is on."""
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def get_gpu_ids() -> List[int]:
|
||||
"""Return list of CUDA device IDs available to this worker."""
|
||||
return ray.get_gpu_ids()
|
||||
|
||||
|
||||
def update_env_vars(env_vars: Dict[str, Any]):
|
||||
"""Updates the environment variables on this worker process.
|
||||
|
||||
Args:
|
||||
env_vars (Dict): Environment variables to set.
|
||||
"""
|
||||
sanitized = {k: str(v) for k, v in env_vars.items()}
|
||||
os.environ.update(sanitized)
|
||||
|
|
|
@ -14,8 +14,9 @@ class BaseWorker:
|
|||
|
||||
def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Executes the input function and returns the output.
|
||||
|
||||
Args:
|
||||
func(Callable): The function to execute.
|
||||
func (Callable): The function to execute.
|
||||
args, kwargs: The arguments to pass into func.
|
||||
"""
|
||||
return func(*args, **kwargs)
|
||||
|
|
Loading…
Add table
Reference in a new issue