[horovod] remove deprecated slot concept, use worker instead (#22708)

Horovod updated the attributes of DistributedTrainableCreator and args to create Horovod RayExecutor.
horovod/horovod@a729ba7

The major issue is Horovod deprecated "slot" concept, use "worker" instead, which is more consistent with Generic Ray worker. The issue is currently blocking Uber DL trainers to use raytune.

This commit updates the Horovod RayExecutor init args.

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
kyle-chen-uber 2022-03-10 00:16:42 -08:00 committed by GitHub
parent 18d535f290
commit 592656ca28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 41 deletions

View file

@ -415,7 +415,7 @@ install_dependencies() {
# This must be run last (i.e., torch cannot be re-installed after this) # This must be run last (i.e., torch cannot be re-installed after this)
if [ "${INSTALL_HOROVOD-}" = 1 ]; then if [ "${INSTALL_HOROVOD-}" = 1 ]; then
# TODO: eventually pin this to master. # TODO: eventually pin this to master.
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git@06aa579c9966035453f92208706157dee14c14ab HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git@a1f17d81f01543196b2c23240da692d9ae310942
fi fi
CC=gcc pip install psutil setproctitle==1.2.2 colorama --target="${WORKSPACE_DIR}/python/ray/thirdparty_files" CC=gcc pip install psutil setproctitle==1.2.2 colorama --target="${WORKSPACE_DIR}/python/ray/thirdparty_files"

View file

@ -61,7 +61,7 @@ def train(config):
print(hvd.size()) print(hvd.size())
np.random.seed(1 + hvd.rank()) np.random.seed(1 + hvd.rank())
torch.manual_seed(1234) torch.manual_seed(1234)
# To ensure consistent initialization across slots, # To ensure consistent initialization across workers,
hvd.broadcast_parameters(net.state_dict(), root_rank=0) hvd.broadcast_parameters(net.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0)
@ -85,14 +85,11 @@ def train(config):
print(f"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.") print(f"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.")
def tune_horovod( def tune_horovod(num_workers, num_samples, use_gpu, mode="square", x_max=1.0):
hosts_per_trial, slots_per_host, num_samples, use_gpu, mode="square", x_max=1.0
):
horovod_trainable = DistributedTrainableCreator( horovod_trainable = DistributedTrainableCreator(
train, train,
use_gpu=use_gpu, use_gpu=use_gpu,
num_hosts=hosts_per_trial, num_workers=num_workers,
num_slots=slots_per_host,
replicate_pem=False, replicate_pem=False,
) )
analysis = tune.run( analysis = tune.run(
@ -121,8 +118,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--smoke-test", action="store_true", help=("Finish quickly for testing.") "--smoke-test", action="store_true", help=("Finish quickly for testing.")
) )
parser.add_argument("--hosts-per-trial", type=int, default=1) parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--slots-per-host", type=int, default=2)
parser.add_argument( parser.add_argument(
"--server-address", "--server-address",
type=str, type=str,
@ -141,8 +137,7 @@ if __name__ == "__main__":
# ray.init(address="auto") # assumes ray is started with ray up # ray.init(address="auto") # assumes ray is started with ray up
tune_horovod( tune_horovod(
hosts_per_trial=args.hosts_per_trial, num_workers=args.num_workers,
slots_per_host=args.slots_per_host,
num_samples=2 if args.smoke_test else 10, num_samples=2 if args.smoke_test else 10,
use_gpu=args.gpu, use_gpu=args.gpu,
mode=args.mode, mode=args.mode,

View file

@ -1,4 +1,4 @@
from typing import Callable, Dict, Type from typing import Callable, Dict, Type, Optional
from contextlib import contextmanager from contextlib import contextmanager
import os import os
@ -79,12 +79,12 @@ class _HorovodTrainable(DistributedTrainable):
# Callable function for training. # Callable function for training.
_function = None _function = None
# Number of workers to allocate per trial.
_num_workers: Optional[int] = (None,)
# Number of hosts (nodes) to allocate per trial # Number of hosts (nodes) to allocate per trial
_num_hosts: int = 1 _num_hosts: Optional[int] = (None,)
# Number of workers (slots) to place on each host.
_num_slots: int = 1
# Number of CPU resources to reserve for each worker. # Number of CPU resources to reserve for each worker.
_num_cpus_per_slot: int = 1 _num_cpus_per_worker: int = 1
# Whether to reserve and pass GPU resources through. # Whether to reserve and pass GPU resources through.
_use_gpu: bool = False _use_gpu: bool = False
# bool: Whether a the function has completed training # bool: Whether a the function has completed training
@ -97,7 +97,7 @@ class _HorovodTrainable(DistributedTrainable):
@property @property
def num_workers(self): def num_workers(self):
return self._num_hosts * self._num_slots return self._num_workers
def setup(self, config: Dict): def setup(self, config: Dict):
trainable = wrap_function(self.__class__._function) trainable = wrap_function(self.__class__._function)
@ -115,10 +115,9 @@ class _HorovodTrainable(DistributedTrainable):
self.executor = RayExecutor( self.executor = RayExecutor(
settings, settings,
cpus_per_slot=self._num_cpus_per_slot, cpus_per_worker=self._num_cpus_per_worker,
use_gpu=self._use_gpu, use_gpu=self._use_gpu,
num_hosts=self._num_hosts, num_workers=self._num_workers,
num_slots=self._num_slots,
) )
new_config = DistributedTrainable.build_config(self, config) new_config = DistributedTrainable.build_config(self, config)
@ -163,9 +162,9 @@ class _HorovodTrainable(DistributedTrainable):
def DistributedTrainableCreator( def DistributedTrainableCreator(
func: Callable, func: Callable,
use_gpu: bool = False, use_gpu: bool = False,
num_hosts: int = 1, num_hosts: Optional[int] = None,
num_slots: int = 1, num_workers: int = 1,
num_cpus_per_slot: int = 1, num_cpus_per_worker: int = 1,
timeout_s: int = 30, timeout_s: int = 30,
replicate_pem: bool = False, replicate_pem: bool = False,
) -> Type[_HorovodTrainable]: ) -> Type[_HorovodTrainable]:
@ -180,8 +179,8 @@ def DistributedTrainableCreator(
of a trial will be placed evenly across different machines. of a trial will be placed evenly across different machines.
It is recommended that if `num_hosts` per trial > 1, you set It is recommended that if `num_hosts` per trial > 1, you set
num_slots == the size (or number of GPUs) of a single host. num_workers == the size (or number of GPUs) of a single host.
If num_hosts == 1, then you can set num_slots to be <= If num_hosts == 1, then you can set num_workers to be <=
the size (number of GPUs) of a single host. the size (number of GPUs) of a single host.
This above assumption can be relaxed - please file a feature request This above assumption can be relaxed - please file a feature request
@ -201,11 +200,11 @@ def DistributedTrainableCreator(
a config dict for hyperparameters and should initialize a config dict for hyperparameters and should initialize
horovod via horovod.init. horovod via horovod.init.
use_gpu (bool); Whether to allocate a GPU per worker. use_gpu (bool); Whether to allocate a GPU per worker.
num_cpus_per_slot (int): Number of CPUs to request num_cpus_per_worker (int): Number of CPUs to request
from Ray per worker. from Ray per worker.
num_hosts (int): Number of hosts that each trial is expected num_hosts (int): Number of hosts that each trial is expected
to use. to use.
num_slots (int): Number of slots (workers) to start on each host. num_workers (int): Number of workers to start on each host.
timeout_s (int): Seconds for Horovod rendezvous to timeout. timeout_s (int): Seconds for Horovod rendezvous to timeout.
replicate_pem (bool): THIS MAY BE INSECURE. If true, this will replicate_pem (bool): THIS MAY BE INSECURE. If true, this will
replicate the underlying Ray cluster ssh key across all hosts. replicate the underlying Ray cluster ssh key across all hosts.
@ -225,7 +224,7 @@ def DistributedTrainableCreator(
from ray.tune.integration.horovod import DistributedTrainableCreator from ray.tune.integration.horovod import DistributedTrainableCreator
trainable_cls = DistributedTrainableCreator( trainable_cls = DistributedTrainableCreator(
train, num_hosts=1, num_slots=2, use_gpu=True) train, num_hosts=1, num_workers=2, use_gpu=True)
tune.run(trainable_cls) tune.run(trainable_cls)
@ -246,8 +245,8 @@ def DistributedTrainableCreator(
class WrappedHorovodTrainable(_HorovodTrainable): class WrappedHorovodTrainable(_HorovodTrainable):
_function = func _function = func
_num_hosts = num_hosts _num_hosts = num_hosts
_num_slots = num_slots _num_workers = num_workers
_num_cpus_per_slot = num_cpus_per_slot _num_cpus_per_worker = num_cpus_per_worker
_use_gpu = use_gpu _use_gpu = use_gpu
_ssh_identity_file = ssh_identity_file _ssh_identity_file = ssh_identity_file
_ssh_str = sshkeystr _ssh_str = sshkeystr
@ -257,8 +256,8 @@ def DistributedTrainableCreator(
def default_resource_request(cls, config: Dict): def default_resource_request(cls, config: Dict):
return PlacementGroupFactory( return PlacementGroupFactory(
[{}] [{}]
+ [{"CPU": cls._num_cpus_per_slot, "GPU": int(use_gpu)}] + [{"CPU": cls._num_cpus_per_worker, "GPU": int(use_gpu)}]
* (num_hosts * num_slots) * (num_workers)
) )
return WrappedHorovodTrainable return WrappedHorovodTrainable

View file

@ -74,7 +74,7 @@ def test_horovod_simple(start_client_server_2_cpus):
assert ray.util.client.ray.is_connected() assert ray.util.client.ray.is_connected()
from ray.tune.examples.horovod_simple import tune_horovod from ray.tune.examples.horovod_simple import tune_horovod
tune_horovod(hosts_per_trial=1, slots_per_host=2, num_samples=2, use_gpu=False) tune_horovod(num_workers=2, num_samples=2, use_gpu=False)
def test_xgboost_example(start_client_server): def test_xgboost_example(start_client_server):

View file

@ -43,14 +43,14 @@ def ray_connect_cluster():
def test_single_step(ray_start_2_cpus): def test_single_step(ray_start_2_cpus):
trainable_cls = DistributedTrainableCreator(_train_simple, num_hosts=1, num_slots=2) trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls() trainer = trainable_cls()
trainer.train() trainer.train()
trainer.stop() trainer.stop()
def test_step_after_completion(ray_start_2_cpus): def test_step_after_completion(ray_start_2_cpus):
trainable_cls = DistributedTrainableCreator(_train_simple, num_hosts=1, num_slots=2) trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls(config={"epochs": 1}) trainer = trainable_cls(config={"epochs": 1})
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
for i in range(10): for i in range(10):
@ -61,13 +61,13 @@ def test_validation(ray_start_2_cpus):
def bad_func(a, b, c): def bad_func(a, b, c):
return 1 return 1
t_cls = DistributedTrainableCreator(bad_func, num_slots=2) t_cls = DistributedTrainableCreator(bad_func, num_workers=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
t_cls() t_cls()
def test_set_global(ray_start_2_cpus): def test_set_global(ray_start_2_cpus):
trainable_cls = DistributedTrainableCreator(_train_simple, num_slots=2) trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainable = trainable_cls() trainable = trainable_cls()
result = trainable.train() result = trainable.train()
trainable.stop() trainable.stop()
@ -76,7 +76,7 @@ def test_set_global(ray_start_2_cpus):
@pytest.mark.parametrize("enabled_checkpoint", [True, False]) @pytest.mark.parametrize("enabled_checkpoint", [True, False])
def test_simple_tune(ray_start_4_cpus, enabled_checkpoint): def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
trainable_cls = DistributedTrainableCreator(_train_simple, num_slots=2) trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
analysis = tune.run( analysis = tune.run(
trainable_cls, trainable_cls,
config={"enable_checkpoint": enabled_checkpoint}, config={"enable_checkpoint": enabled_checkpoint},
@ -92,7 +92,7 @@ def test_resource_tune(ray_connect_cluster, use_gpu):
if use_gpu and ray.cluster_resources().get("GPU", 0) == 0: if use_gpu and ray.cluster_resources().get("GPU", 0) == 0:
pytest.skip("No GPU available.") pytest.skip("No GPU available.")
trainable_cls = DistributedTrainableCreator( trainable_cls = DistributedTrainableCreator(
_train_simple, num_slots=2, use_gpu=use_gpu _train_simple, num_workers=2, use_gpu=use_gpu
) )
analysis = tune.run(trainable_cls, num_samples=2, stop={"training_iteration": 2}) analysis = tune.run(trainable_cls, num_samples=2, stop={"training_iteration": 2})
assert analysis.trials[0].last_result["training_iteration"] == 2 assert analysis.trials[0].last_result["training_iteration"] == 2

View file

@ -47,7 +47,7 @@ def train(config, checkpoint_dir=None):
optimizer = hvd.DistributedOptimizer(optimizer) optimizer = hvd.DistributedOptimizer(optimizer)
np.random.seed(1 + hvd.rank()) np.random.seed(1 + hvd.rank())
torch.manual_seed(1234) torch.manual_seed(1234)
# To ensure consistent initialization across slots, # To ensure consistent initialization across workers,
hvd.broadcast_parameters(net.state_dict(), root_rank=0) hvd.broadcast_parameters(net.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0)
@ -107,7 +107,7 @@ if __name__ == "__main__":
train, train,
use_gpu=False if args.smoke_test else True, use_gpu=False if args.smoke_test else True,
num_hosts=1 if args.smoke_test else 2, num_hosts=1 if args.smoke_test else 2,
num_slots=2 if args.smoke_test else 2, num_workers=2 if args.smoke_test else 2,
replicate_pem=False, replicate_pem=False,
timeout_s=300, timeout_s=300,
) )